diff --git a/.clang-tidy b/.clang-tidy index 868a22c259602..cbe5b04fda95e 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -140,7 +140,7 @@ clang-analyzer-optin.portability.UnixAPI, clang-analyzer-security.insecureAPI.vfork, -clang-analyzer-unix.API, -clang-analyzer-unix.DynamicMemoryModeling, --clang-analyzer-unix.Malloc, +clang-analyzer-unix.Malloc, -clang-analyzer-unix.MallocSizeof, -clang-analyzer-unix.MismatchedDeallocator, clang-analyzer-unix.Vfork, @@ -158,7 +158,7 @@ cppcoreguidelines-explicit-virtual-functions, cppcoreguidelines-init-variables, cppcoreguidelines-narrowing-conversions, cppcoreguidelines-no-malloc, --cppcoreguidelines-pro-type-const-cast, +cppcoreguidelines-pro-type-const-cast, -cppcoreguidelines-pro-type-member-init, -cppcoreguidelines-slicing, -hicpp-avoid-goto, diff --git a/.flake8 b/.flake8 index 91137a006d088..5187a0cdefe03 100644 --- a/.flake8 +++ b/.flake8 @@ -28,9 +28,3 @@ per-file-ignores = # Ignore compare with True in sot unittest test/sot/test_dup_top.py:E712 - - # temp ignore base directory - python/paddle/base/*: - E712, - E266, - E714 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d8ff330bc18b..9b9d5d49c28f8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -160,3 +160,13 @@ repos: hooks: - id: cmakelint args: [--config=./tools/codestyle/.cmakelintrc] +# Others +- repo: local + hooks: + - id: sort-txt-file + name: sort-txt-file + description: Sorts each line string in a text file + entry: python ./tools/codestyle/sort_txt_file.py + language: python + files: test/white_list/pir_op_test_white_list + args: [] diff --git a/cmake/cinn.cmake b/cmake/cinn.cmake index e4e3b86936885..ca828dc48ae1b 100644 --- a/cmake/cinn.cmake +++ b/cmake/cinn.cmake @@ -164,8 +164,8 @@ cinn_cc_library( add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB) add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ${core_deps}) if(NOT CINN_ONLY) - target_link_libraries(cinnapi pd_op_dialect phi) - add_dependencies(cinnapi pd_op_dialect phi) + target_link_libraries(cinnapi op_dialect_vjp phi) + add_dependencies(cinnapi op_dialect_vjp phi) endif() target_link_libraries(cinnapi ${PYTHON_LIBRARIES}) @@ -222,8 +222,8 @@ function(gen_cinncore LINKTYPE) add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB) add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ${core_deps}) if(NOT CINN_ONLY) - target_link_libraries(${CINNCORE_TARGET} pd_op_dialect phi) - add_dependencies(${CINNCORE_TARGET} pd_op_dialect phi) + target_link_libraries(${CINNCORE_TARGET} op_dialect_vjp phi) + add_dependencies(${CINNCORE_TARGET} op_dialect_vjp phi) endif() add_dependencies(${CINNCORE_TARGET} pybind) diff --git a/cmake/cinn/external/absl.cmake b/cmake/cinn/external/absl.cmake index 56befafecea21..0b3f3d685ed80 100644 --- a/cmake/cinn/external/absl.cmake +++ b/cmake/cinn/external/absl.cmake @@ -5,7 +5,7 @@ set(ABSL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/absl) set(ABSL_PREFIX_DIR ${THIRD_PARTY_PATH}/absl) set(ABSL_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) -set(ABSL_REPOSITORY "https://github.com/abseil/abseil-cpp.git") +set(ABSL_REPOSITORY "${GIT_URL}/abseil/abseil-cpp.git") set(ABSL_TAG "20210324.2") set(OPTIONAL_ARGS diff --git a/cmake/cinn/external/jitify.cmake b/cmake/cinn/external/jitify.cmake index 7750934d8056c..8e478a00176b0 100644 --- a/cmake/cinn/external/jitify.cmake +++ b/cmake/cinn/external/jitify.cmake @@ -7,7 +7,7 @@ include(ExternalProject) # clone jitify to Paddle/third_party set(JITIFY_SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/jitify) -set(JITIFY_URL https://github.com/NVIDIA/jitify.git) +set(JITIFY_URL ${GIT_URL}/NVIDIA/jitify.git) set(JITIFY_TAG 57de649139c866eb83acacfe50c92ad7c6278776) ExternalProject_Add( diff --git a/cmake/external/brpc.cmake b/cmake/external/brpc.cmake index c1c514def7619..ad414418caefe 100755 --- a/cmake/external/brpc.cmake +++ b/cmake/external/brpc.cmake @@ -40,7 +40,7 @@ include_directories(${BRPC_INCLUDE_DIR}) # clone brpc to Paddle/third_party set(BRPC_SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/brpc) -set(BRPC_URL https://github.com/apache/brpc.git) +set(BRPC_URL ${GIT_URL}/apache/brpc.git) set(BRPC_TAG 1.4.0) # Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args diff --git a/cmake/external/cudnn-frontend.cmake b/cmake/external/cudnn-frontend.cmake index 16c21c8dbf26d..37625f88d9ded 100644 --- a/cmake/external/cudnn-frontend.cmake +++ b/cmake/external/cudnn-frontend.cmake @@ -34,7 +34,7 @@ if((NOT DEFINED CUDNN_FRONTEND_NAME) OR (NOT DEFINED CUDNN_FRONTEND_URL)) "cudnn-frontend" CACHE STRING "" FORCE) set(CUDNN_FRONTEND_URL - "https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/${CUDNN_FRONTEND_VER}.tar.gz" + "${GIT_URL}/NVIDIA/cudnn-frontend/archive/refs/tags/${CUDNN_FRONTEND_VER}.tar.gz" CACHE STRING "" FORCE) endif() set(CUDNN_FRONTEND_CACHE_FILENAME "${CUDNN_FRONTEND_VER}.tar.gz") diff --git a/cmake/external/dirent.cmake b/cmake/external/dirent.cmake index 9c212a237f0a4..7bec37d5f1b7e 100644 --- a/cmake/external/dirent.cmake +++ b/cmake/external/dirent.cmake @@ -25,7 +25,7 @@ if((NOT DEFINED DIRENT_NAME) OR (NOT DEFINED DIRENT_URL)) "dirent" CACHE STRING "" FORCE) set(DIRENT_URL - "https://github.com/tronkko/dirent/archive/refs/tags/1.23.2.tar.gz" + "${GIT_URL}/tronkko/dirent/archive/refs/tags/1.23.2.tar.gz" CACHE STRING "" FORCE) set(DIRENT_CACHE_FILENAME "1.23.2.tar.gz") endif() diff --git a/cmake/external/jemalloc.cmake b/cmake/external/jemalloc.cmake index 1fc2d508fb735..183c9369a2b2c 100644 --- a/cmake/external/jemalloc.cmake +++ b/cmake/external/jemalloc.cmake @@ -5,9 +5,7 @@ set(JEMALLOC_DOWNLOAD_DIR set(JEMALLOC_PROJECT "extern_jemalloc") set(JEMALLOC_BUILD ${THIRD_PARTY_PATH}/jemalloc/src/extern_jemalloc) set(JEMALLOC_PREFIX_DIR ${THIRD_PARTY_PATH}/jemalloc) -set(JEMALLOC_URL - https://github.com/jemalloc/jemalloc/releases/download/5.1.0/jemalloc-5.1.0.tar.bz2 -) +set(JEMALLOC_URL https://paddle-ci.gz.bcebos.com/jemalloc-5.1.0.tar.bz2) set(JEMALLOC_INSTALL ${THIRD_PARTY_PATH}/install/jemalloc) set(JEMALLOC_INCLUDE_DIR ${JEMALLOC_INSTALL}/include) diff --git a/cmake/external/libxsmm.cmake b/cmake/external/libxsmm.cmake index cbe951211b5a1..0f06fe0952968 100644 --- a/cmake/external/libxsmm.cmake +++ b/cmake/external/libxsmm.cmake @@ -31,9 +31,8 @@ set(LIBXSMMNOBLAS_LIB "${LIBXSMM_LIBRARY_DIR}/libxsmmnoblas.a") file(GLOB LIBXSMM_SOURCE_FILE_LIST ${LIBXSMM_SOURCE_DIR}) list(LENGTH LIBXSMM_SOURCE_FILE_LIST RES_LEN) if(RES_LEN EQUAL 0) - execute_process( - COMMAND ${GIT_EXECUTABLE} clone -b ${LIBXSMM_TAG} - "https://github.com/hfp/libxsmm.git" ${LIBXSMM_SOURCE_DIR}) + execute_process(COMMAND ${GIT_EXECUTABLE} clone -b ${LIBXSMM_TAG} + "${GIT_URL}/hfp/libxsmm.git" ${LIBXSMM_SOURCE_DIR}) else() # check git tag execute_process( diff --git a/cmake/external/onnxruntime.cmake b/cmake/external/onnxruntime.cmake index 57969e8c76c8e..1a2f7662fea24 100644 --- a/cmake/external/onnxruntime.cmake +++ b/cmake/external/onnxruntime.cmake @@ -44,19 +44,19 @@ set(ONNXRUNTIME_DOWNLOAD_DIR if(WIN32) set(ONNXRUNTIME_URL - "https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-win-x64-${ONNXRUNTIME_VERSION}.zip" + "${GIT_URL}/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-win-x64-${ONNXRUNTIME_VERSION}.zip" ) set(ONNXRUNTIME_URL_MD5 f21d6bd1feef15935a5f4e1007797593) set(ONNXRUNTIME_CACHE_EXTENSION "zip") elseif(APPLE) set(ONNXRUNTIME_URL - "https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-osx-x86_64-${ONNXRUNTIME_VERSION}.tgz" + "${GIT_URL}/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-osx-x86_64-${ONNXRUNTIME_VERSION}.tgz" ) set(ONNXRUNTIME_URL_MD5 6a6f6b7df97587da59976042f475d3f4) set(ONNXRUNTIME_CACHE_EXTENSION "tgz") else() set(ONNXRUNTIME_URL - "https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}.tgz" + "${GIT_URL}/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}.tgz" ) set(ONNXRUNTIME_URL_MD5 ce3f2376854b3da4b483d6989666995a) set(ONNXRUNTIME_CACHE_EXTENSION "tgz") diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index f2ef9fd845434..5c9112a4d4e89 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -46,9 +46,8 @@ endif() file(GLOB CBLAS_SOURCE_FILE_LIST ${CBLAS_SOURCE_DIR}) list(LENGTH CBLAS_SOURCE_FILE_LIST RES_LEN) if(RES_LEN EQUAL 0) - execute_process( - COMMAND ${GIT_EXECUTABLE} clone -b ${CBLAS_TAG} - "https://github.com/xianyi/OpenBLAS.git" ${CBLAS_SOURCE_DIR}) + execute_process(COMMAND ${GIT_EXECUTABLE} clone -b ${CBLAS_TAG} + "${GIT_URL}/xianyi/OpenBLAS.git" ${CBLAS_SOURCE_DIR}) else() # check git tag execute_process( diff --git a/cmake/external/paddle2onnx.cmake b/cmake/external/paddle2onnx.cmake index 0a80c87e8e5fa..decb6c9168274 100644 --- a/cmake/external/paddle2onnx.cmake +++ b/cmake/external/paddle2onnx.cmake @@ -71,19 +71,19 @@ endif() if(WIN32) set(PADDLE2ONNX_URL - "https://github.com/PaddlePaddle/Paddle2ONNX/releases/download/v${PADDLE2ONNX_VERSION}/paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip" + "${GIT_URL}/PaddlePaddle/Paddle2ONNX/releases/download/v${PADDLE2ONNX_VERSION}/paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip" ) set(PADDLE2ONNX_URL_MD5 "122b864cb57338191a7e9ef5f607c4ba") set(PADDLE2ONNX_CACHE_EXTENSION "zip") elseif(APPLE) set(PADDLE2ONNX_URL - "https://github.com/PaddlePaddle/Paddle2ONNX/releases/download/v${PADDLE2ONNX_VERSION}/paddle2onnx-osx-x86_64-${PADDLE2ONNX_VERSION}.tgz" + "${GIT_URL}/PaddlePaddle/Paddle2ONNX/releases/download/v${PADDLE2ONNX_VERSION}/paddle2onnx-osx-x86_64-${PADDLE2ONNX_VERSION}.tgz" ) set(PADDLE2ONNX_URL_MD5 "32a4381ff8441b69d58ef0fd6fd919eb") set(PADDLE2ONNX_CACHE_EXTENSION "tgz") else() set(PADDLE2ONNX_URL - "https://github.com/PaddlePaddle/Paddle2ONNX/releases/download/v${PADDLE2ONNX_VERSION}/paddle2onnx-linux-x64-${PADDLE2ONNX_VERSION}.tgz" + "${GIT_URL}/PaddlePaddle/Paddle2ONNX/releases/download/v${PADDLE2ONNX_VERSION}/paddle2onnx-linux-x64-${PADDLE2ONNX_VERSION}.tgz" ) set(PADDLE2ONNX_URL_MD5 "3fbb074987ba241327797f76514e937f") set(PADDLE2ONNX_CACHE_EXTENSION "tgz") diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index 85836a33f8c08..0dc93d47ec92b 100755 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -244,7 +244,7 @@ function(build_protobuf TARGET_NAME BUILD_FOR_HOST) set(PROTOBUF_TAG 01a05a53f40ca2ac5f0af10c6cc0810bee39b792) else() if(WITH_PSLIB) - set(PROTOBUF_REPOSITORY "https://github.com/google/protobuf.git") + set(PROTOBUF_REPOSITORY "${GIT_URL}/google/protobuf.git") set(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546") else() set(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git) diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 447c744da39c3..34d31d299eb89 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -24,7 +24,7 @@ set(XPU_XFT_LIB_NAME "libxft.so") set(XPU_XPTI_LIB_NAME "libxpti.so") if(NOT DEFINED XPU_BASE_DATE) - set(XPU_BASE_DATE "20230926") + set(XPU_BASE_DATE "20231103") endif() set(XPU_XCCL_BASE_VERSION "1.0.53.6") if(NOT DEFINED XPU_XFT_BASE_VERSION) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 9f4ffd23a57e1..92aaa69cb46f6 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -499,12 +499,15 @@ function(cc_test_run TARGET_NAME) NAME ${TARGET_NAME} COMMAND ${cc_test_COMMAND} ${cc_test_ARGS} WORKING_DIRECTORY ${cc_test_DIR}) - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT - FLAGS_cpu_deterministic=true) - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT - FLAGS_init_allocated_mem=true) - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT - FLAGS_cudnn_deterministic=true) + set_property( + TEST ${TARGET_NAME} + PROPERTY + ENVIRONMENT + FLAGS_cpu_deterministic=true + FLAGS_init_allocated_mem=true + FLAGS_cudnn_deterministic=true + LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${PADDLE_BINARY_DIR}/python/paddle/libs:${PADDLE_BINARY_DIR}/python/paddle/base + ) # No unit test should exceed 2 minutes. if(WIN32) set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 150) @@ -726,6 +729,7 @@ function(nv_test TARGET_NAME) # 2. cuda_add_executable does not support ccache. # Reference: https://cmake.org/cmake/help/v3.10/module/FindCUDA.html add_executable(${TARGET_NAME} ${nv_test_SRCS}) + target_compile_definitions(${TARGET_NAME} PUBLIC STATIC_PADDLE) get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} ${os_dependency_modules} paddle_gtest_main phi) diff --git a/cmake/hip.cmake b/cmake/hip.cmake index d3972e577a800..4f005e95bb98a 100644 --- a/cmake/hip.cmake +++ b/cmake/hip.cmake @@ -118,6 +118,11 @@ list(APPEND HIP_CXX_FLAGS -Wno-unused-value) list(APPEND HIP_CXX_FLAGS -Wno-braced-scalar-init) list(APPEND HIP_CXX_FLAGS -Wno-return-type) list(APPEND HIP_CXX_FLAGS -Wno-pragma-once-outside-header) +list(APPEND HIP_CXX_FLAGS -Wno-deprecated-builtins) +list(APPEND HIP_CXX_FLAGS -Wno-switch) +list(APPEND HIP_CXX_FLAGS -Wno-literal-conversion) +list(APPEND HIP_CXX_FLAGS -Wno-constant-conversion) +list(APPEND HIP_CXX_FLAGS -Wno-defaulted-function-deleted) if(WITH_CINN) list(APPEND HIP_CXX_FLAGS -std=c++14) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index a0f5d2c82eeb8..61813e3f5e2ff 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -684,6 +684,9 @@ function(prune_pybind_h) list(APPEND op_list "load_combine") list(APPEND op_list "tensorrt_engine") + # TODO(ming1753): conditional_block_infer is temporarily reserved here to avoid link errors in functions of standalone_executor + list(APPEND op_list "conditional_block_infer") + # add fused_op in op_list list(APPEND op_list "fc") list(APPEND op_list "conv2d_fusion") diff --git a/paddle/cinn/README.md b/paddle/cinn/README.md new file mode 100644 index 0000000000000..204feab7f2798 --- /dev/null +++ b/paddle/cinn/README.md @@ -0,0 +1,121 @@ +``` + ___ ___ ___ + /\__\ /\ \ /\ \ + /:/ / ___ \:\ \ \:\ \ + /:/ / /\__\ \:\ \ \:\ \ + /:/ / ___ /:/__/ _____\:\ \ _____\:\ \ + /:/__/ /\__\/::\ \ /::::::::\__\/::::::::\__\ + \:\ \ /:/ /\/\:\ \__\:\~~\~~\/__/\:\~~\~~\/__/ + \:\ /:/ / \:\/\__\\:\ \ \:\ \ + \:\/:/ / \::/ / \:\ \ \:\ \ + \::/ / /:/ / \:\__\ \:\__\ + \/__/ \/__/ \/__/ \/__/ + +``` + + +# CINN : Compiler Infrastructure for Neural Networks + +The project CINN is a machine learning compiler and executor for multiple hardware backends. +It is designed to provide multiple layers of APIs to make tensor computation easier to define, faster to execute, and more convenient to extend with hardware backends. +Currently, it targets x86 CPUs and Nvidia GPUs. + +This project is under active development. + +## How it works + +The CINN lowers a traditional DNN model into a two-level intermediate representation(IR), the high-level IR(HLIR) and CINN IR. + +The HLIR helps to define some domain-specific computation and perform some overall optimization on the IR-graph; +the CINN IR helps to represent some computation semantic and finally lower to a hardware backend. + +Both levels of IR have the similar SSA graph, analysis and optimization facilities. +The schedule transform is applied on the CINN IR to do optimizations. + +For more details, you can refer to: +https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/cinn + +## Getting Started + +### Compile + +Clone PaddlePaddle first. + +``` +git clone https://github.com/PaddlePaddle/Paddle.git +cd Paddle +mkdir build +cd build +``` + +Build paddle with cinn: + +``` +cmake .. -DCINN_ONLY=OFF -DWITH_CINN=ON -DWITH_GPU=ON +``` + +Build cinn only: + +``` +cmake .. -DCINN_ONLY=ON -DWITH_CINN=ON -DWITH_GPU=ON +``` + +And then + +``` +make -j +``` + +### Install + +Install paddle with cinn: + +``` +pip install python/dist/paddlepaddle_gpu-xxx.whl +``` + +Install cinn only: + +``` +pip install python/dist/cinn_gpu-xxx.whl +``` + +Then you can import paddle in the python environment and check if a paddle version with CINN is installed. + +``` +import paddle +paddle.is_compiled_with_cinn() +``` + +### Concepts + +There are two levels of APIs in CINN, the higher level is HLIR and the lower level is CINN IR, both contain some concepts. + +In HLIR + +- `frontend::Program`, the program helps to define a machine learning computation, +- `hlir::framework::Tensor`, multi-dimensional arrays helps to manage a memory buffer. +- `hlir::framework::Program`, the final executable program in runtime. It holds many basic executable elements. +- `hlir::framework::Graph`, the graph that represents the structure of a model. Each node in the graph represents an operator (conv2d, relu, mul, etc.). +- `hlir::framework::GraphCompiler`, the compiler that transforms the graph representation(hlir::framework::Graph) of a model into an executable program(hlir::framework::Program). + +In CINN IR + +- `Compute`, the method to define a computation, +- `Lower`, the method to lower a computation to the corresponding IR, +- `LoweredFunc`, the function defined in CINN IR, +- `Var`, a scalar variable, +- `Expr`, an expression represents any CINN IR node(no specified Statement node), + +## License + +CINN is licensed under the [Apache 2.0 license](LICENSE). + +## Acknowledgement + +CINN learned a lot from the following projects: + +- [Halide](https://github.com/halide/Halide): Referenced the design of most IR nodes, +- [TVM](https://github.com/apache/tvm): We learned many ideas including the semantics of some schedule primitives, TOPI, NNVM, and so on, +- [tiramisu](https://github.com/Tiramisu-Compiler): The isl usage, polyhedral compilation, schedule primitive implementation, and so on, +- [tensorflow/xla](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla): Referenced the semantics of the primitive operations. diff --git a/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc b/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc index ef408b7b7778a..f7fffa0e0ff4b 100644 --- a/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc +++ b/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc @@ -20,6 +20,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/common/context.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" @@ -49,9 +50,9 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) { ir::Tensor B = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); - poly::StageMap stages = poly::CreateStages({A, B}); - std::vector funcs = lang::LowerVec( - "SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({A, B}); + std::vector funcs = + lang::LowerToAstVec("SimpleAssign", {A, B}, &tensor_group, target); ASSERT_FALSE(funcs.empty()); ir::Expr ast_expr = funcs[0]->body; @@ -115,9 +116,9 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) { ir::Tensor C = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); - poly::StageMap stages = poly::CreateStages({C}); - std::vector funcs = lang::LowerVec( - "AddDiffShape", stages, {C}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({C}); + std::vector funcs = + lang::LowerToAstVec("AddDiffShape", {C}, &tensor_group, target); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr before MultiLevelTiling: "; @@ -169,9 +170,9 @@ TEST(AnalyzeIr, ContainsNodeType) { ir::Tensor B = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); - poly::StageMap stages = poly::CreateStages({A, B}); - std::vector funcs = lang::LowerVec( - "SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({A, B}); + std::vector funcs = + lang::LowerToAstVec("SimpleAssign", {A, B}, &tensor_group, target); ASSERT_FALSE(funcs.empty()); ir::Expr ast_expr = funcs[0]->body; diff --git a/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc b/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc index 9364374156f4a..3b51eac2600e3 100644 --- a/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc +++ b/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc @@ -21,6 +21,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/common/context.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" @@ -48,9 +49,9 @@ TEST(FeatureExtractor, SimpleAssign) { ir::Tensor B = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); - poly::StageMap stages = poly::CreateStages({A, B}); - std::vector funcs = lang::LowerVec( - "SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({A, B}); + std::vector funcs = + lang::LowerToAstVec("SimpleAssign", {A, B}, &tensor_group, target); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr to test: " << ast_expr; @@ -109,9 +110,9 @@ TEST(FeatureExtractor, MatrixMultiply) { [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); - poly::StageMap stages = poly::CreateStages({C}); - std::vector funcs = lang::LowerVec( - "MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({C}); + std::vector funcs = + lang::LowerToAstVec("SimpleAssign", {C}, &tensor_group, target); std::vector vec_ast{funcs[0]->body}; ir::ModuleExpr mod_expr(vec_ast); diff --git a/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc b/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc index 5d6d1be6e0c13..5db6f8999b18a 100644 --- a/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc +++ b/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc @@ -20,6 +20,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/auto_schedule/search_space/search_state.h" #include "paddle/cinn/auto_schedule/task/task_registry.h" #include "paddle/cinn/cinn.h" @@ -47,8 +48,8 @@ std::vector LowerCompute(const std::vector& shape, C = Compute( domain, [&B](Var i, Var j) { return B(i, j); }, "C"); - return cinn::lang::LowerVec( - "test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({A, B}); + return cinn::lang::LowerToAstVec("test_func", {A, B}, &tensor_group, target); } // Create a new IRSchedule with copied ir::LoweredFunc AST diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc index 92c3a542d135c..e59ba8b423293 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc @@ -18,6 +18,7 @@ #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule_block_graph.h" #include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h" @@ -94,6 +95,8 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule, auto all_loops = ir_schedule->GetLoops(block_name); CHECK_LE(num_loops_to_bind, all_loops.size()) << "The number of loops to be bind is greater than size of all_loops"; + CHECK_GE(num_loops_to_bind, 0) + << "The number of loops to be bind should be greater than 0"; // check whether it is the case that threadIdx has been binded but blockIdx // not, the threadIdx can only be binded in the first loop after // num_loops_to_bind loops because we has excluded other cases in @@ -101,6 +104,17 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule, bool gpu_thread_has_binded = num_loops_to_bind < all_loops.size() && all_loops[num_loops_to_bind].As()->is_gpu_thread_binded(); + ir::BlockOrderConstructor block_order_constructor; + std::map, ir::Expr> blocks_order_with_ctrl_stmt = + block_order_constructor(&all_loops[num_loops_to_bind - 1]); + for (auto& pair : blocks_order_with_ctrl_stmt) { + if (pair.first.size() == 2) { + ir::Expr stmt = pair.second; + if (stmt.As() && stmt.As()->is_gpu_thread_binded()) { + gpu_thread_has_binded = true; + } + } + } Expr fused_loop = ir_schedule->Fuse( {all_loops.begin(), all_loops.begin() + num_loops_to_bind}); int32_t extent = fused_loop.As()->extent.as_int32(); @@ -181,5 +195,18 @@ std::vector AutoBind::ApplyOnBlock(SearchState state, return {new_state}; } +void AutoBind::Apply(ir::IRSchedule* ir_schedule, + const std::string& block_name) { + int num_loop_can_bind = + CountLoopCanBinded(ir_schedule->GetLoops(block_name)[0].As()); + if (num_loop_can_bind > 0) { + BindGPUIndex(ir_schedule, + block_name, + num_loop_can_bind, + kMaxBlocks, + target_->max_num_threads()); + } +} + } // namespace auto_schedule } // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h index a45bd31d4b33f..c4baf8e7797e3 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h @@ -42,6 +42,8 @@ class AutoBind : public AutoGenRule { std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override; + void Apply(ir::IRSchedule* ir_schedule, const std::string& block_name); + private: std::vector applicable_schedule_blocks_; }; diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc index e8cab5dd63fa2..57e13c00a1c76 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc @@ -28,6 +28,7 @@ #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" #include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h" @@ -49,6 +50,11 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr); // Check the schedule block to be inlined is not a reduce tensor. + for (const ir::Var& iter_var : sche_block->iter_vars) { + if (iter_var->is_reduce_axis) { + return false; + } + } std::set find_store = ir::ir_utils::CollectIRNodesWithoutTensor( compute_body, [&](const Expr* x) { return x->As(); }); if (find_store.size() != 1UL) { @@ -69,6 +75,29 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, return false; } + // the xxx_reduce_init block cannot be inlined. + if (ir::IsReduceInitTensorName(tensor->name)) { + return false; + } + + // Skip external calls + std::vector consumers = + ir::GetConsumers(sche_block_realize_expr, root); + for (const ir::Expr& consumer : consumers) { + std::set find_load = ir::ir_utils::CollectIRNodesWithoutTensor( + consumer.As() + ->schedule_block.As() + ->body, + [&](const ir::Expr* x) { + return x->As() && + x->As()->tensor.as_tensor_ref()->name == + tensor->name; + }); + if (find_load.empty()) { + return false; + } + } + // write_buffers.size() = 1 and read_buffers is empty, means const // we can inline to consumer if (sche_block->read_buffers.empty()) { diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h index 0ef60a01a9b0f..9a0fc3e823361 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h @@ -63,7 +63,6 @@ class AutoInline : public AutoGenRule { std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override; - private: void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); // NOLINT private: diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc index 4cfef12e030e0..e69d3069f1939 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc @@ -21,6 +21,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" #include "paddle/cinn/cinn.h" @@ -59,16 +60,13 @@ TEST(AutoInline, SingleLoopInline) { ir::Tensor C = Compute( {M}, [&](Var i) { return B(i) + ir::Expr(1.f); }, "C"); - poly::StageMap stages = CreateStages({A, B, C}); + ast_gen_ius::TensorGroup tensor_group({A, B, C}); std::vector funcs = - lang::LowerVec("TestAutoInline_SingleLoopInline", - stages, - {A, C}, - {}, - {}, - nullptr, - target, - true); + lang::LowerToAstVec("TestAutoInline_SingleLoopInline", + + {A, C}, + &tensor_Group, + target); VLOG(6) << "Expr after lowering:"; VLOG(6) << funcs[0]->body; diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc index 8b08d2c0658b3..e4b0597cfeed7 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc @@ -17,6 +17,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/cinn.h" #include "paddle/cinn/lang/lower.h" @@ -38,9 +39,9 @@ TEST(AutoUnroll, Init) { #else Target target = common::DefaultHostTarget(); #endif - auto stages = CreateStages({C}); - auto funcs = cinn::lang::LowerVec( - "test_init", stages, {A, B, C}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({C}); + auto funcs = + cinn::lang::LowerToAstVec("test_init", {A, B, C}, &tensor_group, target); auto ast_expr = funcs[0]->body; ir::IRSchedule init_schedule(ir::ModuleExpr({ast_expr})); diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc index 5a5c68537e9a7..62f1bb74f4ac0 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc @@ -21,6 +21,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" #include "paddle/cinn/cinn.h" @@ -106,16 +107,9 @@ TEST(MultiLevelTile, SimpleLoops) { ir::Tensor C = Compute( {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); - poly::StageMap stages = CreateStages({C}); - std::vector funcs = - lang::LowerVec("TestMultiLevelTile_SimpleLoops", - stages, - {C}, - {}, - {}, - nullptr, - target, - true); + ast_gen_ius::TensorGroup tensor_group({C}); + std::vector funcs = lang::LowerToAstVec( + "TestMultiLevelTile_SimpleLoops", {C}, &tensor_group, target); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr before MultiLevelTiling: "; diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.cc index c44d067610123..c8b8fdeb0f554 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.cc @@ -161,15 +161,16 @@ void ReductionFactoring::Apply(const std::string& block_name, // 5. Split the reduction loop into 2 part VLOG(6) << "before Split: " << ir_schedule->GetModule().GetExprs()[0]; int factor = 1; + int max_factor = 1024; int extent = ir::GetLoopExtent(fused_reduce_loop); - for (int i = ceil(sqrt(extent)); i >= 1; --i) { + for (int i = max_factor; i >= 1; --i) { if (extent % i == 0) { factor = i; break; } } std::vector splited_reduction_loops = - ir_schedule->Split(fused_reduce_loop, {-1, factor}); + ir_schedule->Split(fused_reduce_loop, {factor, -1}); // 6. Apply FactorizeReduction VLOG(6) << "before FactorizeReduction: " << ir_schedule->GetModule().GetExprs()[0]; @@ -177,6 +178,25 @@ void ReductionFactoring::Apply(const std::string& block_name, num_spatial_loops); VLOG(6) << "after FactorizeReduction: " << ir_schedule->GetModule().GetExprs()[0]; + + // 7. Loop fusion and cross thread reduction + std::vector rb_loops = ir_schedule->GetLoops(block_name); + ir::Expr rf_block = ir_schedule->GetBlock(block_name + "_rf"); + ir_schedule->SimpleComputeAt(rf_block, rb_loops.back()); + + rb_loops = ir_schedule->GetLoops(block_name); + ir::Expr rf_init_block = + ir_schedule->GetBlock(block_name + "_rf__reduce_init"); + ir_schedule->SimpleComputeAt(rf_init_block, rb_loops.back()); + + if (*target_ == common::DefaultNVGPUTarget()) { + rb_loops = ir_schedule->GetLoops(block_name); + rf_block = ir_schedule->GetBlock(block_name + "_rf"); + ir_schedule->Bind(rb_loops.back(), "threadIdx.x"); + ir_schedule->SetBuffer(rf_block, "shared"); + } + VLOG(6) << "Loop fusion and cross thread reduction: " + << ir_schedule->GetModule().GetExprs()[0]; } } // namespace auto_schedule diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring_test.cc index 63e808cfbd4a5..6848fba586944 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring_test.cc @@ -25,6 +25,8 @@ #include "paddle/cinn/ir/ir_printer.h" #include "test/cpp/cinn/concrete_program_builder.h" +PD_DECLARE_bool(cinn_new_group_scheduler); + namespace cinn { namespace auto_schedule { @@ -37,7 +39,9 @@ class TestReductionFactoring : public TestAutoGenRuleBase { const std::vector& reduce_dim, const std::string& block_name, const std::string& expected_ir) { - Initialize(common::DefaultHostTarget()); + Initialize(common::DefaultNVGPUTarget()); + // In order to forcibly use the most basic Compute of reduction + FLAGS_cinn_new_group_scheduler = 1; auto test_program = tests::ReduceBuilder().Build( {{"X", shape}}, {{"reduce_dim", reduce_dim}}); // construct input parameter @@ -66,7 +70,8 @@ class TestReductionFactoring : public TestAutoGenRuleBase { }; TEST_F(TestReductionFactoring, AnalyseApplyType) { - Initialize(common::DefaultHostTarget()); + Context::Global().ResetNameId(); + Initialize(common::DefaultNVGPUTarget()); auto test_program = tests::OpBuilder("elementwise_add").Build({{"X", {4, 5}}, {"Y", {4, 5}}}); ir::IRSchedule ir_schedule = MakeIRSchedule(test_program); @@ -77,43 +82,44 @@ TEST_F(TestReductionFactoring, AnalyseApplyType) { RuleApplyType::kCannotApply); } +#ifdef CINN_WITH_CUDA + TEST_F(TestReductionFactoring, ApplyOnBlock1ReduceDim) { + Context::Global().ResetNameId(); std::string expected_ir = R"({ ScheduleBlock(root) { { serial for (i, 0, 32) { - serial for (reduce_k_0_0, 0, 8) + ScheduleBlock(var_0__reduce_init) + { + i0_0 = axis.bind(i) + var_0__reduce_init[i0_0] = 0.00000000f + } + thread_bind[threadIdx.x] for (reduce_k_0_0, 0, 64) { ScheduleBlock(var_0_rf__reduce_init) { vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i) var_0_rf__reduce_init[i0_0, vreduce_k_0_0] = 0.00000000f } - serial for (reduce_k_0_1, 0, 8) { - ScheduleBlock(var_0_rf) + serial for (reduce_k_0_1, 0, 1) { - vreduce_k_0_0, i0_0, vreduce_k_0_1 = axis.bind(reduce_k_0_0, i, reduce_k_0_1) - var_0_rf[i0_0, vreduce_k_0_0] = (var_0_rf[i0_0, vreduce_k_0_0] + X[i0_0, ((8 * vreduce_k_0_0) + vreduce_k_0_1)]) + ScheduleBlock(var_0_rf) + { + vreduce_k_0_0, i0_0, vreduce_k_0_1 = axis.bind(reduce_k_0_0, i, reduce_k_0_1) + var_0_rf[i0_0, vreduce_k_0_0] = (var_0_rf[i0_0, vreduce_k_0_0] + X[i0_0, (vreduce_k_0_0 + vreduce_k_0_1)]) + } + } + { + ScheduleBlock(var_0) + { + vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i) + var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_0]) + } } - } - } - } - serial for (i, 0, 32) - { - ScheduleBlock(var_0__reduce_init) - { - i0_0 = axis.bind(i) - var_0__reduce_init[i0_0] = 0.00000000f - } - serial for (reduce_k_0_0, 0, 8) - { - ScheduleBlock(var_0) - { - vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i) - var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_0]) } } } @@ -124,42 +130,41 @@ TEST_F(TestReductionFactoring, ApplyOnBlock1ReduceDim) { } TEST_F(TestReductionFactoring, ApplyOnBlock2ReduceDim) { + Context::Global().ResetNameId(); std::string expected_ir = R"({ ScheduleBlock(root) { { serial for (i, 0, 32) { - serial for (reduce_k_0_reduce_k_1_fused, 0, 128) + ScheduleBlock(var_0__reduce_init) + { + i0_0 = axis.bind(i) + var_0__reduce_init[i0_0] = 0.00000000f + } + thread_bind[threadIdx.x] for (reduce_k_0_reduce_k_1_fused, 0, 1024) { ScheduleBlock(var_0_rf__reduce_init) { vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i) var_0_rf__reduce_init[i0_0, vreduce_k_0_reduce_k_1_fused] = 0.00000000f } - serial for (reduce_k_0_reduce_k_1_fused_0, 0, 64) { - ScheduleBlock(var_0_rf) + serial for (reduce_k_0_reduce_k_1_fused_0, 0, 8) { - vreduce_k_0_reduce_k_1_fused, i0_0, vreduce_k_0_reduce_k_1_fused_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i, reduce_k_0_reduce_k_1_fused_0) - var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] + X[i0_0, (((64 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) / 128), (((64 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) % 128)]) + ScheduleBlock(var_0_rf) + { + vreduce_k_0_reduce_k_1_fused, i0_0, vreduce_k_0_reduce_k_1_fused_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i, reduce_k_0_reduce_k_1_fused_0) + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] + X[i0_0, (((8 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) / 128), (((8 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) % 128)]) + } + } + { + ScheduleBlock(var_0) + { + vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i) + var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused]) + } } - } - } - } - serial for (i, 0, 32) - { - ScheduleBlock(var_0__reduce_init) - { - i0_0 = axis.bind(i) - var_0__reduce_init[i0_0] = 0.00000000f - } - serial for (reduce_k_0_reduce_k_1_fused, 0, 128) - { - ScheduleBlock(var_0) - { - vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i) - var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused]) } } } @@ -170,42 +175,41 @@ TEST_F(TestReductionFactoring, ApplyOnBlock2ReduceDim) { } TEST_F(TestReductionFactoring, ApplyOnBlock3ReduceDim) { + Context::Global().ResetNameId(); std::string expected_ir = R"({ ScheduleBlock(root) { { serial for (i, 0, 32) { - serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused, 0, 512) + ScheduleBlock(var_0__reduce_init) + { + i0_0 = axis.bind(i) + var_0__reduce_init[i0_0] = 0.00000000f + } + thread_bind[threadIdx.x] for (reduce_k_0_reduce_k_1_reduce_k_2_fused, 0, 1024) { ScheduleBlock(var_0_rf__reduce_init) { vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i) var_0_rf__reduce_init[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = 0.00000000f } - serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused_0, 0, 512) { - ScheduleBlock(var_0_rf) + serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused_0, 0, 256) { - vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i, reduce_k_0_reduce_k_1_reduce_k_2_fused_0) - var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] + X[i0_0, ((((512 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) / 64), ((((512 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) % 64), (((512 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) % 64)]) + ScheduleBlock(var_0_rf) + { + vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i, reduce_k_0_reduce_k_1_reduce_k_2_fused_0) + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] + X[i0_0, ((((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) / 64), ((((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) % 64), (((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) % 64)]) + } + } + { + ScheduleBlock(var_0) + { + vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i) + var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused]) + } } - } - } - } - serial for (i, 0, 32) - { - ScheduleBlock(var_0__reduce_init) - { - i0_0 = axis.bind(i) - var_0__reduce_init[i0_0] = 0.00000000f - } - serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused, 0, 512) - { - ScheduleBlock(var_0) - { - vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i) - var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused]) } } } @@ -214,6 +218,7 @@ TEST_F(TestReductionFactoring, ApplyOnBlock3ReduceDim) { })"; TestApplyOnReduce({32, 64, 64, 64}, {1, 2, 3}, "var_0", expected_ir); } +#endif } // namespace auto_schedule } // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc index 52f38e0b65b03..5ba15a46fef18 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc @@ -21,6 +21,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" #include "paddle/cinn/cinn.h" #include "paddle/cinn/ir/ir.h" @@ -52,9 +53,9 @@ TEST(SkipRule, Basic) { ir::Tensor C = Compute( {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); - poly::StageMap stages = CreateStages({C}); - std::vector funcs = lang::LowerVec( - "TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({C}); + std::vector funcs = + lang::LowerToAstVec("TestSkipRule_Basic", {C}, &tensor_group, target); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr before SkipRule: "; @@ -101,9 +102,9 @@ TEST(SkipRule, ApplyOnSpecificBlock) { ir::Tensor C = Compute( {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); - poly::StageMap stages = CreateStages({C}); - std::vector funcs = lang::LowerVec( - "TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({C}); + std::vector funcs = + lang::LowerToAstVec("TestSkipRule_Basic", {C}, &tensor_group, target); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr before SkipRule: "; diff --git a/paddle/cinn/auto_schedule/search_space/search_state_test.cc b/paddle/cinn/auto_schedule/search_space/search_state_test.cc index 61547d228302f..b0f216c4895aa 100644 --- a/paddle/cinn/auto_schedule/search_space/search_state_test.cc +++ b/paddle/cinn/auto_schedule/search_space/search_state_test.cc @@ -17,6 +17,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/cinn.h" #include "paddle/cinn/common/context.h" @@ -35,35 +36,18 @@ TEST(TestSearchState, SearchStateHash_Equal) { ir::Tensor C = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); + ast_gen_ius::TensorGroup const_group_1({A, B}); cinn::common::Context::Global().ResetNameId(); - auto a_plus_const_funcs_1 = lang::LowerVec("A_plus_const", - poly::CreateStages({A, B}), - {A, B}, - {}, - {}, - nullptr, - target, - true); - + auto a_plus_const_funcs_1 = + lang::LowerToAstVec("A_plus_const", {A, B}, &const_group_1, target); cinn::common::Context::Global().ResetNameId(); - auto a_plus_const_funcs_2 = lang::LowerVec("A_plus_const", - poly::CreateStages({A, B}), - {A, B}, - {}, - {}, - nullptr, - target, - true); - + ast_gen_ius::TensorGroup const_group_2({A, B}); + auto a_plus_const_funcs_2 = + lang::LowerToAstVec("A_plus_const", {A, B}, &const_group_2, target); cinn::common::Context::Global().ResetNameId(); - auto a_plus_b_funcs = lang::LowerVec("A_plus_B", - poly::CreateStages({A, C}), - {A, C}, - {}, - {}, - nullptr, - target, - true); + ast_gen_ius::TensorGroup plus_group({A, C}); + auto a_plus_b_funcs = + lang::LowerToAstVec("A_plus_B", {A, C}, &plus_group, target); std::string a_plus_const_funcs_1_str = R"ROC(function A_plus_const (_A, _B) { diff --git a/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc index 5bb351767e8cb..dcb6e1ca93914 100644 --- a/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc +++ b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc @@ -216,12 +216,12 @@ SearchState EvolutionarySearch::Mutate( // ir_schedule const auto& task_key = tune_task_.serialized_key; InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); - ir::IRSchedule new_ir_sch( + ir::IRSchedule pir_sch( ir::ir_utils::IRCopy(task_registry->Get(task_key)->module_expr), utils::ForkRandomState(rand_seed)); - new_trace.Replay(&new_ir_sch, true); - ApplyPostScheduleRules(&new_ir_sch, post_schedule_rules_); - auto res = SearchState(std::move(new_ir_sch)); + new_trace.Replay(&pir_sch, true); + ApplyPostScheduleRules(&pir_sch, post_schedule_rules_); + auto res = SearchState(std::move(pir_sch)); VLOG(5) << JoinStatesDebugString( "EvolutionarySearch::Mutate", {state, res}, /*verbose=*/VLOG_IS_ON(6)); diff --git a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc index 443c297c5e722..94222d748c054 100644 --- a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc +++ b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc @@ -17,6 +17,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/cinn.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" @@ -46,16 +47,13 @@ TEST(MutateTileSize, Basic) { [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); - poly::StageMap stages = CreateStages({A, B, C}); + ast_gen_ius::TensorGroup tensor_group({A, B, C}); std::vector funcs = - lang::LowerVec("TestMutateTileSize_Basic", - stages, - {A, B, C}, - {}, - {}, - nullptr, - target, - true); + lang::LowerToAstVec("TestMutateTileSize_Basic", + + {A, B, C}, + &tensor_group, + target); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Original Expr: "; @@ -65,7 +63,7 @@ TEST(MutateTileSize, Basic) { // repeated. utils::LinearRandomEngine::StateType rand_seed = 123; ir::IRSchedule ir_schedule(module_expr, rand_seed); - ir::IRSchedule new_ir_schedule(ir_schedule); + ir::IRSchedule pir_schedule(ir_schedule); // apply schedule auto loops = ir_schedule.GetLoops("C"); @@ -76,13 +74,13 @@ TEST(MutateTileSize, Basic) { MutateTileSize mutator; ir::ScheduleDesc sch_desc = mutator.Apply(ir_schedule.GetTraceDesc(), &rand_seed); - sch_desc.Replay(&new_ir_schedule, true); + sch_desc.Replay(&pir_schedule, true); VLOG(6) << "Expr before mutate tile size: \n" << ir_schedule.GetModule().GetExprs()[0]; VLOG(6) << "Expr after mutate tile size: \n" - << new_ir_schedule.GetModule().GetExprs()[0]; + << pir_schedule.GetModule().GetExprs()[0]; - std::string target_new_ir = R"ROC({ + std::string target_pir = R"ROC({ ScheduleBlock(root) { serial for (i_1, 0, 2) @@ -117,7 +115,7 @@ TEST(MutateTileSize, Basic) { ss << exprs[0]; return ss.str(); }; - ASSERT_EQ(get_ir_str(&new_ir_schedule), target_new_ir); + ASSERT_EQ(get_ir_str(&pir_schedule), target_pir); std::vector last_tile_factors = {2, 16}; for (int i = 0; i < 10; ++i) { diff --git a/paddle/cinn/backends/compiler.cc b/paddle/cinn/backends/compiler.cc index 0a64b24712f48..f63869730a11f 100644 --- a/paddle/cinn/backends/compiler.cc +++ b/paddle/cinn/backends/compiler.cc @@ -304,6 +304,8 @@ void Compiler::CompileCudaModule(const Module& module, auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name); CHECK(fn_kernel); + fn_ptr_.push_back(reinterpret_cast(fn_kernel)); + symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast(fn_kernel)); } diff --git a/paddle/cinn/backends/compiler.h b/paddle/cinn/backends/compiler.h index a468193d4d85a..f269b00492a42 100644 --- a/paddle/cinn/backends/compiler.h +++ b/paddle/cinn/backends/compiler.h @@ -121,6 +121,8 @@ class Compiler final { */ void* Lookup(absl::string_view fn_name); + std::vector GetFnPtr() const { return fn_ptr_; } + private: void CompileCudaModule(const ir::Module& module, const std::string& code = ""); @@ -136,6 +138,7 @@ class Compiler final { Target target_; std::unique_ptr engine_; + std::vector fn_ptr_; #ifdef CINN_WITH_CUDA std::unique_ptr cuda_module_; #endif diff --git a/paddle/cinn/cinn.h b/paddle/cinn/cinn.h index 333bc051ead98..e81771ba0c7e7 100644 --- a/paddle/cinn/cinn.h +++ b/paddle/cinn/cinn.h @@ -29,6 +29,7 @@ namespace cinn { +using ast_gen_ius::TensorGroup; using backends::CodeGenC; using backends::CodeGenCX86; using backends::Outputs; @@ -39,6 +40,7 @@ using lang::CallExtern; using lang::CallLowered; using lang::Compute; using lang::Lower; +using lang::LowerToAst; using lang::Placeholder; using lang::ReduceAll; using lang::ReduceAny; diff --git a/paddle/cinn/common/context.h b/paddle/cinn/common/context.h index 0e1566cca9932..9ade36f6de6c5 100644 --- a/paddle/cinn/common/context.h +++ b/paddle/cinn/common/context.h @@ -52,6 +52,22 @@ struct NameGenerator { mutable std::mutex mutex_; }; +struct PrettyNamer { + const std::string& GetOrNew(const size_t hash_key, + const std::string& name_hint) { + if (pretty_names_.find(hash_key) == pretty_names_.end()) { + pretty_names_[hash_key] = name_generator_.New(name_hint); + } + return pretty_names_.at(hash_key); + } + + NameGenerator& GetNameGenerator() { return name_generator_; } + + private: + absl::flat_hash_map pretty_names_; + NameGenerator name_generator_; +}; + class Context { public: static Context& Global(); @@ -61,10 +77,15 @@ class Context { * @param name_hint The prefix. */ std::string NewName(const std::string& name_hint) { - return name_generator_.New(name_hint); + return pretty_namer_.GetNameGenerator().New(name_hint); } - void ResetNameId() { name_generator_.ResetID(); } + std::string PrettyUniqName(const size_t hash_key, + const std::string& name_hint) { + return pretty_namer_.GetOrNew(hash_key, name_hint); + } + + void ResetNameId() { pretty_namer_.GetNameGenerator().ResetID(); } const std::vector& runtime_include_dir(); @@ -82,7 +103,7 @@ class Context { private: Context() = default; - NameGenerator name_generator_; + PrettyNamer pretty_namer_; std::vector runtime_include_dir_; mutable std::mutex mutex_; diff --git a/paddle/cinn/common/macros.h b/paddle/cinn/common/macros.h index 1c49109eb97b5..dbae22549331c 100644 --- a/paddle/cinn/common/macros.h +++ b/paddle/cinn/common/macros.h @@ -67,10 +67,10 @@ __test_global_namespace_##uniq_name##__>::value, \ msg) -#define USE_FUSION_PASS(pass_name) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __use_fusion_pass_##pass_name, \ - "USE_OP_ITSELF must be called in global namespace"); \ - extern int TouchFusionPassRegistrar_##pass_name(); \ - [[maybe_unused]] static int __use_fusion_pass_##pass_name##_ = \ - TouchFusionPassRegistrar_##pass_name() +#define USE_FUSION_PASS(pass_name) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __use_cinn_fusion_pass_##pass_name, \ + "USE_FUSION_PASS must be called in global namespace"); \ + extern int TouchCinnFusionPassRegistrar_##pass_name(); \ + [[maybe_unused]] static int __use_cinn_fusion_pass_##pass_name##_ = \ + TouchCinnFusionPassRegistrar_##pass_name() diff --git a/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt index 1a5857fd2cfe2..a8bd23352bd5c 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt @@ -1,4 +1,4 @@ -# TODO(Aurelius84): new_ir_compiler depends on pd_op_dialect and could +# TODO(Aurelius84): pir_compiler depends on pd_op_dialect and could # not found under CINN_ONLY mode if(NOT CINN_ONLY) set(CINN_DIALECT_BINARY_DIR @@ -35,6 +35,7 @@ if(NOT CINN_ONLY) COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_op_dir} COMMAND ${PYTHON_EXECUTABLE} ${cinn_op_gen_parsed_yaml_file} --op_yaml_path ${cinn_op_yaml_file} --output_path ${cinn_op_parsed_yaml_file} + DEPENDS ${cinn_op_gen_parsed_yaml_file} ${cinn_op_yaml_file} VERBATIM) add_custom_command( @@ -61,7 +62,7 @@ if(NOT CINN_ONLY) manual_op.cc op_attribute.cc DEPS - pd_op_dialect) + op_dialect_vjp) target_include_directories(cinn_op_dialect PRIVATE ${CINN_DIALECT_BINARY_DIR}) endif() diff --git a/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h b/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h index 99e12a3d13ab4..9c6959db093e4 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h +++ b/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h @@ -18,8 +18,8 @@ #include #include #include -#include "paddle/cinn/hlir/framework/new_ir/utils.h" #include "paddle/cinn/hlir/framework/op.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/pir/core/attribute_base.h" #include "paddle/pir/core/operation.h" @@ -51,7 +51,7 @@ struct GroupInfo { private: void Initialize() { op_pattern_kind = hlir::framework::OpPatternKind::kElementWise; - fn_name = hlir::framework::newir::CompatibleInfo::GroupOpsName(ops); + fn_name = hlir::framework::pir::CompatibleInfo::GroupOpsName(ops); } }; @@ -77,5 +77,27 @@ struct GroupInfoAttributeStorage : public pir::AttributeStorage { ParamKey data_; }; +struct JITInfoAttributeStorage : public pir::AttributeStorage { + using ParamKey = cinn::hlir::framework::pir::CUDAJITInfo; + explicit JITInfoAttributeStorage(const ParamKey& key) : data_(key) {} + + static JITInfoAttributeStorage* Construct(const ParamKey& key) { + return new JITInfoAttributeStorage(key); + } + + static std::size_t HashValue(const ParamKey& key) { + return std::hash()(*(reinterpret_cast(key.fn_ptr))); + } + + bool operator==(const ParamKey& key) const { + return data_.fn_ptr == key.fn_ptr; + } + + const ParamKey& GetAsKey() const { return data_; } + + private: + ParamKey data_; +}; + } // namespace dialect } // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index 3a4ebb63679f3..0ba418dbec811 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -15,16 +15,17 @@ #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include +#include "glog/logging.h" #include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/enforce.h" #include "paddle/pir/core/op_base.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" namespace cinn { namespace dialect { const char *GroupOp::attributes_name[GroupOp::attributes_num] = {"group_info"}; -// TODO(Aurlius84): Need to figure out how to rebuild relation info of ops outer -// GroupOp void GroupOp::Build(pir::Builder &builder, pir::OperationArgument &argument, const std::vector &output_types) { @@ -32,6 +33,20 @@ void GroupOp::Build(pir::Builder &builder, argument.output_types = output_types; } +void GroupOp::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + std::unique_ptr &&block) { + VLOG(4) << "Start build GroupOp"; + if (block && !block->empty()) { + IR_ENFORCE(block->back()->isa()); + auto *op = block->back(); + for (size_t i = 0; i < op->num_operands(); ++i) { + argument.AddOutput(op->operand(i).type()); + } + } + argument.AddRegion()->push_back(block.release()); +} + pir::Block *GroupOp::block() { pir::Region ®ion = (*this)->region(0); if (region.empty()) region.emplace_back(); diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index 39d433790be78..ba116d52a98c0 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -33,6 +33,10 @@ class GroupOp : public pir::Op { pir::OperationArgument &argument, // NOLINT const std::vector &output_types); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + std::unique_ptr &&block); + pir::Block *block(); std::vector ops(); diff --git a/paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc b/paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc index 554d7357af970..1899d5f44bee1 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc @@ -19,7 +19,13 @@ namespace dialect { const GroupInfo &GroupInfoAttribute::data() const { return storage()->GetAsKey(); } + +const cinn::hlir::framework::pir::CUDAJITInfo &CUDAJITInfoAttribute::data() + const { + return storage()->GetAsKey(); +} } // namespace dialect } // namespace cinn IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::GroupInfoAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::CUDAJITInfoAttribute) diff --git a/paddle/cinn/hlir/dialect/operator/ir/op_attribute.h b/paddle/cinn/hlir/dialect/operator/ir/op_attribute.h index 6e92b45002785..10bd5ebc300a4 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/op_attribute.h +++ b/paddle/cinn/hlir/dialect/operator/ir/op_attribute.h @@ -33,7 +33,22 @@ class GroupInfoAttribute : public pir::Attribute { const GroupInfo& data() const; }; +class CUDAJITInfoAttribute : public pir::Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(CUDAJITInfoAttribute, + JITInfoAttributeStorage); + + bool operator<(const CUDAJITInfoAttribute& right) const { + return storage() < right.storage(); + } + + const cinn::hlir::framework::pir::CUDAJITInfo& data() const; +}; + } // namespace dialect } // namespace cinn IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupInfoAttribute) +IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::CUDAJITInfoAttribute) diff --git a/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc index 6d2f0409f24e9..11ccd77bb109d 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc @@ -39,20 +39,31 @@ void OperatorDialect::initialize() { >(); RegisterOp(); RegisterAttribute(); + RegisterAttribute(); } void OperatorDialect::PrintType(pir::Type type, std::ostream &os) const {} void OperatorDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const { - os << "(" << attr.dialect().name(); - os << '.'; - if (auto group_info_attr = attr.dyn_cast()) { - const GroupInfo &data = group_info_attr.data(); - os << "GroupInfo)" - << "[" << data.fn_name << "]"; + if (attr.isa()) { + os << "(" << attr.dialect().name(); + os << '.'; + if (auto group_info_attr = attr.dyn_cast()) { + const GroupInfo &data = group_info_attr.data(); + os << "GroupInfo)" + << "[" << data.fn_name << "]"; + } + { os << "<#AttrNotImplemented>"; } + } else if (attr.isa()) { + auto cuda_jit_info = attr.dyn_cast(); + + os << "(" << cuda_jit_info.data().fn_ptr; + os << ')'; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "cinn dialect only support GrupInfo and CUDAJITInfo")); } - { os << "<#AttrNotImplemented>"; } } void OperatorDialect::PrintOperation(pir::Operation *op, diff --git a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml index 9f14c6e406661..27aa274156d96 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml +++ b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml @@ -9,7 +9,7 @@ param : [x, broadcast_axes] - op : reduce_max - args : (Tensor x, int64_t[] axis, bool keep_dim) + args : (Tensor x, int64_t[] dim, bool keep_dim) output : Tensor(out) infer_meta : func : ReduceInferMeta @@ -17,7 +17,7 @@ func : frobenius_norm - op : reduce_sum - args : (Tensor x, int64_t[] axis, bool keep_dim) + args : (Tensor x, int64_t[] dim, bool keep_dim) output : Tensor(out) infer_meta : func : ReduceInferMeta diff --git a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt index 770e78d191e3d..20ee7cb1c9baa 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt @@ -4,7 +4,28 @@ if(NOT CINN_ONLY) SRCS group_with_group_merge_pass.cc op_with_group_merge_pass.cc + cinn_group_lowering_pass.cc tensor_node.cc DEPS - pd_op_dialect) + op_dialect_vjp + pir_compiler + cinn_runtime_dialect) + + cinn_cc_library( + pd_to_cinn_pass + SRCS + pd_to_cinn_pass.cc + DEPS + drr + cinn_op_dialect + op_dialect_vjp) + + cinn_cc_library( + add_broadcast_to_elementwise_pass + SRCS + add_broadcast_to_elementwise_pass.cc + DEPS + pir + cinn_op_dialect + op_dialect_vjp) endif() diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc new file mode 100644 index 0000000000000..7e8635b951fc8 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h" + +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/pattern_applicator.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace cinn { +namespace dialect { +namespace ir { + +int64_t GetDimByIndex(const phi::DDim& first, + const phi::DDim& second, + int short_align_axis, + int idx) { + // rank of first less than rank of second + if (idx < short_align_axis) { + return second[idx]; + } else { + return first[idx - short_align_axis] > second[idx] + ? first[idx - short_align_axis] + : second[idx]; + } +} + +std::vector GetOutputShape(const phi::DDim& x, const phi::DDim& y) { + std::vector vec_res; + if (x.size() >= y.size()) { + int short_align_axis = x.size() - y.size(); + int max_rank = x.size(); + vec_res.resize(max_rank); + for (size_t i = 0; i < max_rank; ++i) { + vec_res[i] = GetDimByIndex(y, x, short_align_axis, i); + } + } else { + int short_align_axis = y.size() - x.size(); + int max_rank = y.size(); + + vec_res.resize(max_rank); + for (size_t i = 0; i < max_rank; ++i) { + vec_res[i] = GetDimByIndex(x, y, short_align_axis, max_rank); + } + } + + return vec_res; +} + +bool IsSameDim(const phi::DDim& first, const std::vector& second) { + if (first.size() == second.size()) { + bool same = true; + + for (size_t i = 0; i < first.size(); ++i) { + if (first[i] != second[i]) { + same = false; + break; + } + } + + return same; + } + return false; +} + +bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { + auto x_dims = op->operand_source(0) + .type() + .dyn_cast() + .dims(); + auto y_dims = op->operand_source(1) + .type() + .dyn_cast() + .dims(); + + if (x_dims != y_dims) { + auto output_shape = GetOutputShape(x_dims, y_dims); + std::vector vec_dims; + for (int64_t i = 0; i < output_shape.size(); ++i) { + vec_dims.push_back(i); + } + if (!IsSameDim(x_dims, output_shape)) { + // add broadcast to input 0 + auto new_transpose_op = rewriter->Build( + op->operand_source(0), vec_dims, output_shape); + + op->operand(0).set_source(new_transpose_op->result(0)); + } + + if (!IsSameDim(y_dims, output_shape)) { + auto new_transpose_op = rewriter->Build( + op->operand_source(1), vec_dims, output_shape); + + op->operand(1).set_source(new_transpose_op->result(0)); + } + + return true; + } + + return false; +} + +template +class AddBrodcastToElementwisePattern : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(OPTYPE op, + pir::PatternRewriter& rewriter) const override { + return ProcessOp(op, &rewriter); + } +}; + +AddBroadcastToElementwisePass::AddBroadcastToElementwisePass() + : pir::Pass("add_broadcast_to_elementwise_pass", 1) {} + +bool AddBroadcastToElementwisePass::Initialize(pir::IrContext* context) { + pir::RewritePatternSet ps(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + + patterns_ = ::pir::FrozenRewritePatternSet(std::move(ps)); + return true; +} + +void AddBroadcastToElementwisePass::Run(pir::Operation* op) { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); +} + +bool AddBroadcastToElementwisePass::CanApplyOn(pir::Operation* op) const { + return op->isa() && op->num_regions() > 0; +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h new file mode 100644 index 0000000000000..6478479248fd9 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h @@ -0,0 +1,40 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class AddBroadcastToElementwisePass : public pir::Pass { + public: + AddBroadcastToElementwisePass(); + + bool Initialize(pir::IrContext *context) override; + + void Run(pir::Operation *op) override; + + bool CanApplyOn(pir::Operation *op) const override; + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc new file mode 100644 index 0000000000000..ba5c946ff3164 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc @@ -0,0 +1,210 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h" + +#include + +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" + +namespace cinn { +namespace dialect { +namespace ir { + +std::vector GetBlockOutsideInput( + const std::vector op_list) { + std::vector vec_res; + std::unordered_set<::pir::Value> block_inner_output; + for (size_t k = 0; k < op_list.size(); ++k) { + for (size_t i = 0; i < op_list[k]->num_results(); ++i) { + block_inner_output.insert(op_list[k]->result(i)); + } + } + + std::unordered_set<::pir::Value> insert_value; + for (size_t k = 0; k < op_list.size(); ++k) { + for (size_t i = 0; i < op_list[k]->num_operands(); ++i) { + if (!block_inner_output.count(op_list[k]->operand_source(i)) && + !insert_value.count(op_list[k]->operand_source(i))) { + vec_res.push_back(op_list[k]->operand_source(i)); + insert_value.insert(op_list[k]->operand_source(i)); + } + } + } + return vec_res; +} + +std::vector GetBlockOutsideOutput( + const std::vector op_list, + const std::vector group_all_list) { + assert(group_all_list.size() >= 2); + assert(group_all_list.back()->isa()); + + auto yeild_op = group_all_list.back()->dyn_cast(); + + std::unordered_set yeild_inputs; + for (size_t i = 0; i < yeild_op.num_operands(); ++i) { + yeild_inputs.insert(yeild_op.operand_source(i)); + } + + std::unordered_set innner_op_set(op_list.begin(), + op_list.end()); + std::unordered_set outside_group_set; + + for (size_t i = 0; i < group_all_list.size(); ++i) { + if (!innner_op_set.count(group_all_list[i])) { + outside_group_set.insert(group_all_list[i]); + } + } + + std::vector vec_res; + + for (auto* op : op_list) { + for (size_t i = 0; i < op->num_results(); ++i) { + if (yeild_inputs.count(op->result(i))) { + vec_res.push_back(op->result(i)); + } else { + for (auto it = op->result(i).use_begin(); it != op->result(i).use_end(); + ++it) { + if (outside_group_set.count(it->owner())) { + vec_res.push_back(op->result(i)); + break; + } + } + } + } + } + return vec_res; +} + +std::vector GetOpListNotIncludeYield( + const std::vector& op_list) { + std::vector vec_res; + for (size_t i = 0; i < op_list.size(); ++i) { + if (!op_list[i]->isa()) { + vec_res.push_back(op_list[i]); + } + } + + return vec_res; +} + +std::unique_ptr CINNGroupLoweringPass(::pir::Program* program) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + std::string jit_op_name = cinn::dialect::JitKernelOp::name(); + ::pir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name); + + auto ir_program = std::make_unique<::pir::Program>(ctx); + std::unordered_map value_map; + + auto target = cinn::common::DefaultNVGPUTarget(); + auto scope = cinn::hlir::framework::BuildScope(target, *program); + + for (auto it = program->block()->begin(); it != program->block()->end(); + ++it) { + if ((*it)->isa()) { + // GetOpList and Call cinn CodeGen + auto group_op = (*it)->dyn_cast(); + + // op fusion + auto op_fusion = cinn::dialect::ir::OpFusionPassInternal( + GetOpListNotIncludeYield(group_op.ops())); + + // fusion merge + auto group_list = + cinn::dialect::ir::GeneralFusionMergePassInternal(op_fusion); + + PADDLE_ENFORCE_EQ(group_list.size(), + 1u, + phi::errors::Unimplemented( + "Only support one group after group fusion")); + for (auto group : group_list) { + auto ir_compiler = std::make_shared( + *program, target, scope); + hlir::framework::PirCompilerManager::Instance().insert(ir_compiler); + auto group1 = + std::make_shared(group->ops); + auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group1}); + std::unordered_map op_attrs{ + {cinn::dialect::JitKernelOp::kAttrName, + cinn::dialect::CUDAJITInfoAttribute::get(ctx, fn_ptr_res[0])}, + }; + + // Generate jit kernel op input and output + auto vec_ins = GetBlockOutsideInput(group->ops); + + std::vector vec_new_ins; + for (size_t i = 0; i < vec_ins.size(); ++i) { + vec_new_ins.push_back(value_map.at(vec_ins[i])); + } + + auto vec_outs = GetBlockOutsideOutput(group->ops, group_op.ops()); + + std::vector vec_types; + for (auto& out : vec_outs) { + vec_types.push_back(out.type()); + } + + ::pir::Operation* cinn_op = + ::pir::Operation::Create(vec_new_ins, op_attrs, vec_types, op_info); + + for (size_t i = 0; i < group_op.num_results(); ++i) { + value_map[group_op.result(i)] = cinn_op->result(i); + } + + ir_program->block()->push_back(cinn_op); + } + + } else { + std::vector vec_ins; + + for (size_t i = 0; i < (*it)->num_operands(); ++i) { + vec_ins.push_back(value_map.at((*it)->operand_source(i))); + } + + std::vector vec_types; + for (size_t i = 0; i < (*it)->num_results(); ++i) { + vec_types.push_back((*it)->result(i).type()); + } + + ::pir::OpInfo info1 = ctx->GetRegisteredOpInfo((*it)->name()); + ::pir::Operation* op = ::pir::Operation::Create( + vec_ins, (*it)->attributes(), vec_types, info1); + + ir_program->block()->push_back(op); + for (size_t i = 0; i < (*it)->num_results(); ++i) { + value_map[(*it)->result(i)] = op->result(i); + } + } + } + return ir_program; +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h new file mode 100644 index 0000000000000..99d113555a39f --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h @@ -0,0 +1,27 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/program.h" + +namespace cinn { +namespace dialect { +namespace ir { + +std::unique_ptr CINNGroupLoweringPass(::pir::Program* program); + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc index e9c165bbcec52..fcf7f2242f09b 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc @@ -26,6 +26,8 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h" #include "paddle/phi/core/flags.h" +#include "paddle/cinn/common/is_reachable_predicator.h" + PD_DECLARE_bool(enhance_vertical_fusion_with_recompute); namespace cinn { @@ -154,49 +156,45 @@ class GraphGroupFuseHelper final : public FuseHelper { private: bool IsReachableInDag(const OpGroupPtr& producer, const OpGroupPtr& consumer) const { - // const auto& MinDepth4Node = [&](const OpGroupPtr& node) { - // return node.GetGroup()->min_depth; - // }; - // const auto& MaxDepth4Node = [&](const OpGroupPtr& node) { - // return node.GetGroup()->max_depth; - // }; - // const auto& VisitNextNodes = - // [&](const OpGroupPtr& node, - // const std::function& Visit) { - // for (const auto& node_producer : node.producers()) { - // Visit(node_producer); - // } - // }; - // common::IsReachablePredicator is_reachable( - // MinDepth4Node, MaxDepth4Node, VisitNextNodes); - // return is_reachable(consumer, producer, [](OpGroupPtr) {}); - // TODO(phlrain) : support IsReachable - return false; + const auto& MinDepth4Node = [&](const OpGroupPtr& node) { + return node.GetGroup()->min_depth; + }; + const auto& MaxDepth4Node = [&](const OpGroupPtr& node) { + return node.GetGroup()->max_depth; + }; + const auto& VisitNextNodes = + [&](const OpGroupPtr& node, + const std::function& Visit) { + for (const auto& node_producer : node.producers()) { + Visit(node_producer); + } + }; + ::cinn::common::IsReachablePredicator is_reachable( + MinDepth4Node, MaxDepth4Node, VisitNextNodes); + return is_reachable(consumer, producer, [](OpGroupPtr) {}); } bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& producer, const OpGroupPtr& consumer) const { - // const auto& MinDepth4Node = [&](const OpGroupPtr& node) { - // return node.GetGroup()->min_depth; - // }; - // const auto& MaxDepth4Node = [&](const OpGroupPtr& node) { - // return node.GetGroup()->max_depth; - // }; - // const auto& VisitNextNodes = - // [&](const OpGroupPtr& node, - // const std::function& Visit) { - // for (const auto& node_producer : node.producers()) { - // if (node == consumer && node_producer == producer) { - // continue; - // } - // Visit(node_producer); - // } - // }; - // common::IsReachablePredicator is_reachable( - // MinDepth4Node, MaxDepth4Node, VisitNextNodes); - // return is_reachable(consumer, producer, [](OpGroupPtr) {}); - // TODO(phlrain) : support IsReachable - return false; + const auto& MinDepth4Node = [&](const OpGroupPtr& node) { + return node.GetGroup()->min_depth; + }; + const auto& MaxDepth4Node = [&](const OpGroupPtr& node) { + return node.GetGroup()->max_depth; + }; + const auto& VisitNextNodes = + [&](const OpGroupPtr& node, + const std::function& Visit) { + for (const auto& node_producer : node.producers()) { + if (node == consumer && node_producer == producer) { + continue; + } + Visit(node_producer); + } + }; + common::IsReachablePredicator is_reachable( + MinDepth4Node, MaxDepth4Node, VisitNextNodes); + return is_reachable(consumer, producer, [](OpGroupPtr) {}); } const FusePassCtxT* ctx_; @@ -390,7 +388,7 @@ bool GraphGroupFuseHelper::ReduceFuseBroadcast( template bool GraphGroupFuseHelper::ReduceFuseReduce( const OpGroupPtr& src, const OpGroupPtr& dst) const { - return reduce_fuse_reduce(src.GetGroup(), dst.GetGroup()); + return ReduceFuseReduce1(src, dst); } template @@ -407,6 +405,7 @@ struct HorizontalFuseUtil { return false; } auto out = iter->second(src, dst); + return out; } @@ -419,25 +418,29 @@ struct HorizontalFuseUtil { static std::map RawConditionMap() { return std::map{ - {{kElementWise, kElementWise}, &IsSameSize}, - {{kElementWise, kBroadcast}, &IsSameSize}, - {{kElementWise, kInjective}, &IsSameSize}, - {{kElementWise, kReduction}, &HorizontalElementwiseFuseReduce}, - - {{kBroadcast, kElementWise}, &IsSameSize}, - {{kBroadcast, kBroadcast}, &IsSameSize}, - {{kBroadcast, kInjective}, &IsSameSize}, - {{kBroadcast, kReduction}, &IsSameSize}, - - {{kInjective, kElementWise}, &IsSameSize}, - {{kInjective, kBroadcast}, &IsSameSize}, - {{kInjective, kInjective}, &IsSameSize}, - {{kInjective, kReduction}, &IsSameSize}, - - {{kReduction, kElementWise}, &HorizontalElementwiseFuseReduce}, - {{kReduction, kBroadcast}, &IsSameSize}, - {{kReduction, kInjective}, &IsSameSize}, - {{kReduction, kReduction}, &ReduceFuseReduce}, + {{OpPatternKind::kElementWise, OpPatternKind::kElementWise}, + &IsSameSize}, + {{OpPatternKind::kElementWise, OpPatternKind::kBroadcast}, &IsSameSize}, + {{OpPatternKind::kElementWise, OpPatternKind::kInjective}, &IsSameSize}, + {{OpPatternKind::kElementWise, OpPatternKind::kReduction}, + &HorizontalElementwiseFuseReduce}, + + {{OpPatternKind::kBroadcast, OpPatternKind::kElementWise}, &IsSameSize}, + {{OpPatternKind::kBroadcast, OpPatternKind::kBroadcast}, &IsSameSize}, + {{OpPatternKind::kBroadcast, OpPatternKind::kInjective}, &IsSameSize}, + {{OpPatternKind::kBroadcast, OpPatternKind::kReduction}, &IsSameSize}, + + {{OpPatternKind::kInjective, OpPatternKind::kElementWise}, &IsSameSize}, + {{OpPatternKind::kInjective, OpPatternKind::kBroadcast}, &IsSameSize}, + {{OpPatternKind::kInjective, OpPatternKind::kInjective}, &IsSameSize}, + {{OpPatternKind::kInjective, OpPatternKind::kReduction}, &IsSameSize}, + + {{OpPatternKind::kReduction, OpPatternKind::kElementWise}, + &HorizontalElementwiseFuseReduce}, + {{OpPatternKind::kReduction, OpPatternKind::kBroadcast}, &IsSameSize}, + {{OpPatternKind::kReduction, OpPatternKind::kInjective}, &IsSameSize}, + {{OpPatternKind::kReduction, OpPatternKind::kReduction}, + &ReduceFuseReduce}, }; } @@ -455,7 +458,7 @@ struct HorizontalFuseUtil { const OpGroupPtr* ele_group = nullptr; const OpGroupPtr* reduce_group = nullptr; - if (src.kind() == kReduction) { + if (src.kind() == OpPatternKind::kReduction) { ele_group = &dst; reduce_group = &src; } else { @@ -481,7 +484,7 @@ struct HorizontalFuseUtil { static bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) { // return ctx->fuse_helper().ReduceFuseReduce(src, dst); - return reduce_fuse_reduce(src.GetGroup(), dst.GetGroup()); + return ReduceFuseReduce1(src, dst); } }; @@ -524,8 +527,10 @@ class DefaultInputFusePass final : public InputFusePass { [&]() -> std::unordered_set { std::unordered_set consumers; for (const auto& consumer : consumer_set) { - if (consumer.kind() == kElementWise || consumer.kind() == kBroadcast || - consumer.kind() == kInjective || consumer.kind() == kReduction) { + if (consumer.kind() == OpPatternKind::kElementWise || + consumer.kind() == OpPatternKind::kBroadcast || + consumer.kind() == OpPatternKind::kInjective || + consumer.kind() == OpPatternKind::kReduction) { consumers.insert(consumer); } } @@ -613,8 +618,10 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { [&]() -> std::unordered_set { std::unordered_set consumers; for (const auto& consumer : producer.consumers()) { - if (consumer.kind() == kElementWise || consumer.kind() == kBroadcast || - consumer.kind() == kInjective || consumer.kind() == kReduction) { + if (consumer.kind() == OpPatternKind::kElementWise || + consumer.kind() == OpPatternKind::kBroadcast || + consumer.kind() == OpPatternKind::kInjective || + consumer.kind() == OpPatternKind::kReduction) { consumers.insert(consumer); } } @@ -715,7 +722,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { candidates.push_back(consumer); } if (candidates.size() == consumers.size() && - producer.kind() == kElementWise) { + producer.kind() == OpPatternKind::kElementWise) { return; } @@ -756,40 +763,40 @@ class DefaultVerticalFusePass final : public VerticalFusePass { static std::map RawConditionMap() { return std::map{ - {{OpPatternKind::kElementWise, kElementWise}, + {{OpPatternKind::kElementWise, OpPatternKind::kElementWise}, &DefaultVerticalFusePass::IsSameSize}, - {{OpPatternKind::kElementWise, kBroadcast}, + {{OpPatternKind::kElementWise, OpPatternKind::kBroadcast}, &DefaultVerticalFusePass::ElementwiseFuseBroadcast}, - {{OpPatternKind::kElementWise, kInjective}, + {{OpPatternKind::kElementWise, OpPatternKind::kInjective}, &DefaultVerticalFusePass::HorizontalWithInjective}, - {{OpPatternKind::kElementWise, kReduction}, + {{OpPatternKind::kElementWise, OpPatternKind::kReduction}, &DefaultVerticalFusePass::ElementwiseFuseReduce}, - {{OpPatternKind::kBroadcast, kElementWise}, + {{OpPatternKind::kBroadcast, OpPatternKind::kElementWise}, &DefaultVerticalFusePass::IsSameSize}, - {{OpPatternKind::kBroadcast, kBroadcast}, + {{OpPatternKind::kBroadcast, OpPatternKind::kBroadcast}, &DefaultVerticalFusePass::IsSameSize}, - {{OpPatternKind::kBroadcast, kInjective}, + {{OpPatternKind::kBroadcast, OpPatternKind::kInjective}, &DefaultVerticalFusePass::HorizontalWithInjective}, - {{OpPatternKind::kBroadcast, kReduction}, + {{OpPatternKind::kBroadcast, OpPatternKind::kReduction}, &DefaultVerticalFusePass::BroadcastFuseReduce}, - {{OpPatternKind::kInjective, kElementWise}, + {{OpPatternKind::kInjective, OpPatternKind::kElementWise}, &DefaultVerticalFusePass::IsSameSize}, - {{OpPatternKind::kInjective, kBroadcast}, + {{OpPatternKind::kInjective, OpPatternKind::kBroadcast}, &DefaultVerticalFusePass::IsSameSize}, - {{OpPatternKind::kInjective, kInjective}, + {{OpPatternKind::kInjective, OpPatternKind::kInjective}, &DefaultVerticalFusePass::HorizontalWithInjective}, - {{OpPatternKind::kInjective, kReduction}, + {{OpPatternKind::kInjective, OpPatternKind::kReduction}, &DefaultVerticalFusePass::InjectiveHorizontalWithReduce}, - {{OpPatternKind::kReduction, kElementWise}, + {{OpPatternKind::kReduction, OpPatternKind::kElementWise}, &DefaultVerticalFusePass::ReduceFuseElementwise}, - {{OpPatternKind::kReduction, kBroadcast}, + {{OpPatternKind::kReduction, OpPatternKind::kBroadcast}, &DefaultVerticalFusePass::ReduceFuseBroadcast}, - {{OpPatternKind::kReduction, kInjective}, + {{OpPatternKind::kReduction, OpPatternKind::kInjective}, &DefaultVerticalFusePass::HorizontalWithInjective}, - {{OpPatternKind::kReduction, kReduction}, + {{OpPatternKind::kReduction, OpPatternKind::kReduction}, &DefaultVerticalFusePass::ReduceFuseReduce}, }; } @@ -895,7 +902,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { } if (!candidates.empty() && unsafe_candidates.size() == consumers.size() && - producer.kind() == kElementWise) { + producer.kind() == OpPatternKind::kElementWise) { for (const auto& consumer : consumers) { ctx->MarkFusible(producer, consumer); } @@ -1023,9 +1030,7 @@ class FusionPassRegistrar final : public Registrar { // code generation. class GeneralFusionMergePassHelper { public: - explicit GeneralFusionMergePassHelper(const ::pir::Program* graph, - const GroupList& group_list) - : graph_(graph) { + explicit GeneralFusionMergePassHelper(const GroupList& group_list) { fusion_groups_ = group_list; // init input to consumers. InitInputToConsumers(); @@ -1078,8 +1083,10 @@ class GeneralFusionMergePassHelper { VLOG(3) << "DoFusionMerge...!"; while (DoGeneralHorizontalFusion()) { } + while (DoGeneralVerticalFusion()) { } + while (DoGeneralRecomputeAndVerticalFusion()) { } } @@ -1355,6 +1362,10 @@ class GeneralFusionMergePassHelper { } else { fused_group->group_id = consumer->group_id; } + + for (auto* op : consumer->ops) { + fused_group->ops.push_back(op); + } // set op pattern kind fused_group->op_pattern_kind = static_cast(fused_group->op_pattern_kind) >= @@ -1362,27 +1373,27 @@ class GeneralFusionMergePassHelper { ? fused_group->op_pattern_kind : consumer->op_pattern_kind; // input nodes - for (auto& node : consumer->input_nodes) { - if (fused_group->input_nodes.count(node.first)) { - fused_group->input_nodes[node.first] += node.second; + for (auto& node : consumer->input_ops) { + if (fused_group->input_ops.count(node.first)) { + fused_group->input_ops[node.first] += node.second; } else { - fused_group->input_nodes.insert(node); + fused_group->input_ops.insert(node); } } // output node - for (auto& node : consumer->output_nodes) { - fused_group->output_nodes.insert(node); + for (auto& node : consumer->output_ops) { + fused_group->output_ops.insert(node); } // internal node if (consumer->fused_sub_groups.size()) { - for (auto& node : consumer->internal_nodes) { - fused_group->internal_nodes.insert(node); + for (auto& node : consumer->internal_ops) { + fused_group->internal_ops.insert(node); } } // master node - for (auto& node : consumer->master_nodes) { - if (GetOpKind(node->name()) == kReduction) { - fused_group->master_nodes.insert(node); + for (auto& node : consumer->master_ops) { + if (GetOpKind(node->name()) == OpPatternKind::kReduction) { + fused_group->master_ops.insert(node); } } // insert sub group @@ -1440,27 +1451,27 @@ class GeneralFusionMergePassHelper { // if node is output nodes of sub_group, check it can't be internal node. for (auto& sub_group : repeat_sub_groups) { // check each output node in sub_group. - for (auto& node : sub_group->output_nodes) { + for (auto& node : sub_group->output_ops) { // if node is not output node of fused_group. - if (!fused_group->output_nodes.count(node)) { - fused_group->internal_nodes.insert(node); + if (!fused_group->output_ops.count(node)) { + fused_group->internal_ops.insert(node); } } } - if (static_cast(kReduction) > + if (static_cast(OpPatternKind::kReduction) > static_cast((consumers.back())->op_pattern_kind)) { auto consumer = consumers.back(); - for (auto& node : consumer->master_nodes) { - fused_group->master_nodes.insert(node); + for (auto& node : consumer->master_ops) { + fused_group->master_ops.insert(node); } } else { for (auto consumer = consumers.rbegin(); consumer != consumers.rend(); ++consumer) { ::pir::Operation* master_node = nullptr; - for (auto& node : (*consumer)->master_nodes) { - if (GetOpKind(node->name()) != kReduction) { + for (auto& node : (*consumer)->master_ops) { + if (GetOpKind(node->name()) != OpPatternKind::kReduction) { master_node = node; break; } @@ -1468,7 +1479,7 @@ class GeneralFusionMergePassHelper { if (master_node) { // VLOG(3) << "Insert Master node : " << master_node->id() // << " into group : " << fused_group->group_id; - fused_group->master_nodes.insert(master_node); + fused_group->master_ops.insert(master_node); break; } } @@ -1478,7 +1489,7 @@ class GeneralFusionMergePassHelper { fusion_groups_[postion] = fused_group; fusion_groups_index_[fused_group] = postion; - CHECK(fused_group->output_nodes.size()) + CHECK(fused_group->output_ops.size()) << "No output node is found, " << fused_group->group_id; } @@ -1502,6 +1513,7 @@ class GeneralFusionMergePassHelper { } const auto& fuse_passes = GetVerticalFusePasses(); for (const auto& fuse_pass : fuse_passes) { + // TODO(Aurelius84): Broadcast_Test_2 failed here (*fuse_pass)(ctx); } } @@ -1570,27 +1582,32 @@ class GeneralFusionMergePassHelper { ? producer->op_pattern_kind : consumer->op_pattern_kind; // input nodes - fused_group->input_nodes = producer->input_nodes; + fused_group->input_ops = producer->input_ops; + + fused_group->ops = producer->ops; + for (size_t i = 0; i < consumer->ops.size(); ++i) { + fused_group->ops.push_back(consumer->ops[i]); + } // internal nodes if (producer->fused_sub_groups.size()) { - for (auto& node : producer->internal_nodes) { - fused_group->internal_nodes.insert(node); + for (auto& node : producer->internal_ops) { + fused_group->internal_ops.insert(node); } } // convert producer's output node to internal. - for (auto node : producer->output_nodes) { + for (auto node : producer->output_ops) { // if node is used more than 1 time. - if (consumer->input_nodes.count(node)) { - if (consumer->input_nodes[node] > 1 && node->num_operands() > 0) { - fused_group->internal_nodes.insert(node); + if (consumer->input_ops.count(node)) { + if (consumer->input_ops[node] > 1 && node->num_operands() > 0) { + fused_group->internal_ops.insert(node); } } } // master nodes - for (auto& node : producer->master_nodes) { - if (GetOpKind(node->name()) == kReduction) { - fused_group->master_nodes.insert(node); + for (auto& node : producer->master_ops) { + if (GetOpKind(node->name()) == OpPatternKind::kReduction) { + fused_group->master_ops.insert(node); } } @@ -1616,32 +1633,32 @@ class GeneralFusionMergePassHelper { producer->belong_groups.insert(fused_group); // input nodes - for (auto& input_node : consumer->input_nodes) { + for (auto& input_node : consumer->input_ops) { // if input node not in producer output. - if (!producer->output_nodes.count(input_node.first)) { - if (fused_group->input_nodes.count(input_node.first)) { - fused_group->input_nodes[input_node.first] += input_node.second; + if (!producer->output_ops.count(input_node.first)) { + if (fused_group->input_ops.count(input_node.first)) { + fused_group->input_ops[input_node.first] += input_node.second; } else { - fused_group->input_nodes.insert(input_node); + fused_group->input_ops.insert(input_node); } } } // output nodes - for (auto& node : consumer->output_nodes) { - fused_group->output_nodes.insert(node); + for (auto& node : consumer->output_ops) { + fused_group->output_ops.insert(node); } // internal nodes if (consumer->fused_sub_groups.size()) { - for (auto& node : consumer->internal_nodes) { - fused_group->internal_nodes.insert(node); + for (auto& node : consumer->internal_ops) { + fused_group->internal_ops.insert(node); } } // master nodes - for (auto& node : consumer->master_nodes) { - fused_group->master_nodes.insert(node); + for (auto& node : consumer->master_ops) { + fused_group->master_ops.insert(node); } // producer nodes @@ -1690,36 +1707,36 @@ class GeneralFusionMergePassHelper { if (!master_fuesd_group.get()) { master_fuesd_group = fused_group; } - CHECK(fused_group->output_nodes.size()) + CHECK(fused_group->output_ops.size()) << "No output node is found, " << fused_group->group_id; } - for (auto& node : producer->output_nodes) { + for (auto& node : producer->output_ops) { bool be_output = true; for (const auto& consumer : producer->consumer_groups()) { // if consumer is in fusionable. if (fusionable_consumers.count(consumer)) { - if (consumer->input_nodes.count(node)) { + if (consumer->input_ops.count(node)) { be_output = false; } continue; } // if consumer is not in fusionable. - if (consumer->input_nodes.count(node)) { + if (consumer->input_ops.count(node)) { be_output = true; break; } // others node is as graph output. } - if (output_nodes_set_.count(node)) { + if (output_ops_set_.count(node)) { be_output = true; } if (be_output) { // VLOG(4) << "Insert Id " << node->id() << " Into Group " // << master_fuesd_group->group_id; - master_fuesd_group->output_nodes.insert(node); + master_fuesd_group->output_ops.insert(node); } } // insert unfusionable consumer groups @@ -1820,11 +1837,11 @@ class GeneralFusionMergePassHelper { // just merge the node into group. auto& sub_group = consumer->fused_sub_groups.front(); sub_group->group_id = producer->group_id + "_" + sub_group->group_id; - sub_group->nodes.insert(sub_group->nodes.begin(), - producer->CollectNodes()[0]); - sub_group->nodes_set.insert(producer->CollectNodes()[0]); + sub_group->ops.insert(sub_group->ops.begin(), + producer->CollectOps()[0]); + sub_group->ops_set.insert(producer->CollectOps()[0]); // remove depency. - consumer->input_nodes.erase(producer->CollectNodes()[0]); + consumer->input_ops.erase(producer->CollectOps()[0]); consumer->mut_producer_groups()->erase(producer); producer->mut_consumer_groups()->erase(consumer); } @@ -1832,7 +1849,7 @@ class GeneralFusionMergePassHelper { CHECK_GE(producer->consumer_groups().size(), candidates.size()); if (producer->consumer_groups().size() == 0 && candidates.size() == 0 && - output_nodes_set_.count(producer->CollectNodes()[0]) == 0) { + output_ops_set_.count(producer->CollectOps()[0]) == 0) { producer->belong_groups.insert(*fusionable_consumers->begin()); } @@ -1849,19 +1866,19 @@ class GeneralFusionMergePassHelper { if (false) { std::vector candidates; for (auto& consumer : *fusionable_consumers) { - if (consumer->op_pattern_kind == kElementWise) { + if (consumer->op_pattern_kind == OpPatternKind::kElementWise) { candidates.push_back(consumer); continue; } auto producer_output_shape = phi::vectorize( - GetValueShape((*producer->output_nodes.begin())->result(0))); + GetValueShape((*producer->output_ops.begin())->result(0))); auto consumer_output_shape = phi::vectorize( - GetValueShape((*consumer->output_nodes.begin())->result(0))); + GetValueShape((*consumer->output_ops.begin())->result(0))); auto consumer_master_input_shape = phi::vectorize(GetValueShape( - (*(consumer->master_nodes.begin()))->operand_source(0))); + (*(consumer->master_ops.begin()))->operand_source(0))); int producer_output_numel = std::accumulate(producer_output_shape.begin(), @@ -1883,8 +1900,8 @@ class GeneralFusionMergePassHelper { continue; } - if (producer->op_pattern_kind != kInjective && - consumer->op_pattern_kind == kReduction && + if (producer->op_pattern_kind != OpPatternKind::kInjective && + consumer->op_pattern_kind == OpPatternKind::kReduction && producer_output_numel == consumer_master_input_numel) { candidates.push_back(consumer); } @@ -1902,15 +1919,15 @@ class GeneralFusionMergePassHelper { } else { std::vector candidates; for (auto& consumer : *fusionable_consumers) { - if (consumer->op_pattern_kind == kElementWise) { + if (consumer->op_pattern_kind == OpPatternKind::kElementWise) { candidates.push_back(consumer); continue; } auto shape0 = phi::vectorize( - GetValueShape((*producer->output_nodes.begin())->result(0))); + GetValueShape((*producer->output_ops.begin())->result(0))); auto shape1 = phi::vectorize( - GetValueShape((*consumer->output_nodes.begin())->result(0))); + GetValueShape((*consumer->output_ops.begin())->result(0))); if (std::accumulate( shape0.begin(), shape0.end(), 1, std::multiplies()) == @@ -2042,7 +2059,7 @@ class GeneralFusionMergePassHelper { VLOG(3) << "InitInputToConsumers...!"; // init input data node -> fusion group map. for (auto& group : fusion_groups_) { - for (auto& node : group->nodes_set) { + for (auto& node : group->ops_set) { // collect producer node data. for (size_t i = 0; i < node->num_operands(); ++i) { auto in = node->operand_source(i); @@ -2064,10 +2081,11 @@ class GeneralFusionMergePassHelper { belong_group->max_depth = group->depth; belong_group->min_depth = group->depth; belong_group->group_id = group->group_id; - belong_group->input_nodes = group->input_nodes; - belong_group->output_nodes = group->output_nodes; + belong_group->ops = group->ops; + belong_group->input_ops = group->input_ops; + belong_group->output_ops = group->output_ops; belong_group->op_pattern_kind = group->op_pattern_kind; - belong_group->master_nodes = group->master_nodes; + belong_group->master_ops = group->master_ops; (*belong_group->mut_producer_groups()) = group->producer_groups(); (*belong_group->mut_consumer_groups()) = group->consumer_groups(); belong_group->fused_sub_groups.push_back(group); @@ -2099,23 +2117,21 @@ class GeneralFusionMergePassHelper { } } - const ::pir::Program* graph_; GroupList fusion_groups_; std::unordered_map fusion_groups_index_; - std::unordered_set output_nodes_set_; + std::unordered_set output_ops_set_; std::unordered_map<::pir::Value, std::unordered_set> input_to_consumers_; }; -GroupList GeneralFusionMergePassInternal(const ::pir::Program* graph, - const GroupList& group_list) { +GroupList GeneralFusionMergePassInternal(const GroupList& group_list) { if (group_list.size() <= 1) { VLOG(3) << "Don't do Fusoin Merge Pass...!"; return group_list; } - GeneralFusionMergePassHelper fusion_merge_pass_helper(graph, group_list); + GeneralFusionMergePassHelper fusion_merge_pass_helper(group_list); auto res = fusion_merge_pass_helper(); return res; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass_utils.h b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass_utils.h index 19ea891531b87..f38042569191c 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass_utils.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass_utils.h @@ -111,7 +111,7 @@ static bool limit_args(const OpGroupPtr& first, const OpGroupPtr& second) { } bool WithoutLastDimInReduce(const phi::DDim& inshape, - const std::vector& axes) { + const std::vector& axes) { // if last axis is in reduce. if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || std::find(axes.begin(), axes.end(), -1) != axes.end()) { @@ -132,10 +132,8 @@ bool WithoutLastDimInReduce(const phi::DDim& inshape, static int GetSharedSize(const cinn::dialect::ir::OpNode& op_node) { const auto& inshape = op_node.inputs()[0].shape(); - // const auto& axes = op_node.GetAttr>("dim"); - // const auto& axes = op_node.Op()->attributes().at("dim").dyn_cast<> - // TODO(phlrain): get vector from attribute - std::vector axes = {1}; + const auto& axes = op_node.GetAttr>("dim"); + if (WithoutLastDimInReduce(inshape, axes)) { int lane = 1; for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { @@ -179,14 +177,15 @@ static int GetSharedSize(const cinn::dialect::ir::OpNode& op_node) { return 0; } -static bool ReduceFuseReduce(const OpGroupPtr& first, - const OpGroupPtr& second) { - if (!limit_args(first, second)) { - return false; - } +static bool ReduceFuseReduce1(const OpGroupPtr& first, + const OpGroupPtr& second) { + // return false; + // if (!limit_args(first, second)) { + // return false; + // } std::unique_ptr reducer_0 = nullptr; first.WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) { - if (!reducer_0 && op.kind() == kReduction) { + if (!reducer_0 && op.kind() == OpPatternKind::kReduction) { reducer_0.reset(new cinn::dialect::ir::OpNode(op)); } }); @@ -194,7 +193,7 @@ static bool ReduceFuseReduce(const OpGroupPtr& first, std::unique_ptr reducer_1 = nullptr; second.WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) { - if (!reducer_1 && op.kind() == kReduction) { + if (!reducer_1 && op.kind() == OpPatternKind::kReduction) { reducer_1.reset(new cinn::dialect::ir::OpNode(op)); } }); @@ -208,12 +207,8 @@ static bool ReduceFuseReduce(const OpGroupPtr& first, const auto& reducer_1_input_shape = reducer_1->inputs()[0].shape(); const auto& reducer_1_output_shape = reducer_1->outputs()[0].shape(); - // TODO(phlrain): get attribute from op node - // auto reducer_0_reduce_dim = reducer_0->GetAttr>("dim"); - // auto reducer_1_reduce_dim = reducer_1->GetAttr>("dim"); - - std::vector reducer_0_reduce_dim = {0}; - std::vector reducer_1_reduce_dim = {0}; + auto reducer_0_reduce_dim = reducer_0->GetAttr>("dim"); + auto reducer_1_reduce_dim = reducer_1->GetAttr>("dim"); for (auto& dim : reducer_0_reduce_dim) { // if dim = -1, set as shape.size() - 1 diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_util.h b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_util.h index 1b8f5b6aeacd7..458d2fda1ed8f 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_util.h @@ -40,11 +40,11 @@ inline bool limit_args(const std::shared_ptr& first, const std::shared_ptr& second) { std::unordered_set args; for (auto& group : {first, second}) { - for (auto node : group->input_nodes) { - args.insert(node.first); + for (auto iter : group->input_ops) { + args.insert(iter.first); } - for (auto node : group->output_nodes) { - args.insert(node); + for (auto op : group->output_ops) { + args.insert(op); } } @@ -66,8 +66,8 @@ inline bool is_same_shape(const std::shared_ptr& first, return false; } - auto output_var_0 = GetValueShape((*first->master_nodes.begin())->result(0)); - auto output_var_1 = GetValueShape((*second->master_nodes.begin())->result(0)); + auto output_var_0 = GetValueShape((*first->master_ops.begin())->result(0)); + auto output_var_1 = GetValueShape((*second->master_ops.begin())->result(0)); return output_var_0 == output_var_1; } @@ -77,8 +77,8 @@ inline bool is_same_size(const std::shared_ptr& first, return false; } - auto output_var_0 = GetValueShape((*first->master_nodes.begin())->result(0)); - auto output_var_1 = GetValueShape((*second->master_nodes.begin())->result(0)); + auto output_var_0 = GetValueShape((*first->master_ops.begin())->result(0)); + auto output_var_1 = GetValueShape((*second->master_ops.begin())->result(0)); if (output_var_0 == output_var_1) { return true; } @@ -89,8 +89,8 @@ inline bool is_same_size(const std::shared_ptr& first, } inline bool is_const_group(const std::shared_ptr& group) { - return group->CollectNodes().size() == 1 && - ConstantOps.count(group->CollectNodes()[0]->name()); + return group->CollectOps().size() == 1 && + ConstantOps.count(group->CollectOps()[0]->name()); } inline bool elementwise_fuse_broadcast( @@ -105,9 +105,9 @@ inline bool elementwise_fuse_broadcast( return true; } // if first's output is not all in second's input - for (auto output : first->output_nodes) { + for (auto output : first->output_ops) { return true; - if (!second->input_nodes.count(output)) { + if (!second->input_ops.count(output)) { return false; } @@ -130,7 +130,7 @@ inline bool honrizontal_elementwise_fuse_reduce( const std::shared_ptr& first, const std::shared_ptr& second) { std::shared_ptr ele_group, reduce_group; - if (first->op_pattern_kind == kReduction) { + if (first->op_pattern_kind == OpPatternKind::kReduction) { ele_group = second; reduce_group = first; } else { @@ -143,10 +143,10 @@ inline bool honrizontal_elementwise_fuse_reduce( } auto ele_node_shape = - GetValueShape((*ele_group->master_nodes.begin())->result(0)); + GetValueShape((*ele_group->master_ops.begin())->result(0)); int32_t size_ele = phi::product(ele_node_shape); // TODO(phlrain): seems extrame danger herem, why compare multi Master Node? - for (auto* master : reduce_group->master_nodes) { + for (auto* master : reduce_group->master_ops) { auto master_node_shape = GetValueShape(master->result(0)); int32_t size_master = phi::product(master_node_shape); if (size_ele == size_master) { @@ -169,9 +169,9 @@ inline bool elementwise_fuse_reduce(const std::shared_ptr& first, // if reduce nodes not in consumers of first group std::queue<::pir::Operation*> candidates; - std::unordered_set<::pir::Operation*> first_node_set = first->NodeSet(); - std::unordered_set<::pir::Operation*> second_node_set = second->NodeSet(); - for (const auto& pair : second->input_nodes) { + std::unordered_set<::pir::Operation*> first_node_set = first->OpSet(); + std::unordered_set<::pir::Operation*> second_node_set = second->OpSet(); + for (const auto& pair : second->input_ops) { if (first_node_set.find(pair.first) != first_node_set.end()) { candidates.push(pair.first); } @@ -195,7 +195,7 @@ inline bool elementwise_fuse_reduce(const std::shared_ptr& first, visited.insert(consumer); candidates.push(consumer); } - if (second->master_nodes.count(consumer)) { + if (second->master_ops.count(consumer)) { masters_in_consumers.insert(consumer); } } @@ -203,7 +203,7 @@ inline bool elementwise_fuse_reduce(const std::shared_ptr& first, if (!masters_in_consumers.empty()) { bool flag = true; auto first_node_shape = - GetValueShape((*first->master_nodes.begin())->result(0)); + GetValueShape((*first->master_ops.begin())->result(0)); int32_t size_first = phi::product(first_node_shape); for (::pir::Operation* master : masters_in_consumers) { @@ -221,8 +221,8 @@ inline bool elementwise_fuse_reduce(const std::shared_ptr& first, // if reduce using block_reduce, can't fuse producer. ::pir::Operation* reducer = nullptr; - for (auto& node : second->master_nodes) { - if (GetOpKind(node->name()) == kReduction) { + for (auto& node : second->master_ops) { + if (GetOpKind(node->name()) == OpPatternKind::kReduction) { reducer = node; break; } @@ -240,7 +240,7 @@ inline bool elementwise_fuse_reduce(const std::shared_ptr& first, // } auto input_shape = GetValueShape(reducer->operand_source(0)); - std::vector reduce_axes = GetVectorAttr(reducer, "axis"); + auto reduce_axes = GetVectorAttr(reducer, "dim"); // int max_num_threads = helper->target_.max_num_threads(); int max_num_threads = 1000; @@ -289,8 +289,8 @@ inline bool broadcast_fuse_reduce(const std::shared_ptr& first, return true; } ::pir::Operation* reducer = nullptr; - for (auto& node : second->master_nodes) { - if (GetOpKind(node->name()) == kReduction) { + for (auto& node : second->master_ops) { + if (GetOpKind(node->name()) == OpPatternKind::kReduction) { reducer = node; break; } @@ -300,7 +300,7 @@ inline bool broadcast_fuse_reduce(const std::shared_ptr& first, auto input_shape = GetValueShape(reducer->operand_source(0)); auto input_size = phi::product(input_shape); - auto output_shape = GetValueShape((*first->master_nodes.begin())->result(0)); + auto output_shape = GetValueShape((*first->master_ops.begin())->result(0)); auto output_size = phi::product(output_shape); if (input_size == output_size) { @@ -325,10 +325,9 @@ inline bool horizontal_relation(const std::shared_ptr& first, const OpPatternKind op_pattern_kind) { // merge injective auto merge_nodes_set = [](const std::shared_ptr& group) { - std::unordered_set<::pir::Operation*> nodes_set = group->nodes_set; + std::unordered_set<::pir::Operation*> nodes_set = group->ops_set; for (auto& sub_group : group->fused_sub_groups) { - nodes_set.insert(sub_group->nodes_set.begin(), - sub_group->nodes_set.end()); + nodes_set.insert(sub_group->ops_set.begin(), sub_group->ops_set.end()); } return nodes_set; }; @@ -398,14 +397,14 @@ inline bool horizontal_with_injective( if (!is_same_size(first, second)) { return false; } - return horizontal_relation(first, second, kInjective); + return horizontal_relation(first, second, OpPatternKind::kInjective); } inline bool injective_horizontal_with_reduce( const std::shared_ptr& first, const std::shared_ptr& second) { // check injective with injective. - if (!horizontal_relation(first, second, kInjective)) { + if (!horizontal_relation(first, second, OpPatternKind::kInjective)) { return false; } return elementwise_fuse_reduce(first, second); @@ -424,8 +423,8 @@ inline bool reduce_fuse_broadcast(const std::shared_ptr& first, // each reducer and its consumers with type of Broadcast needs to meet. It is // required that each consumer of type Broadcast meet the same shape after // broadcast as before reduce. - for (auto& node_in_master : first->master_nodes) { - if (GetOpKind(node_in_master->name()) != kReduction) { + for (auto& node_in_master : first->master_ops) { + if (GetOpKind(node_in_master->name()) != OpPatternKind::kReduction) { continue; } ::pir::Operation* reducer = node_in_master; @@ -435,7 +434,7 @@ inline bool reduce_fuse_broadcast(const std::shared_ptr& first, phi::vectorize(GetValueShape(reducer->operand_source(0))); auto reducer_output_shape = phi::vectorize(GetValueShape(reducer->result(0))); - std::vector reduce_axes = GetVectorAttr(reducer, "axis"); + std::vector reduce_axes = GetVectorAttr(reducer, "dim"); auto keep_dim = false; for (auto& axis : reduce_axes) { @@ -480,8 +479,8 @@ inline bool reduce_fuse_broadcast(const std::shared_ptr& first, visited_set.insert(consumer); candidates.push(consumer); } - if (GetOpKind(consumer->name()) == kBroadcast && - second->NodeSet().find(consumer) != second->NodeSet().end()) { + if (GetOpKind(consumer->name()) == OpPatternKind::kBroadcast && + second->OpSet().find(consumer) != second->OpSet().end()) { broadcasters.insert(consumer); } } @@ -543,8 +542,8 @@ inline bool reduce_fuse_reduce(const std::shared_ptr& first, return false; } ::pir::Operation* reducer_0 = nullptr; - for (auto& reducer : first->master_nodes) { - if (GetOpKind(reducer->name()) == kReduction) { + for (auto& reducer : first->master_ops) { + if (GetOpKind(reducer->name()) == OpPatternKind::kReduction) { reducer_0 = reducer; break; } @@ -552,8 +551,8 @@ inline bool reduce_fuse_reduce(const std::shared_ptr& first, // CHECK(reducer_0) << "Can't find reduce op in group " << first->group_id; ::pir::Operation* reducer_1 = nullptr; - for (auto& reducer : second->master_nodes) { - if (GetOpKind(reducer->name()) == kReduction) { + for (auto& reducer : second->master_ops) { + if (GetOpKind(reducer->name()) == OpPatternKind::kReduction) { reducer_1 = reducer; break; } @@ -566,13 +565,8 @@ inline bool reduce_fuse_reduce(const std::shared_ptr& first, auto reducer_1_input_shape = GetValueShape(reducer_1->operand_source(0)); auto reducer_1_output_shape = GetValueShape(reducer_1->result(0)); - // auto reducer_0_reduce_dim = - // absl::get>(reducer_0->attrs.attr_store.at("dim")); - // auto reducer_1_reduce_dim = - // absl::get>(reducer_1->attrs.attr_store.at("dim")); - // TODO(phlrain) - std::vector reducer_0_reduce_dim = GetVectorAttr(reducer_0, "axis"); - std::vector reducer_1_reduce_dim = GetVectorAttr(reducer_1, "axis"); + auto reducer_0_reduce_dim = GetVectorAttr(reducer_0, "dim"); + auto reducer_1_reduce_dim = GetVectorAttr(reducer_1, "dim"); for (auto& dim : reducer_0_reduce_dim) { // if dim = -1, set as shape.size() - 1 @@ -594,8 +588,8 @@ inline bool reduce_fuse_reduce(const std::shared_ptr& first, reducer_0_reduce_dim == reducer_1_reduce_dim) { auto shared_size = 0; for (auto& fusion_group : {first, second}) { - for (auto* master : fusion_group->master_nodes) { - if (GetOpKind(master->name()) == kReduction) { + for (auto* master : fusion_group->master_ops) { + if (GetOpKind(master->name()) == OpPatternKind::kReduction) { shared_size += GetSharedSize(master); } } @@ -615,8 +609,8 @@ inline bool reduce_fuse_reduce(const std::shared_ptr& first, reducer_0_reduce_dim == reducer_1_reduce_dim) { auto shared_size = 0; for (auto& fusion_group : {first, second}) { - for (auto* master : fusion_group->master_nodes) { - if (GetOpKind(master->name()) == kReduction) { + for (auto* master : fusion_group->master_ops) { + if (GetOpKind(master->name()) == OpPatternKind::kReduction) { shared_size += GetSharedSize(master); } } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/op_group.h b/paddle/cinn/hlir/dialect/operator/transforms/op_group.h index 87138df17be85..4914d80f75709 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/op_group.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/op_group.h @@ -153,7 +153,7 @@ class OpGroup { // group.WalkOpNodes(get_reduce_op); void WalkOpNodes( const std::function& VisitOpNode) const { - group_.lock()->WalkNodes( + group_.lock()->WalkOps( [&](::pir::Operation* node) { VisitOpNode(OpNode(node)); }); } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/op_node.h b/paddle/cinn/hlir/dialect/operator/transforms/op_node.h index 8579d11b19bb9..d7f0542a3bec9 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/op_node.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/op_node.h @@ -31,11 +31,11 @@ class OpNode { OpPatternKind kind() const { auto kind = GetOpKind(node_->name()); - if (kind == kBroadcast) { + if (kind == OpPatternKind::kBroadcast) { // As binary op was defined as broadcast, actually it should be // element-wise. if (node_->name() != "broadcast_to") { - return kElementWise; + return OpPatternKind::kElementWise; } } return kind; @@ -137,10 +137,10 @@ class OpNode { const OutputTensorListView& outputs() const { return output_tensors_; } template - const T& GetAttr(const std::string& attr_name) const { + T GetAttr(const std::string& attr_name) const { auto attr = paddle::dialect::GetAttributeData(node_->attributes().at(attr_name)); - return PADDLE_GET_CONST(T, attr); + return paddle::get(attr); } private: diff --git a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc index 3039d81ff83a3..1e5a5965005f1 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h" + #include #include #include @@ -19,8 +21,6 @@ #include #include -#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h" - #include "paddle/phi/core/enforce.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/operation.h" @@ -40,6 +40,8 @@ std::unordered_map OpKindMap = { {"pd_op.full", OpPatternKind::kElementWise}, {"pd_op.relu", OpPatternKind::kElementWise}, {"pd_op.exp", OpPatternKind::kElementWise}, + {"pd_op.sin", OpPatternKind::kElementWise}, + {"pd_op.cos", OpPatternKind::kElementWise}, {"pd_op.sum", OpPatternKind::kReduction}, {"cinn_op.reduce_sum", OpPatternKind::kReduction}, {"cinn_op.reduce_max", OpPatternKind::kReduction}, @@ -61,7 +63,7 @@ phi::DDim GetFirstInputShape(const ::pir::Operation* op) { return in.type().dyn_cast().dims(); } -phi::DDim GetValueShape(const ::pir::Value value) { +phi::DDim GetValueShape(const ::pir::Value& value) { return value.type().dyn_cast().dims(); } @@ -85,10 +87,10 @@ bool WithoutLastDimInReduce(const std::vector& inshape, } } -int GetSharedSize(::pir::Operation* node) { - auto inshape = phi::vectorize(GetValueShape(node->result(0))); +int GetSharedSize(::pir::Operation* op) { + auto inshape = phi::vectorize(GetValueShape(op->result(0))); - auto axes = GetVectorAttr(node, "axis"); + auto axes = GetVectorAttr(op, "dim"); if (WithoutLastDimInReduce(inshape, axes)) { int lane = 1; @@ -143,49 +145,47 @@ using ConditionFunction = // code generation. class OpFusionPassHelper { public: - explicit OpFusionPassHelper(const ::pir::Program& graph) { + explicit OpFusionPassHelper(const std::vector& op_list) { // init fusion relation InitFusionRelation(); - // filter node data, create group for each node - // auto nodes_inorder = std::get<0>(graph->topological_order()); + // filter op data, create group for each op + // auto ops_inorder = std::get<0>(graph->topological_order()); - for (auto it = graph.block()->begin(); it != graph.block()->end(); ++it) { - auto node = *it; - local_ops_.insert(node); + for (auto it = op_list.begin(); it != op_list.end(); ++it) { + local_ops_.insert(*it); } int index = 0; - for (auto it = graph.block()->begin(); it != graph.block()->end(); ++it) { - auto node = *it; - if (node) { - nodes_.push_back(node); + for (auto it = op_list.begin(); it != op_list.end(); ++it) { + auto op = *it; + if (op) { + ops_.push_back(op); auto group = std::make_shared(); // init group - group->nodes.push_back(node); - group->nodes_set.insert(node); - group->output_nodes.insert(node); - // input node - - for (size_t i = 0; i < node->num_operands(); ++i) { - auto input = - node->operand_source(i).dyn_cast().owner(); + group->ops.push_back(op); + group->ops_set.insert(op); + group->output_ops.insert(op); + // input op + + for (size_t i = 0; i < op->num_operands(); ++i) { + auto input = op->operand_source(i).dyn_cast().owner(); if (input && (local_ops_.count(input))) { - group->input_nodes[input] = 1; + group->input_ops[input] = 1; } } // group type - group->op_pattern_kind = GetOpKind(node->name()); - // use current node as master node for schedule - group->master_nodes.insert(node); + group->op_pattern_kind = GetOpKind(op->name()); + // use current op as master op for schedule + group->master_ops.insert(op); // get opration unique id group->group_id = "id_" + std::to_string(index++); - fusion_groups_[node] = group; + fusion_groups_[op] = group; } } - // reverse node for output to input - std::reverse(nodes_.begin(), nodes_.end()); + // reverse op for output to input + std::reverse(ops_.begin(), ops_.end()); } // return a vector of groups in topological order. @@ -198,23 +198,23 @@ class OpFusionPassHelper { // find all fusion group. GroupList fusion_groups; std::unordered_set groups_set; - for (auto node : nodes_) { - auto& group = fusion_groups_[node]; + for (auto op : ops_) { + auto& group = fusion_groups_[op]; if (!groups_set.count(group.get())) { groups_set.insert(group.get()); fusion_groups.push_back(group); - // reverse nodes order to producer->consumer. - std::reverse(group->nodes.begin(), group->nodes.end()); + // reverse ops order to producer->consumer. + std::reverse(group->ops.begin(), group->ops.end()); } } // producer consumer for (auto& consumer : fusion_groups) { - for (auto& input_node : consumer->input_nodes) { - if (!local_ops_.count(input_node.first)) { + for (auto& input_op : consumer->input_ops) { + if (!local_ops_.count(input_op.first)) { continue; } - auto& producer = fusion_groups_[input_node.first]; + auto& producer = fusion_groups_[input_op.first]; consumer->mut_producer_groups()->insert(producer); producer->mut_consumer_groups()->insert(consumer); } @@ -236,16 +236,16 @@ class OpFusionPassHelper { private: void DoOpFusion() { - for (auto consumer : nodes_) { + for (auto consumer : ops_) { auto consumer_kind = GetOpKind(consumer->name()); // kNonFusible op can't fuse any other op. - if (consumer_kind == kNonFusible) { + if (consumer_kind == OpPatternKind::kNonFusible) { continue; } // fusion op for consumer auto consumer_fusion = fusion_groups_[consumer]; // - // check all linkin node + // check all linkin op for (size_t i = 0; i < consumer->num_operands(); ++i) { auto producer_data = consumer->operand_source(i); @@ -255,7 +255,7 @@ class OpFusionPassHelper { } // if producer is fused. - if (consumer_fusion->nodes_set.count(producer)) { + if (consumer_fusion->ops_set.count(producer)) { // VLOG(3) << "Op " << producer->id() << " is fused."; continue; } @@ -265,7 +265,7 @@ class OpFusionPassHelper { } // kNonFusible op can't fuse any other op. auto producer_kind = GetOpKind(producer->name()); - if (producer_kind == kNonFusible) { + if (producer_kind == OpPatternKind::kNonFusible) { continue; } // VLOG(3) << "Producer Op: " << producer->id() @@ -273,17 +273,17 @@ class OpFusionPassHelper { // << " -> Consumer Op: " << consumer->id() // << ", Op Pattern: " << consumer_kind; bool can_fuse = true; - // checkout producer node outputs are all in fusion op + // checkout producer op outputs are all in fusion op // find all the op use by size_t producer_data_used_num = 0; for (auto it = producer_data.use_begin(); it != producer_data.use_end(); ++it) { - auto consumer_node = it->owner(); + auto consumer_op = it->owner(); producer_data_used_num++; - // if fusion group can't find node, can't merge - if (consumer_fusion->nodes_set.find(consumer_node) == - consumer_fusion->nodes_set.end()) { + // if fusion group can't find op, can't merge + if (consumer_fusion->ops_set.find(consumer_op) == + consumer_fusion->ops_set.end()) { can_fuse = false; break; } @@ -299,39 +299,39 @@ class OpFusionPassHelper { // producer->id() + "_" + consumer_fusion->group_id; consumer_fusion->group_id = consumer_fusion->group_id; - consumer_fusion->nodes.push_back(producer); - consumer_fusion->nodes_set.insert(producer); - consumer_fusion->input_nodes.erase(producer); + consumer_fusion->ops.push_back(producer); + consumer_fusion->ops_set.insert(producer); + consumer_fusion->input_ops.erase(producer); consumer_fusion->op_pattern_kind = static_cast(consumer_fusion->op_pattern_kind) > static_cast(producer_kind) ? consumer_fusion->op_pattern_kind : producer_kind; - if (producer_kind == kReduction) { - consumer_fusion->master_nodes.insert(producer); + if (producer_kind == OpPatternKind::kReduction) { + consumer_fusion->master_ops.insert(producer); } - if (output_nodes_set_.count(producer)) { + if (output_ops_set_.count(producer)) { // VLOG(3) << "Insert Global Output Node : " << producer->id(); - consumer_fusion->output_nodes.insert(producer); + consumer_fusion->output_ops.insert(producer); } else if (producer_data_used_num > 1 && producer->num_operands() > 0 && is_same_size(producer, consumer_fusion)) { - // producer is not a const value node. - consumer_fusion->internal_nodes.insert(producer); + // producer is not a const value op. + consumer_fusion->internal_ops.insert(producer); } - // fuse input node + // fuse input op auto producer_fusion = fusion_groups_[producer]; - for (auto input_node : producer_fusion->input_nodes) { - if (consumer_fusion->input_nodes.count(input_node.first)) { - consumer_fusion->input_nodes[input_node.first] += input_node.second; + for (auto input_op : producer_fusion->input_ops) { + if (consumer_fusion->input_ops.count(input_op.first)) { + consumer_fusion->input_ops[input_op.first] += input_op.second; } else { - consumer_fusion->input_nodes.insert(input_node); + consumer_fusion->input_ops.insert(input_op); } } - // update node group + // update op group fusion_groups_[producer] = consumer_fusion; } } @@ -343,119 +343,127 @@ class OpFusionPassHelper { { FusionRelation relation; // producer -> consumer - relation.op_kind = {kElementWise, kBroadcast, kReduction, kInjective}; + relation.op_kind = {OpPatternKind::kElementWise, + OpPatternKind::kBroadcast, + OpPatternKind::kReduction, + OpPatternKind::kInjective}; // producer -> fusion relation.fusion_op_kind = { // horizontal or vertical relation(Elementwise + *Elementwise*). As // has same output shape, can always fuse. - {kElementWise, always_fuse}, + {OpPatternKind::kElementWise, always_fuse}, // must be horizontal, as Elementwise + Broadcast is left to fusion // merge pass. - {kBroadcast, + {OpPatternKind::kBroadcast, [](::pir::Operation* producer, const GroupPtr& consumer) -> bool { // NOTE, producer and consumer NEVER be same size if (is_same_size(producer, consumer)) { return true; } - // NOTE, original code is below, if produer is not output node, + // NOTE, original code is below, if produer is not output op, // result always be true - // !helper->output_nodes_set_.count(producer); + // !helper->output_ops_set_.count(producer); return true; }}, // horizontal or vertical relation, check with same output shape with // horizontal relation or with last // successive dimension less than 1024 for gpu. - {kReduction, horizontal_or_vertical_reduce_relation}, + {OpPatternKind::kReduction, horizontal_or_vertical_reduce_relation}, // can be horizontal or can compute inline, check with same output // shape or can compute inline. - {kInjective, horizontal_or_can_inline}, + {OpPatternKind::kInjective, horizontal_or_can_inline}, // must be horizontal, check with same output shape. - {kOutFusible, is_same_shape}}; - fusion_relation_map_[kElementWise] = std::move(relation); + {OpPatternKind::kOutFusible, is_same_shape}}; + fusion_relation_map_[OpPatternKind::kElementWise] = std::move(relation); } // 2.kBroadcast as producer { FusionRelation relation; // producer -> consumer - relation.op_kind = {kElementWise, kReduction, kInjective}; + relation.op_kind = {OpPatternKind::kElementWise, + OpPatternKind::kReduction, + OpPatternKind::kInjective}; // producer -> fusion relation.fusion_op_kind = { // horizontal or vertical relation(Broadcast + *Elementwise*), check // with same output shape. - {kElementWise, is_same_size}, + {OpPatternKind::kElementWise, is_same_size}, // must be horizontal, as Broadcast + Broadcast is not allowed. - {kBroadcast, is_same_size}, + {OpPatternKind::kBroadcast, is_same_size}, // horizontal or vertical relation(Broadcast + Reduce). - {kReduction, horizontal_or_vertical_reduce_relation}, + {OpPatternKind::kReduction, horizontal_or_vertical_reduce_relation}, // can be horizontal or can compute inline, check with same output // shape or just one consumer. - {kInjective, horizontal_or_can_inline}, + {OpPatternKind::kInjective, horizontal_or_can_inline}, // must be horizontal, check with same output shape. - {kOutFusible, is_same_shape}}; - fusion_relation_map_[kBroadcast] = std::move(relation); + {OpPatternKind::kOutFusible, is_same_shape}}; + fusion_relation_map_[OpPatternKind::kBroadcast] = std::move(relation); } // 3.kReduction as producer { FusionRelation relation; // producer -> consumer - relation.op_kind = {kElementWise, kBroadcast}; + relation.op_kind = {OpPatternKind::kElementWise, + OpPatternKind::kBroadcast}; // producer -> fusion relation.fusion_op_kind = { // horizontal or vertical relation(Reduce + Elementwise*), check // without last dimension in reduce. - {kElementWise, is_same_size}, + {OpPatternKind::kElementWise, is_same_size}, // must be horizontal relation, check with same output shape and // without last dimension in reduce. - {kBroadcast, reduce_fuse_broadcast}, + {OpPatternKind::kBroadcast, reduce_fuse_broadcast}, // must be horizontal relation and with same reduce attr. - {kReduction, reduce_fuse_reduce}, + {OpPatternKind::kReduction, reduce_fuse_reduce}, // no_fuse - {kInjective, no_fuse}, + {OpPatternKind::kInjective, no_fuse}, // can't fuse. - {kOutFusible, no_fuse}}; - fusion_relation_map_[kReduction] = std::move(relation); + {OpPatternKind::kOutFusible, no_fuse}}; + fusion_relation_map_[OpPatternKind::kReduction] = std::move(relation); } // 4.kInjective { FusionRelation relation; // producer -> consumer - relation.op_kind = {kElementWise, kInjective}; + relation.op_kind = {OpPatternKind::kElementWise, + OpPatternKind::kInjective}; // producer -> fusion relation.fusion_op_kind = { // can be horizontal or vertical(Injective + Elementwise), check with // same output shape. - {kElementWise, is_same_size}, + {OpPatternKind::kElementWise, is_same_size}, // must be horizontal relation, check with same output shape. - {kBroadcast, horizontal_with_same_size}, + {OpPatternKind::kBroadcast, horizontal_with_same_size}, // left to fusion merge pass. - {kReduction, no_fuse}, + {OpPatternKind::kReduction, no_fuse}, // must be horizontal relation, check with same output shape. - {kInjective, horizontal_or_can_inline}, + {OpPatternKind::kInjective, horizontal_or_can_inline}, // can't fuse. - {kOutFusible, no_fuse}, + {OpPatternKind::kOutFusible, no_fuse}, }; - fusion_relation_map_[kInjective] = std::move(relation); + fusion_relation_map_[OpPatternKind::kInjective] = std::move(relation); } // 5.kOutFusible { FusionRelation relation; // producer -> consumer - relation.op_kind = {kElementWise, kBroadcast}; + relation.op_kind = {OpPatternKind::kElementWise, + OpPatternKind::kBroadcast}; // producer -> fusion relation.fusion_op_kind = { // horizontal or vertical relation, check has same shape. - {kElementWise, is_same_shape}, + {OpPatternKind::kElementWise, is_same_shape}, // it must be horizontal relation, check has same shape. - {kBroadcast, is_same_shape}, + {OpPatternKind::kBroadcast, is_same_shape}, // can't fuse. - {kReduction, no_fuse}, + {OpPatternKind::kReduction, no_fuse}, // must be horizontal relation, check has same shape. - {kInjective, is_same_shape}, + {OpPatternKind::kInjective, is_same_shape}, // can't fuse. - {kOutFusible, no_fuse}, + {OpPatternKind::kOutFusible, no_fuse}, }; - fusion_relation_map_[kOutFusible] = std::move(relation); + fusion_relation_map_[OpPatternKind::kOutFusible] = std::move(relation); } } @@ -474,9 +482,9 @@ class OpFusionPassHelper { return false; } - std::vector<::pir::Operation*> nodes_; + std::vector<::pir::Operation*> ops_; std::unordered_map fusion_groups_; - std::unordered_set output_nodes_set_; + std::unordered_set output_ops_set_; std::vector> groups_; @@ -491,38 +499,16 @@ class OpFusionPassHelper { std::unordered_map fusion_relation_map_; }; -GroupList OpFusionPassInternal(const ::pir::Program& program) { +GroupList OpFusionPassInternal(const std::vector& op_list) { VLOG(3) << "OpFusionPass...!"; - auto op_fusion_helper = OpFusionPassHelper(program); + auto op_fusion_helper = OpFusionPassHelper(op_list); auto res = op_fusion_helper(); - for (size_t i = 0; i < res.size(); ++i) { - auto group = res[i]; - - for (size_t j = 0; j < group->nodes.size(); ++j) { - } - } - - // for (auto& group : graph->fusion_groups) { - // VLOG(3) << "Group Id : " << group->group_id; - // for (const auto& producer : group->producer_groups()) { - // VLOG(3) << " producer group -> " << producer->group_id; - // } - // for (const auto& consumer : group->consumer_groups()) { - // VLOG(3) << " consumer group -> " << consumer->group_id; - // } - // } VLOG(3) << "OpFusionPass Finish...!"; return res; } -// void BuildNonFusedGroupsPassInternal(framework::Graph* graph) { -// auto op_fusion_helper = OpFusionPassHelper(graph); -// VLOG(3) << "Apply OpFusionPass to generate initial non-fusion groups"; -// graph->fusion_groups = op_fusion_helper(false); -// } - } // namespace ir } // namespace dialect } // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h index c784140c1cf36..d9e07273791fe 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h @@ -24,10 +24,9 @@ namespace ir { using GroupPtr = std::shared_ptr; using GroupList = std::vector; -GroupList OpFusionPassInternal(const ::pir::Program& program); +GroupList OpFusionPassInternal(const std::vector& op_list); -GroupList GeneralFusionMergePassInternal(const ::pir::Program* graph, - const GroupList& group_list); +GroupList GeneralFusionMergePassInternal(const GroupList& group_list); } // namespace ir } // namespace dialect diff --git a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h index 1ba6ba85b5158..2864a8a5a142d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h @@ -21,6 +21,7 @@ #include #include +#include "paddle/cinn/hlir/framework/pir/group.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/pir/core/operation.h" @@ -29,33 +30,9 @@ namespace cinn { namespace dialect { namespace ir { - -enum OpPatternKind { - // The relation between input tensor index and output tensor index is - // one-to-one correspondence. - // for example :code:`out[i, j] = input[i, j] + 1`. - // Note that the axis need to be in order. - kElementWise = 0, - // The relation between input tensor index and output tensor index is - // one-to-many correspondence. - // for example :code:`out[i, j, k] = input[i, j]`. - // Note that the axis need to be in order. - kBroadcast = 1, - // Injective operator, we can always injectively map a output axis to a input - // axis. - // for example :code:`out[i, j] = input[j, i]`. - kInjective = 2, - // The relation between input tensor index and output tensor index is - // many-to-one correspondence. - // for example :code:`out[i, j] = sum(input[i, j, k]) along k`. - kReduction = 3, - // Complex operation, can still fuse one-to-one operations into its output. - kOutFusible = 4, - // Operation that cannot fuse anything. - kNonFusible = 8 -}; - -OpPatternKind GetOpKind(const std::string& op_name); +// alias OpPatternKind and pir::Group +using OpPatternKind = hlir::framework::OpPatternKind; +using Group = hlir::framework::pir::Group; template std::vector GetVectorAttr(const ::pir::Operation* op, @@ -84,147 +61,16 @@ std::vector GetVectorAttr(const ::pir::Operation* op, return vec_res; } -struct Group { - Group() = default; - - // distance to last group. - int depth{0}; - int max_depth{0}; - int min_depth{INT_MAX}; - // group id, consisted of node's id. - std::string group_id{""}; - // global unique id. - std::string unique_id{"uniq"}; - // node in this group - std::vector<::pir::Operation*> nodes; - std::unordered_set<::pir::Operation*> nodes_set; - // input nodes of the group. - std::unordered_map<::pir::Operation*, int> input_nodes; - // output nodes of the group. - std::unordered_set<::pir::Operation*> output_nodes; - // op pattern kind. - OpPatternKind op_pattern_kind{kElementWise}; - // internal node, the output is used by multi-node. - // internal node can't use compute inline, should use buffer. - std::unordered_set<::pir::Operation*> internal_nodes; - // master node for schedule - std::unordered_set<::pir::Operation*> master_nodes; - - // fused sub-groups, used for fusion merge pass - std::vector> fused_sub_groups; - // if as sub-group, used for belong groups. - std::unordered_set> belong_groups; - - // for op lowering. - std::vector input_names; - std::vector output_names; - - struct SharedGroupHasher { - size_t operator()(const std::shared_ptr& group) const noexcept { - return std::hash()(reinterpret_cast(group.get())); - } - }; - struct SharedGroupComparator { - bool operator()(const std::shared_ptr& first, - const std::shared_ptr& second) const noexcept { - return first.get() == second.get(); - } - }; - - std::vector<::pir::Operation*> CollectNodes() { - if (fused_sub_groups.size()) { - std::vector<::pir::Operation*> tmp_nodes; - for (auto& group : fused_sub_groups) { - tmp_nodes.insert( - tmp_nodes.end(), group->nodes.begin(), group->nodes.end()); - } - return tmp_nodes; - } else { - return nodes; - } - } - - void WalkNodes( - const std::function& VisitNode) const { - if (fused_sub_groups.size()) { - for (auto& group : fused_sub_groups) { - for (const auto& node : group->nodes) { - VisitNode(node); - } - } - } else { - for (const auto& node : nodes) { - VisitNode(node); - } - } - } - - std::unordered_set<::pir::Operation*> NodeSet() { - std::unordered_set<::pir::Operation*> node_set; - for (auto node : CollectNodes()) { - node_set.insert(node); - } - return node_set; - } - - // TODO(phlrain) : impliment GetInputNodeDatas GetOutputNodeDatas func - // std::unordered_set<::pir::Value> GetInputNodeDatas() { return {}; } - // std::unordered_set<::pir::Value> GetOutputNodeDatas() { return {}; } - - std::string GetFuncName() { return "fn_" + group_id + unique_id; } - - public: - const std::unordered_set, - SharedGroupHasher, - SharedGroupComparator>& - producer_groups() const { - return producer_groups_; - } - - const std::unordered_set, - SharedGroupHasher, - SharedGroupComparator>& - consumer_groups() const { - return consumer_groups_; - } - - std::unordered_set, - SharedGroupHasher, - SharedGroupComparator>* - mut_producer_groups() { - return &producer_groups_; - } - - std::unordered_set, - SharedGroupHasher, - SharedGroupComparator>* - mut_consumer_groups() { - return &consumer_groups_; - } - - OpPatternKind kind() const { return op_pattern_kind; } - - private: - // input groups - std::unordered_set, - SharedGroupHasher, - SharedGroupComparator> - producer_groups_; - // output grous - std::unordered_set, - SharedGroupHasher, - SharedGroupComparator> - consumer_groups_; -}; +OpPatternKind GetOpKind(const std::string& op_name); phi::DDim GetFirstInputShape(const ::pir::Operation* op); -phi::DDim GetValueShape(const ::pir::Value value); +phi::DDim GetValueShape(const ::pir::Value& value); bool WithoutLastDimInReduce(const std::vector& inshape, const std::vector& axes); -int GetSharedSize(::pir::Operation* node); +int GetSharedSize(::pir::Operation* op); inline bool always_fuse(::pir::Operation* producer, const std::shared_ptr& consumer) { @@ -238,16 +84,16 @@ inline bool no_fuse(::pir::Operation* producer, inline bool is_same_shape(::pir::Operation* producer, const std::shared_ptr& consumer) { - auto master_node = consumer->master_nodes.begin(); + auto master_op = consumer->master_ops.begin(); return GetValueShape(producer->result(0)) == - GetValueShape((*master_node)->result(0)); + GetValueShape((*master_op)->result(0)); } inline bool is_same_size(::pir::Operation* producer, const std::shared_ptr& consumer) { - auto master_node = consumer->master_nodes.begin(); + auto master_op = consumer->master_ops.begin(); auto producer_shape = GetValueShape(producer->result(0)); - auto consumer_shape = GetValueShape((*master_node)->result(0)); + auto consumer_shape = GetValueShape((*master_op)->result(0)); if (producer_shape == consumer_shape) { return true; } @@ -259,15 +105,15 @@ inline bool is_same_size(::pir::Operation* producer, inline bool without_last_dimension_in_reduce( ::pir::Operation* producer, const std::shared_ptr& consumer) { auto in_shape = phi::vectorize(GetFirstInputShape(producer)); - auto reduce_axes = GetVectorAttr(producer, "axis"); + auto reduce_axes = GetVectorAttr(producer, "dim"); return WithoutLastDimInReduce(in_shape, reduce_axes); } inline bool reduce_fuse_reduce(::pir::Operation* producer, const std::shared_ptr& consumer) { ::pir::Operation* reducer = NULL; - for (auto* master : consumer->master_nodes) { - if (GetOpKind(master->name()) == kReduction) { + for (auto* master : consumer->master_ops) { + if (GetOpKind(master->name()) == OpPatternKind::kReduction) { reducer = master; break; } @@ -283,8 +129,8 @@ inline bool reduce_fuse_reduce(::pir::Operation* producer, auto reducer_output_shape = phi::vectorize(GetValueShape(reducer->result(0))); - auto producer_reduce_dim = GetVectorAttr(producer, "axis"); - auto reducer_reduce_dim = GetVectorAttr(reducer, "axis"); + auto producer_reduce_dim = GetVectorAttr(producer, "dim"); + auto reducer_reduce_dim = GetVectorAttr(reducer, "dim"); for (auto& dim : producer_reduce_dim) { // if dim = -1, set as shape.size() - 1 @@ -309,8 +155,8 @@ inline bool reduce_fuse_reduce(::pir::Operation* producer, // check shape is same if (input_shape_same || without_last_dim) { auto shared_size = GetSharedSize(producer); - for (auto* master : consumer->master_nodes) { - if (GetOpKind(master->name()) == kReduction) { + for (auto* master : consumer->master_ops) { + if (GetOpKind(master->name()) == OpPatternKind::kReduction) { shared_size += GetSharedSize(master); } } @@ -328,30 +174,30 @@ inline bool reduce_fuse_reduce(::pir::Operation* producer, inline bool is_horizontal_relation(::pir::Operation* producer, const std::shared_ptr& consumer) { - auto check_depency = [&](::pir::Operation* node) { + auto check_depency = [&](::pir::Operation* op) { std::queue<::pir::Operation*> candidates; std::unordered_set<::pir::Operation*> visited_set; - candidates.push(node); + candidates.push(op); while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); - // visit all producer node + // visit all producer op for (size_t i = 0; i < candidate->num_operands(); ++i) { - auto tmp_node = + auto tmp_op = candidate->operand_source(i).dyn_cast().owner(); // check depency. - if (producer == tmp_node) { + if (producer == tmp_op) { return true; } - // check node is in region. - if (!consumer->nodes_set.count(tmp_node)) { + // check op is in region. + if (!consumer->ops_set.count(tmp_op)) { continue; } - // recored visited node. - if (!visited_set.count(tmp_node)) { - visited_set.insert(tmp_node); - candidates.push(tmp_node); + // recored visited op. + if (!visited_set.count(tmp_op)) { + visited_set.insert(tmp_op); + candidates.push(tmp_op); } } } @@ -359,11 +205,11 @@ inline bool is_horizontal_relation(::pir::Operation* producer, return false; }; - for (auto node : consumer->nodes_set) { - if (GetOpKind(node->name()) != consumer->op_pattern_kind) { + for (auto op : consumer->ops_set) { + if (GetOpKind(op->name()) != consumer->op_pattern_kind) { continue; } - if (check_depency(node)) { + if (check_depency(op)) { return false; } } @@ -378,18 +224,18 @@ inline bool horizontal_or_vertical_reduce_relation( return true; } - // reducer node in fusion op. + // reducer op in fusion op. ::pir::Operation* reducer = NULL; - for (auto* master : consumer->master_nodes) { - if (GetOpKind(master->name()) == kReduction) { + for (auto* master : consumer->master_ops) { + if (GetOpKind(master->name()) == OpPatternKind::kReduction) { reducer = master; break; } } - // check producer has same shape with reducer node. + // check producer has same shape with reducer op. auto reduce_shape = phi::vectorize(GetFirstInputShape(reducer)); - auto reduce_axes = GetVectorAttr(reducer, "axis"); + auto reduce_axes = GetVectorAttr(reducer, "dim"); for (auto& axis : reduce_axes) { // if axis = -1, set as shape.size() - 1 @@ -398,14 +244,14 @@ inline bool horizontal_or_vertical_reduce_relation( } } - auto node_shape = phi::vectorize(GetFirstInputShape(producer)); - auto node_size = std::accumulate( - node_shape.begin(), node_shape.end(), 1, std::multiplies()); + auto op_shape = phi::vectorize(GetFirstInputShape(producer)); + auto op_size = std::accumulate( + op_shape.begin(), op_shape.end(), 1, std::multiplies()); auto reduce_size = std::accumulate( reduce_shape.begin(), reduce_shape.end(), 1, std::multiplies()); // is not same size with reduce size. - if (node_size != reduce_size) { + if (op_size != reduce_size) { return false; } // check without last axis in reduce. @@ -440,18 +286,18 @@ inline bool horizontal_or_can_inline(::pir::Operation* producer, return true; } else { // if do broadcast, check can compute inline. - // return helper->output_nodes_set_.count(producer) == 0; - // TODO(phlrain): support output node set check + // return helper->output_ops_set_.count(producer) == 0; + // TODO(phlrain): support output op set check return false; } } // vertical relation: 1.can compute inline // if (helper->GetNodeData(producer)->outlinks().size() == 1 && - // helper->output_nodes_set_.count(producer) == 0) { + // helper->output_ops_set_.count(producer) == 0) { // return true; // } - // link to same node. + // link to same op. // auto& out_links = helper->GetNodeData(producer)->outlinks(); // for (auto link : out_links) { // if ((*out_links.begin())->sink() != link->sink()) { @@ -459,7 +305,7 @@ inline bool horizontal_or_can_inline(::pir::Operation* producer, // } // } - // return helper->output_nodes_set_.count(producer) == 0; + // return helper->output_ops_set_.count(producer) == 0; return false; } @@ -484,7 +330,7 @@ inline bool reduce_fuse_broadcast(::pir::Operation* producer, // } auto rinput_shape = phi::vectorize(GetFirstInputShape(producer)); - auto reduce_axes = GetVectorAttr(producer, "axis"); + auto reduce_axes = GetVectorAttr(producer, "dim"); auto keep_dim = producer->attributes() .at("keep_dim") .dyn_cast<::pir::BoolAttribute>() @@ -510,11 +356,11 @@ inline bool reduce_fuse_broadcast(::pir::Operation* producer, auto routput_shape = phi::vectorize(GetValueShape(producer->result(0))); auto find_reducer = - [&](::pir::Operation* node, + [&](::pir::Operation* op, ::pir::Operation* reducer, - const std::unordered_set<::pir::Operation*>& nodes_set) { + const std::unordered_set<::pir::Operation*>& ops_set) { std::queue<::pir::Operation*> candidates; - candidates.push(node); + candidates.push(op); while (!candidates.empty()) { auto candidate = candidates.front(); @@ -527,7 +373,7 @@ inline bool reduce_fuse_broadcast(::pir::Operation* producer, return true; } - if (nodes_set.count(producer)) { + if (ops_set.count(producer)) { candidates.push(producer); } } @@ -536,17 +382,17 @@ inline bool reduce_fuse_broadcast(::pir::Operation* producer, return false; }; - for (auto node : consumer->nodes_set) { - if (GetOpKind(node->name()) != kBroadcast) { + for (auto op : consumer->ops_set) { + if (GetOpKind(op->name()) != OpPatternKind::kBroadcast) { continue; } - if (!find_reducer(node, producer, consumer->nodes_set)) { + if (!find_reducer(op, producer, consumer->ops_set)) { continue; } - auto broadcast_shape = GetVectorAttr(node, "out_shape"); - auto broadcast_axes = GetVectorAttr(node, "broadcast_axes"); + auto broadcast_shape = GetVectorAttr(op, "out_shape"); + auto broadcast_axes = GetVectorAttr(op, "broadcast_axes"); for (auto& axis : broadcast_axes) { if (axis < 0) { diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc new file mode 100644 index 0000000000000..35bd62f8dba9d --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" + +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class SumOpPattern : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // Source Pattern + pir::drr::SourcePattern pattern = ctx->SourcePattern(); + const auto &full_int_array = + pattern.Op(paddle::dialect::FullIntArrayOp::name(), + {{"value", pattern.Attr("axis_info")}, + {"dtype", pattern.Attr("dtype_2")}, + {"place", pattern.Attr("place_2")}}); + + const auto &sum = pattern.Op(paddle::dialect::SumOp::name(), + {{"dtype", pattern.Attr("dtype")}, + {"keepdim", pattern.Attr("keep_dim")}}); + pattern.Tensor("ret") = sum(pattern.Tensor("arg0"), full_int_array()); + + // Result patterns + pir::drr::ResultPattern res = pattern.ResultPattern(); + const auto &cinn_reduce_sum = + res.Op(cinn::dialect::ReduceSumOp::name(), + {{"dim", pattern.Attr("axis_info")}, + {"keep_dim", pattern.Attr("keep_dim")}}); + res.Tensor("ret") = cinn_reduce_sum(res.Tensor("arg0")); + } +}; + +class MaxOpPattern : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // Source Pattern + pir::drr::SourcePattern pattern = ctx->SourcePattern(); + const auto &full_int_array = + pattern.Op(paddle::dialect::FullIntArrayOp::name(), + {{"value", pattern.Attr("axis_info")}, + {"dtype", pattern.Attr("dtype_2")}, + {"place", pattern.Attr("place_2")}}); + + const auto &pd_max = pattern.Op(paddle::dialect::MaxOp::name(), + {{"keepdim", pattern.Attr("keep_dim")}}); + pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array()); + + // Result patterns + pir::drr::ResultPattern res = pattern.ResultPattern(); + const auto &cinn_reduce_max = + res.Op(cinn::dialect::ReduceMaxOp::name(), + {{"dim", pattern.Attr("axis_info")}, + {"keep_dim", pattern.Attr("keep_dim")}}); + res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0")); + } +}; + +PdOpToCinnOpPass::PdOpToCinnOpPass() : pir::Pass("pd_to_cinn_pass", 1) {} + +bool PdOpToCinnOpPass::Initialize(pir::IrContext *context) { + pir::RewritePatternSet ps(context); + ps.Add(SumOpPattern().Build(context)); + ps.Add(MaxOpPattern().Build(context)); + + patterns_ = ::pir::FrozenRewritePatternSet(std::move(ps)); + return true; +} + +void PdOpToCinnOpPass::Run(pir::Operation *op) { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); +} + +bool PdOpToCinnOpPass::CanApplyOn(pir::Operation *op) const { + return op->isa() && op->num_regions() > 0; +} + +void PdOp2CinnOpConverter(::pir::Program *program) { + pir::IrContext *ctx = pir::IrContext::Instance(); + + pir::PassManager pm(ctx); + pm.AddPass(std::make_unique()); + + pm.Run(program); +} +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h new file mode 100644 index 0000000000000..d6c0bd2013bbc --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h @@ -0,0 +1,43 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/program.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class PdOpToCinnOpPass : public pir::Pass { + public: + PdOpToCinnOpPass(); + + bool Initialize(pir::IrContext *context) override; + + void Run(pir::Operation *op) override; + + bool CanApplyOn(pir::Operation *op) const override; + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +void PdOp2CinnOpConverter(::pir::Program *program); + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt b/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt index 6023117faee09..3452dcd74ab9f 100644 --- a/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt @@ -1,4 +1,10 @@ if(NOT CINN_ONLY) - cinn_cc_library(cinn_runtime_dialect SRCS runtime_dialect.cc jit_kernel_op.cc - DEPS pir_core) + cinn_cc_library( + cinn_runtime_dialect + SRCS + runtime_dialect.cc + jit_kernel_op.cc + DEPS + cinn_op_dialect + pir) endif() diff --git a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc index c98eb564b9735..2d8833a6acefc 100644 --- a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc +++ b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc @@ -14,6 +14,8 @@ #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/enforce.h" @@ -27,15 +29,17 @@ void JitKernelOp::VerifySig() { auto& attributes = this->attributes(); - IR_ENFORCE(attributes.count(kAttrName) > 0 && - attributes.at(kAttrName).isa<::pir::PointerAttribute>(), - "Type of attribute: instruction is not right."); + IR_ENFORCE( + attributes.count(kAttrName) > 0 && + attributes.at(kAttrName).isa(), + "Type of attribute: instruction is not right."); } -hlir::framework::Instruction* JitKernelOp::instruction() { - void* ptr = - attributes().at(kAttrName).dyn_cast<::pir::PointerAttribute>().data(); - return reinterpret_cast(ptr); +const hlir::framework::pir::CUDAJITInfo& JitKernelOp::cuda_jit_info() { + return attributes() + .at(kAttrName) + .dyn_cast() + .data(); } } // namespace dialect diff --git a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h index 62adcf2b1c7f1..0ac3d26c262b7 100644 --- a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h +++ b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h @@ -14,16 +14,11 @@ #pragma once +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/pir/core/op_base.h" namespace cinn { -namespace hlir { -namespace framework { -class Instruction; -} // namespace framework -} // namespace hlir - namespace dialect { /* @@ -46,10 +41,10 @@ class JitKernelOp : public ::pir::Op { static const char* name() { return "cinn_runtime.jit_kernel"; } // TODO(Aurelius84): Think deeply what should contains static constexpr uint32_t attributes_num = 1; - static constexpr char* kAttrName = "instruction"; + static constexpr char* kAttrName = "jit_info"; static const char* attributes_name[attributes_num]; - hlir::framework::Instruction* instruction(); + const hlir::framework::pir::CUDAJITInfo& cuda_jit_info(); void VerifySig(); }; diff --git a/paddle/cinn/hlir/framework/CMakeLists.txt b/paddle/cinn/hlir/framework/CMakeLists.txt index 54da1e2b7dc90..a9385d627828a 100755 --- a/paddle/cinn/hlir/framework/CMakeLists.txt +++ b/paddle/cinn/hlir/framework/CMakeLists.txt @@ -1,4 +1,4 @@ -add_subdirectory(new_ir) +add_subdirectory(pir) core_gather_headers() gather_srcs( @@ -24,13 +24,10 @@ gather_srcs( visualize_helper.cc compile_error.cc) -# TODO(Aurelius84): new_ir_compiler depends on pd_op_dialect and could +# TODO(Aurelius84): pir_compiler depends on op_dialect_vjp and could # not found under CINN_ONLY mode if(NOT CINN_ONLY) - cinn_cc_library(new_ir_compiler SRCS new_ir_compiler.cc DEPS cinnapi - pd_op_dialect) - cinn_cc_library(convert_to_dialect SRCS convert_to_dialect.cc DEPS cinnapi - cinn_op_dialect) + cinn_cc_library(pir_compiler SRCS pir_compiler.cc DEPS cinnapi op_dialect_vjp) endif() if(WITH_CUDA) diff --git a/paddle/cinn/hlir/framework/convert_to_dialect.cc b/paddle/cinn/hlir/framework/convert_to_dialect.cc deleted file mode 100644 index f76b49a54555f..0000000000000 --- a/paddle/cinn/hlir/framework/convert_to_dialect.cc +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/hlir/framework/convert_to_dialect.h" - -#include -#include - -#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" -#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" -#include "paddle/cinn/hlir/framework/program.h" -#include "paddle/pir/core/builtin_attribute.h" -#include "paddle/pir/core/program.h" - -namespace cinn { -namespace hlir { -namespace framework { - -std::unique_ptr<::pir::Program> ConvertToRuntimeDialect( - const hlir::framework::Program& program) { - ::pir::IrContext* ctx = ::pir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); - auto ir_program = std::make_unique<::pir::Program>(ctx); - - std::string jit_op_name = dialect::JitKernelOp::name(); - ::pir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name); - - auto& instrs = program.GetRunInstructions(); - for (auto& instr : instrs) { - std::unordered_map op_attrs{ - {dialect::JitKernelOp::kAttrName, - ::pir::PointerAttribute::get(ctx, instr.get())}, - }; - - ::pir::Operation* cinn_op = - ::pir::Operation::Create({}, op_attrs, {}, op_info); - ir_program->block()->push_back(cinn_op); - } - return std::move(ir_program); -} - -} // namespace framework -} // namespace hlir -} // namespace cinn diff --git a/paddle/cinn/hlir/framework/new_ir/CMakeLists.txt b/paddle/cinn/hlir/framework/new_ir/CMakeLists.txt deleted file mode 100755 index e08baf06dbd13..0000000000000 --- a/paddle/cinn/hlir/framework/new_ir/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -if(NOT CINN_ONLY) - core_gather_headers() - gather_srcs(cinnapi_src SRCS utils.cc op_lowering_impl.cc) -endif() diff --git a/paddle/cinn/hlir/framework/new_ir/group.h b/paddle/cinn/hlir/framework/new_ir/group.h deleted file mode 100644 index 1a67a02e58ca9..0000000000000 --- a/paddle/cinn/hlir/framework/new_ir/group.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include - -#include "paddle/cinn/hlir/framework/new_ir/utils.h" -#include "paddle/cinn/hlir/framework/op.h" -#include "paddle/pir/core/operation.h" - -namespace cinn { -namespace hlir { -namespace framework { -namespace newir { -using framework::OpPatternKind; - -// TODO(Aurelius84): Need to be replaced with CinnGroupOp -struct Group { - public: - explicit Group(const std::vector<::pir::Operation*>& group_ops) - : ops(group_ops) { - Initialize(); - } - - explicit Group(std::initializer_list<::pir::Operation*> group_ops) - : ops(group_ops) { - Initialize(); - } - - int group_id; - std::string fn_name; - OpPatternKind op_pattern_kind; - std::vector<::pir::Operation*> ops; - std::vector input_names; - std::vector output_names; - - private: - void Initialize() { - op_pattern_kind = OpPatternKind::kElementWise; - fn_name = CompatibleInfo::GroupOpsName(ops); - } -}; - -} // namespace newir -} // namespace framework -} // namespace hlir -} // namespace cinn diff --git a/paddle/cinn/hlir/framework/new_ir/utils.cc b/paddle/cinn/hlir/framework/new_ir/utils.cc deleted file mode 100644 index 3f938981390fb..0000000000000 --- a/paddle/cinn/hlir/framework/new_ir/utils.cc +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/hlir/framework/new_ir/utils.h" - -namespace cinn { -namespace hlir { -namespace framework { -namespace newir { - -const std::unordered_map CompatibleInfo::OP_NAMES = { - {"pd_op.full", "fill_constant"}}; - -std::string CompatibleInfo::OpName(const ::pir::Operation& op) { - std::string name = op.name(); - if (OP_NAMES.count(name)) { - return OP_NAMES.at(name); - } - auto pos = name.find("."); - if (pos == std::string::npos) { - return name; - } - auto cinn_op_name = name.substr(pos + 1); - VLOG(4) << "GetOpName: " << name << " -> " << cinn_op_name; - return cinn_op_name; -} - -std::string CompatibleInfo::ValueName(const ::pir::Value& value) { - return CompatibleInfo::kNamePrefix + - std::to_string(std::hash<::pir::Value>()(value)); -} - -std::string CompatibleInfo::OpFuncName(const ::pir::Operation& op) { - std::string op_name = OpName(op); - std::string func_name = - cinn::common::Context::Global().NewName("fn_" + op_name); - return func_name; -} - -std::string CompatibleInfo::GroupOpsName( - const std::vector<::pir::Operation*>& ops) { - std::string name = "fn"; - for (auto* op : ops) { - std::string op_name = OpName(*op); - name += "_" + cinn::common::Context::Global().NewName(op_name); - } - return name; -} - -std::vector CompatibleInfo::InputNames(const ::pir::Operation& op, - bool allow_duplicate) { - std::vector names; - std::unordered_set repeat; - for (int i = 0; i < op.num_operands(); ++i) { - auto value = op.operand_source(i); - std::string name = CompatibleInfo::ValueName(value); - if (!allow_duplicate && repeat.count(name)) { - continue; - } - repeat.insert(name); - names.push_back(name); - } - return names; -} - -std::vector CompatibleInfo::OutputNames(::pir::Operation& op) { - std::vector names; - for (int i = 0; i < op.num_results(); ++i) { - auto value = op.result(i); - std::string name = CompatibleInfo::ValueName(value); - names.push_back(std::move(name)); - } - return names; -} - -} // namespace newir -} // namespace framework -} // namespace hlir -} // namespace cinn diff --git a/paddle/cinn/hlir/framework/op_lowering.h b/paddle/cinn/hlir/framework/op_lowering.h index ac52aea80de71..8ae0d5869c1a4 100644 --- a/paddle/cinn/hlir/framework/op_lowering.h +++ b/paddle/cinn/hlir/framework/op_lowering.h @@ -22,7 +22,7 @@ #include "paddle/cinn/hlir/framework/op_lowering_impl_base.h" #include "paddle/cinn/lang/packed_func.h" #ifndef CINN_WITH_ONLY -#include "paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_impl.h" #endif namespace cinn { @@ -65,13 +65,13 @@ inline OpLowerer CreateOpLowerer( } #ifndef CINN_WITH_ONLY -template +template OpLowerer CreateOpLowerer(const Target&); template <> -inline OpLowerer CreateOpLowerer(const Target& target) { - auto* impl_base = new newir::OpLowererImpl(target); - return OpLowerer(impl_base); +inline OpLowerer CreateOpLowerer(const Target& target) { + auto* impl_base = new pir::OpLowererImpl(target); + return OpLowerer(impl_base); } #endif diff --git a/paddle/cinn/hlir/framework/op_lowering_impl.cc b/paddle/cinn/hlir/framework/op_lowering_impl.cc index b380ee8aaba2e..1b47dbda611d7 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/op_lowering_impl.cc @@ -19,10 +19,13 @@ #include "paddle/cinn/hlir/framework/graph_compiler_util.h" #include "paddle/cinn/hlir/framework/op_lowering_util.h" #include "paddle/cinn/hlir/op/external_api_registry.h" +#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" +#include "paddle/cinn/runtime/flags.h" PD_DECLARE_bool(cinn_use_cuda_vectorize); +PD_DECLARE_bool(cinn_new_group_scheduler); namespace cinn { namespace hlir { @@ -123,7 +126,10 @@ std::vector OpLowererImpl::LowerGroup( ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - if (apply_group_schedule) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + if (apply_group_schedule && + !(nodes.size() == 1 && + op_pattern_dict[nodes[0]->op()] == OpPatternKind::kNonFusible)) { DoGroupSchedule(ir_sch, group, tensor_map); VLOG(3) << "After group schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); @@ -463,6 +469,23 @@ ir::Expr OpLowererImpl::DoGroupSchedule( ir::IRSchedule& ir_sch, const GroupPtr& group, const std::unordered_map& tensor_map) { + if (FLAGS_cinn_new_group_scheduler) { + std::unordered_set output_tensor_names; + std::transform( + group->output_nodes.begin(), + group->output_nodes.end(), + std::inserter(output_tensor_names, output_tensor_names.begin()), + [](const Node* node) { + NodeData* node_data = + (*node->outlinks().begin())->sink()->safe_as(); + CHECK(node_data); + return node_data->id(); + }); + ir::StaticShapeGroupScheduler group_scheduler( + &ir_sch, output_tensor_names, target_); + group_scheduler.Schedule(); + return ir_sch.GetModule().GetExprs().at(0); + } // topological order. auto nodes_set = group->NodeSet(); auto v_consumers = BuildVirtualConsumer(group, this->shape_dict_); @@ -558,8 +581,7 @@ ir::Expr OpLowererImpl::DoGroupSchedule( this->shape_dict_); } else { VLOG(3) << "Before assign node " << node->id() - << " into horizontal link reducer " << greducer->id() - << ", ir is:\n" + << " into horizontal link reducer, ir is:\n" << ir_sch.GetModule().GetExprs().at(0); // if node is horizontal with reduce or node is reduce, loop assign // diff --git a/paddle/cinn/hlir/framework/pir/CMakeLists.txt b/paddle/cinn/hlir/framework/pir/CMakeLists.txt new file mode 100755 index 0000000000000..10ce9d7c07275 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/CMakeLists.txt @@ -0,0 +1,5 @@ +if(NOT CINN_ONLY) + core_gather_headers() + gather_srcs(cinnapi_src SRCS utils.cc op_lowering_impl.cc op_mapper.cc + op_lowering_util.cc) +endif() diff --git a/paddle/cinn/hlir/framework/pir/group.h b/paddle/cinn/hlir/framework/pir/group.h new file mode 100644 index 0000000000000..7b0913525c254 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/group.h @@ -0,0 +1,219 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle/cinn/hlir/framework/op.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/pir/core/operation.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { +using framework::OpPatternKind; + +// TODO(Aurelius84): Need to be replaced with CinnGroupOp +struct Group { + public: + Group() = default; + explicit Group(const std::vector<::pir::Operation*>& group_ops) + : ops(group_ops) {} + + explicit Group(std::initializer_list<::pir::Operation*> group_ops) + : ops(group_ops) {} + + // distance to last group. + int depth{0}; + int max_depth{0}; + int min_depth{INT_MAX}; + // group id, consisted of op's id. + std::string group_id{""}; + // global unique id. + std::string unique_id{"uniq"}; + // op in this group + std::vector<::pir::Operation*> ops; + std::unordered_set<::pir::Operation*> ops_set; + // input ops of the group. + std::unordered_map<::pir::Operation*, int> input_ops; + // output ops of the group. + std::unordered_set<::pir::Operation*> output_ops; + // op pattern kind. + OpPatternKind op_pattern_kind{kReduction}; + // internal op, the output is used by multi-op. + // internal op can't use compute inline, should use buffer. + std::unordered_set<::pir::Operation*> internal_ops; + // master op for schedule + std::unordered_set<::pir::Operation*> master_ops; + + // fused sub-groups, used for fusion merge pass + std::vector> fused_sub_groups; + // if as sub-group, used for belong groups. + std::unordered_set> belong_groups; + + // for op lowering. + std::vector input_names; + std::vector output_names; + std::string fn_name{""}; + + struct SharedGroupHasher { + size_t operator()(const std::shared_ptr& group) const noexcept { + return std::hash()(reinterpret_cast(group.get())); + } + }; + struct SharedGroupComparator { + bool operator()(const std::shared_ptr& first, + const std::shared_ptr& second) const noexcept { + return first.get() == second.get(); + } + }; + + std::vector<::pir::Operation*> CollectOps() { + if (fused_sub_groups.size()) { + std::vector<::pir::Operation*> tmp_ops; + for (auto& group : fused_sub_groups) { + tmp_ops.insert(tmp_ops.end(), group->ops.begin(), group->ops.end()); + } + return tmp_ops; + } else { + return ops; + } + } + + void WalkOps(const std::function& VisitOp) const { + if (fused_sub_groups.size()) { + for (auto& group : fused_sub_groups) { + for (const auto& op : group->ops) { + VisitOp(op); + } + } + } else { + for (const auto& op : ops) { + VisitOp(op); + } + } + } + + std::unordered_set<::pir::Operation*> OpSet() { + std::unordered_set<::pir::Operation*> op_set; + for (auto op : CollectOps()) { + op_set.insert(op); + } + return op_set; + } + + std::unordered_set<::pir::Value> GetInputOpValues() { + std::unordered_set<::pir::Value> group_inputs; + auto ops_set = this->OpSet(); + // count all op's input Value + for (auto op : this->CollectOps()) { + for (auto& value : op->operands_source()) { + if (!value || !value.type()) { + continue; + } + + if (!ops_set.count(value.dyn_cast<::pir::OpResult>().owner())) { + // if the input value owner op is not in OpSet, it's the group's input + group_inputs.insert(value); + continue; + } + + if (std::find(this->input_names.begin(), + this->input_names.end(), + CompatibleInfo::ValueName(value)) != + this->input_names.end()) { + // if the input data in group's input_names + group_inputs.insert(value); + continue; + } + } + } + + return group_inputs; + } + std::unordered_set<::pir::Value> GetOutputOpValues() { + std::unordered_set<::pir::Value> group_outputs; + + for (auto op : this->output_ops) { + for (auto& result : op->results()) { + if (!result || result.type()) { + continue; + } + + group_outputs.insert(result); + } + } + return group_outputs; + } + + std::string GetFuncName() { return "fn_" + group_id + unique_id; } + + public: + const std::unordered_set, + SharedGroupHasher, + SharedGroupComparator>& + producer_groups() const { + return producer_groups_; + } + + const std::unordered_set, + SharedGroupHasher, + SharedGroupComparator>& + consumer_groups() const { + return consumer_groups_; + } + + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator>* + mut_producer_groups() { + return &producer_groups_; + } + + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator>* + mut_consumer_groups() { + return &consumer_groups_; + } + + OpPatternKind kind() const { return op_pattern_kind; } + + std::string FuncName() const { + if (fn_name == "") { + // TODO(Aurelius84): Polish this implementation. + const_cast(this)->fn_name = CompatibleInfo::GroupOpsName(ops); + } + return this->fn_name; + } + + private: + // input groups + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator> + producer_groups_; + // output grous + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator> + consumer_groups_; +}; + +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc similarity index 58% rename from paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc rename to paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index ea76d939bc45b..40cd6444a4fed 100644 --- a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -12,17 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_impl.h" #include -#include "paddle/cinn/hlir/framework/op_lowering_util.h" + +#include "paddle/cinn/ast_gen_ius/tensor_group.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" #include "paddle/cinn/hlir/op/external_api_registry.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" -#include "paddle/cinn/hlir/framework/new_ir/utils.h" +#include "paddle/cinn/hlir/framework/compile_error.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/lang/placeholder.h" -#include "paddle/cinn/utils/attribute_util.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/phi/core/ddim.h" @@ -31,7 +33,7 @@ PD_DECLARE_bool(cinn_use_cuda_vectorize); namespace cinn { namespace hlir { namespace framework { -namespace newir { +namespace pir { using cinn::hlir::op::ExternalApiRegistry; using common::Type; @@ -39,13 +41,44 @@ using framework::OpPatternKind; using framework::StrategyFunction; namespace details { + +bool IsInTensorMap( + const std::string& name, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map) { + for (auto iter : tensor_map) { + if (name == CompatibleInfo::ValueName(iter.first)) { + return true; + } + } + return false; +} + +common::Type GetTensorDtype(const ::pir::Value& value) { + auto type_info = value.type().dyn_cast(); + auto in_shape = phi::vectorize(type_info.dims()); + auto dtype = type_info.dtype(); + return CompatibleInfo::ConvertIRType(dtype); +} + +common::Type GetTensorDtype( + const std::string& name, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map) { + for (auto iter : tensor_map) { + if (name == CompatibleInfo::ValueName(iter.first)) { + return GetTensorDtype(iter.first); + } + } + VLOG(4) << name << " is not in tensor map, return FP32 by default."; + return common::F32(); +} + ir::Tensor GetTensor(const ::pir::Value& value) { auto type_info = value.type().dyn_cast(); auto in_shape = phi::vectorize(type_info.dims()); auto dtype = type_info.dtype(); std::string input_id = CompatibleInfo::ValueName(value); return lang::CreatePlaceHolder( - in_shape, utils::ConvertIRType(dtype), input_id); + in_shape, CompatibleInfo::ConvertIRType(dtype), input_id); } std::vector CollectInputTensor( @@ -53,9 +86,8 @@ std::vector CollectInputTensor( std::vector* func_args, std::unordered_map<::pir::Value, ir::Tensor>* tensor_map) { std::vector tensors; - for (auto in_value : op->operands_source()) { + for (auto in_value : CompatibleInfo::RealOperandSources(*op)) { VLOG(4) << "input tensor name: " << CompatibleInfo::ValueName(in_value); - // NOTE(Aurelius84): Need always to create placeholder for input tensor. ir::Tensor tensor = details::GetTensor(in_value); if (!tensor_map->count(in_value)) { // record tensor. @@ -80,7 +112,7 @@ void CollectOutputInfo(::pir::Operation* op, auto type_info = out_value.type().dyn_cast(); - out_types->push_back(utils::ConvertIRType(type_info.dtype())); + out_types->push_back(CompatibleInfo::ConvertIRType(type_info.dtype())); auto out_shape = phi::vectorize(type_info.dims()); out_shapes->push_back(std::move(out_shape)); } @@ -89,7 +121,7 @@ void CollectOutputInfo(::pir::Operation* op, NodeAttr CollectAttrs(const ::pir::Operation& op) { NodeAttr node_attrs; VLOG(4) << "op.attributes():" << op.attributes().size(); - auto attrs = utils::ConvertAttributes(op.attributes()); + auto attrs = CompatibleInfo::ConvertAttributes(op); node_attrs.node_name = CompatibleInfo::OpName(op); node_attrs.attr_store = std::move(attrs); @@ -106,6 +138,10 @@ std::vector OpLowererImpl::Lower(const GroupPtr& group, bool apply_pass) { VLOG(3) << "Lowering Group : " << group->group_id << " , Op Pattern : " << group->op_pattern_kind; + // TODO(Aurelius84): The logic shoule be moved into op_fusion module. + if (group->ops.size() >= 1U & group->output_ops.size() == 0) { + group->output_ops.insert(group->ops[group->ops.size() - 1]); + } group->input_names.clear(); group->output_names.clear(); switch (group->op_pattern_kind) { @@ -138,10 +174,8 @@ bool OpLowererImpl::ElementwiseScheduleDetermineFunction(::pir::Operation* op) { } bool OpLowererImpl::ReduceScheduleDetermineFunction(::pir::Operation* op) { - // TODO(Aurelius84): Support this. - // auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - // return op_pattern_dict[op] == framework::kReduction; - return true; + VLOG(3) << "in ReduceScheduleDetermineFunction"; + return CompatibleInfo::OpKind(*op) == framework::kReduction; } bool OpLowererImpl::NonFusibleScheduleDetermineFunction(::pir::Operation* op) { @@ -160,24 +194,27 @@ std::vector OpLowererImpl::LowerGroup( } std::vector group_func_arg_tensors; std::unordered_map<::pir::Value, ir::Tensor> tensor_map; + // for some op, it will output more tmp value and regard as + // XX_0, XX_1, so we log them in tmp_tensor_info; + std::unordered_map tmp_tensor_info; bool do_op_schedule = apply_group_schedule || apply_op_schedule; std::vector func_bodies = LowerOps(ops, do_op_schedule, schedule_determine_func, &group_func_arg_tensors, - &tensor_map); + &tensor_map, + &tmp_tensor_info); // 2.Do group schedule. ir::ModuleExpr mod_expr(func_bodies); ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - // TODO(Aurelius84): Support this. - // if (apply_group_schedule) { - // DoGroupSchedule(ir_sch, group, tensor_map); - // VLOG(3) << "After group schedule, ir is: \n" - // << ir_sch.GetModule().GetExprs().at(0); - // } + if (apply_group_schedule) { + DoGroupSchedule(ir_sch, group, tensor_map, tmp_tensor_info); + VLOG(3) << "After group schedule, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); + } // 3.Do post-processing, // including preparing function args and temporary variables, @@ -219,7 +256,7 @@ std::vector OpLowererImpl::LowerCustomCall( // target_); // } std::vector compute_args = { - common::CINNValue(group->fn_name), common::CINNValue(external_api)}; + common::CINNValue(group->FuncName()), common::CINNValue(external_api)}; common::CINNValuePack pack = impl->fcompute(common::CINNValuePack{compute_args}); CHECK_EQ(pack.size(), 1UL); @@ -250,9 +287,8 @@ std::vector OpLowererImpl::PostProcess( } group->output_names.clear(); - // FIXME(Aurelius84): Do we need to use output_ops? - // Currently we regards all ops as output_ops. - for (auto& op : group->ops) { + VLOG(3) << "group->output_ops.size(): " << group->output_ops.size(); + for (auto& op : group->output_ops) { // collect all output tensor. for (auto opresult : op->results()) { if (tensor_map.count(opresult) == 0) { @@ -299,7 +335,7 @@ std::vector OpLowererImpl::PostProcess( auto temp_buffers = lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body); // 3.Building LoweredFunc - auto func = ir::_LoweredFunc_::Make(group->fn_name, + auto func = ir::_LoweredFunc_::Make(group->FuncName(), group_func_args, ir_sch->GetModule().GetExprs().at(0), temp_buffers); @@ -316,7 +352,8 @@ std::vector OpLowererImpl::LowerOps( bool apply_op_schedule, ScheduleDetermineFunction schedule_determine_func, std::vector* group_func_arg_tensors, - std::unordered_map<::pir::Value, ir::Tensor>* tensor_map) { + std::unordered_map<::pir::Value, ir::Tensor>* tensor_map, + std::unordered_map* tmp_tensor_info) { auto& strategy = Operator::GetAttrs("CINNStrategy"); std::vector func_bodies; for (auto* op : ops) { @@ -335,10 +372,9 @@ std::vector OpLowererImpl::LowerOps( const hlir::framework::Operator* cinn_op = Operator::Get(cinn_op_name); auto op_impl = OpStrategy::SelectImpl(strategy[cinn_op]( node_attrs, op_func_arg_tensors, out_types, out_shapes, this->target_)); - // 2.Perform the lower process of Op - std::vector funcs = - DoOpLower(op_impl, op, tensor_map, &op_func_arg_tensors); + std::vector funcs = DoOpLower( + op_impl, op, tensor_map, tmp_tensor_info, &op_func_arg_tensors); if (apply_op_schedule && (this->*schedule_determine_func)(op)) { // 3.Perform the schedule of Op @@ -360,6 +396,7 @@ std::vector OpLowererImpl::DoOpLower( std::shared_ptr op_impl, ::pir::Operation* op, std::unordered_map<::pir::Value, ir::Tensor>* tensor_map, + std::unordered_map* tmp_tensor_info, std::vector* op_func_arg_tensors) { VLOG(4) << "Do lower with Compute, op: " << op->name(); std::vector cinn_inputs; @@ -386,10 +423,13 @@ std::vector OpLowererImpl::DoOpLower( // Some op may output multiple temp tensors in their Compute // definition, but only one output in the graph, and we use id + // "_0"/"_1" as key. - // FIXME(Aurelius84): It seems that the implementation is relate with - // string name. - // (*tensor_map)[op_results[0] + post] = expr.as_tensor_ref(); - // post = "_" + std::to_string(idx); + if (idx < op_results.size()) { + (*tensor_map)[op_results[idx]] = expr.as_tensor_ref(); + } + std::string tensor_name = CompatibleInfo::ValueName(op_results[0]) + post; + VLOG(3) << "Add tmp tensor name for reducer op: " << tensor_name; + (*tmp_tensor_info)[tensor_name] = expr.as_tensor_ref(); + post = "_" + std::to_string(idx); } else { // If the number of output tensors defined by Compute is less equal than // the output node_data on the graph, then there is a one-to-one @@ -409,16 +449,17 @@ std::vector OpLowererImpl::DoOpLower( // 2.Do lower std::string lower_fn_name = CompatibleInfo::OpFuncName(*op); - std::vector funcs = lang::LowerVec(lower_fn_name, - tmp_stages, - *op_func_arg_tensors, - {}, - {}, - nullptr, - this->target_, - true); + ast_gen_ius::TensorGroup tensor_group = + ast_gen_ius::ConvertStageMapToTensorGroup(tmp_stages); + std::vector funcs = lang::LowerToAstVec( + lower_fn_name, *op_func_arg_tensors, {&tensor_group}, this->target_); VLOG(4) << "Lower op: " << lower_fn_name << ", get " << funcs.size() << " LoweredFunc:\n"; + if (VLOG_IS_ON(4)) { + for (auto fun : funcs) { + VLOG(4) << fun; + } + } op_func_arg_tensors->clear(); for (int idx = 0; idx < pack.size() - 1; ++idx) { @@ -452,7 +493,194 @@ ir::Expr OpLowererImpl::DoOpSchedule( return expr_pack[0].operator ir::Expr(); } -} // namespace newir +ir::Expr OpLowererImpl::DoGroupSchedule( + ir::IRSchedule& ir_sch, + const GroupPtr& group, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, + const std::unordered_map& tmp_tensor_info) { + // topological order. + auto ops_set = group->OpSet(); + auto v_consumers = BuildVirtualConsumer(group); + auto ops_in_order = BFSTopologicalOrderWithPriority(group, v_consumers); + // find reducer. + std::unordered_set<::pir::Operation*> ops_inline; + auto greducer = FindGlobalReducer(ops_in_order); + + // do schedule + for (auto op : ops_in_order) { + VLOG(4) << "Try FUSION " << op->name(); + std::string op_name = CompatibleInfo::OpName(*op); + auto op_kind = CompatibleInfo::OpKind(*op); + // consumers. + auto consumers = GetConsumersInSet(op, ops_set); + auto* reducer = greducer ? FindNearestReducer(op, ops_set) : greducer; + if (!reducer && greducer) { + reducer = v_consumers.count(op) ? v_consumers.find(op)->second : reducer; + if (reducer && + CompatibleInfo::OpKind(*reducer) != framework::kReduction) { + reducer = nullptr; + } + } + + auto masters = GetMasters(op, ops_inline, ops_set); + // TODO(Aurelius84): support inline later. + if (CanbeInline(op, reducer, consumers, masters, group, ops_set) && false) { + VLOG(3) << "Before compute inline, ir is:\n" + << ir_sch.GetModule().GetExprs().at(0); + auto block = ir_sch.GetBlock(CompatibleInfo::ValueName(op->result(0))); + ir::ComputeInlineChecker checker(ir_sch, block); + if (!checker.Check()) { + checker.BuildDataDependency(); + continue; + } + + // if exist global reduce node. + if (greducer) { + auto loops = ir_sch.GetLoops(CompatibleInfo::ValueName(op->result(0))); + if (op_kind == framework::kElementWise) { + ir_sch.FlattenLoops(loops, true); + } else { + ir_sch.FlattenLoops(loops, false); + } + } + + ir_sch.ComputeInline(block); + ops_inline.insert(op); + VLOG(3) << "After compute inline, ir is:\n" + << ir_sch.GetModule().GetExprs().at(0); + continue; + } + // find master to computeat. + auto master = GetMasterToComputeAt( + op, ops_in_order, ops_inline, ops_set, v_consumers); + std::string op_out_name = CompatibleInfo::ValueName(op->result(0)); + // assign to reducer/master loop. + if (reducer) { + VLOG(3) << "Before assign node " << op_name + << " into vertical link reducer " + << CompatibleInfo::OpName(*reducer) << ", ir is:\n" + << ir_sch.GetModule().GetExprs().at(0); + // if node is vertical with reduce, loop assign reducer. + LoopAssignReduce( + ir_sch, op, reducer, this->target_, tensor_map, tmp_tensor_info); + } else if (greducer) { + auto greducer_out_shape = CompatibleInfo::ValueShape(greducer->result(0)); + auto op_out_shape = CompatibleInfo::ValueShape(op->result(0)); + if (CompatibleInfo::ShapeProduct(greducer_out_shape) != + CompatibleInfo::ShapeProduct(op_out_shape)) { + LoopAssignReduce( + ir_sch, op, greducer, this->target_, tensor_map, tmp_tensor_info); + } + } else { + VLOG(3) << "Before assign node " << op_name + << " into horizontal link reducer, ir is:\n" + << ir_sch.GetModule().GetExprs().at(0); + // if node is horizontal with reduce or node is reduce, loop assign + // master. + auto loops = ir_sch.GetLoops(op_out_name); + if (op_kind == framework::kElementWise) { + ir_sch.FlattenLoops(loops, true); + } else if (op_kind != framework::kReduction) { + ir_sch.FlattenLoops(loops, false); + } + + if (master && op_kind != framework::kReduction) { + auto master_loops = + ir_sch.GetLoops(CompatibleInfo::ValueName(master->result(0))); + std::vector splits; + for (auto loop : master_loops) { + splits.push_back(loop.As()->extent.as_int32()); + } + loops = ir_sch.GetLoops(op_out_name); + ir_sch.Split(loops[0], splits); + } + } + VLOG(3) << "Before loop fusion, ir is:\n" + << ir_sch.GetModule().GetExprs().at(0); + // do loop fuse. + LoopComputeAt(ir_sch, + op, + master ? master : ops_in_order.front(), + group, + tensor_map, + tmp_tensor_info); + VLOG(3) << "After loop fusion, ir is:\n" + << ir_sch.GetModule().GetExprs().at(0); + } + + // do vectorize + auto all_blocks = ir_sch.GetAllBlocks(); + VLOG(4) << "Size of blocks: " << all_blocks.size(); + VLOG(4) << "Op Pattern : " << group->op_pattern_kind; + + // only support first block? + auto block = all_blocks[0]; + + if (block->as() == nullptr || + block->as() + ->schedule_block->as() == nullptr) { + std::string err_msg = + "Group scheduling, the Expr is not wrapped by ScheduleBlockRealize or " + "ScheduleBlock, cannot be scheduled."; + std::ostringstream detail_info; + detail_info << "Expr:\n"; + detail_info << block; + throw CompileErrorHandler(CompilationStatus::LOWERING_FAIL, + err_msg, + detail_info.str(), + __FILE__, + __LINE__); + } + auto is_tensor_block = true; + auto tensor_name = block->as() + ->schedule_block->as() + ->name; + if (!details::IsInTensorMap(tensor_name, tensor_map)) { + is_tensor_block = false; + } + if (FLAGS_cinn_use_cuda_vectorize && is_tensor_block && + (group->op_pattern_kind == framework::kElementWise || + group->op_pattern_kind == framework::kInjective || + group->op_pattern_kind == framework::kBroadcast)) { + // auto loops = ir_sch.GetLoops(GetNodeData(node)->id()); + auto loops = ir_sch.GetLoops(block); + VLOG(4) << "Op Pattern : " << loops.size(); + if (loops.size() >= 1) { + VLOG(4) << "Before vectorize, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); + auto loop_inner = loops.back(); + int vector_width = 1; + auto psize = ir::GetLoopExtent(loop_inner); + auto dtype = details::GetTensorDtype(tensor_name, tensor_map); + VLOG(4) << tensor_name << " dtype " << dtype; + if (psize % 8 == 0 && (dtype.is_float16() || dtype.is_bfloat16())) { + vector_width = 8; + } else if (psize % 4 == 0) { + vector_width = 4; + } else if (psize % 2 == 0) { + vector_width = 2; + } + if (vector_width > 1) { + auto splited = ir_sch.Split(loop_inner, {-1, vector_width}); + splited[0].As()->set_bind_info( + loop_inner.As()->bind_info()); + splited[1].As()->set_serial(); + ir_sch.Vectorize(splited[1], vector_width); + } + VLOG(4) << "After vectorize, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); + } + } + + VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); + SyncThreadWithShared(ir_sch, group, ops_inline, ops_set, tensor_map); + VLOG(4) << "After IRSchedule, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); + return ir_sch.GetModule().GetExprs().at(0); +} + +} // namespace pir } // namespace framework } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h similarity index 87% rename from paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h rename to paddle/cinn/hlir/framework/pir/op_lowering_impl.h index 3fa859bbce880..156e7a399ced5 100644 --- a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h @@ -19,9 +19,9 @@ #include "paddle/cinn/common/target.h" #include "paddle/cinn/hlir/framework/instruction.h" -#include "paddle/cinn/hlir/framework/new_ir/group.h" #include "paddle/cinn/hlir/framework/op_lowering_impl_base.h" #include "paddle/cinn/hlir/framework/op_strategy.h" +#include "paddle/cinn/hlir/framework/pir/group.h" #include "paddle/cinn/ir/lowered_func.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule/ir_schedule_util.h" @@ -36,7 +36,7 @@ namespace cinn { namespace hlir { namespace framework { -namespace newir { +namespace pir { using GroupPtr = std::shared_ptr; @@ -119,7 +119,8 @@ class OpLowererImpl : public OpLowererImplBase { bool apply_op_schedule, ScheduleDetermineFunction schedule_determine_func, std::vector* group_func_arg_tensors, - std::unordered_map<::pir::Value, ir::Tensor>* tensor_map); + std::unordered_map<::pir::Value, ir::Tensor>* tensor_map, + std::unordered_map* tmp_tensor_info); /** * @brief Lower an Op to CINN IR. The Compute and Lower processes will be @@ -134,6 +135,7 @@ class OpLowererImpl : public OpLowererImplBase { std::shared_ptr op_impl, ::pir::Operation* op, std::unordered_map<::pir::Value, ir::Tensor>* tensor_map, + std::unordered_map* tmp_tensor_info, std::vector* op_func_arg_tensors); /** @@ -147,6 +149,20 @@ class OpLowererImpl : public OpLowererImplBase { const std::vector& op_func_arg_tensors, const std::vector& lowered_funcs); + /** + * @brief Apply schedule on a group. + * @param ir_sch The IRSchedule containing the entire group's lowered func + * bodies. + * @param group The group to be scheduled. + * @param tensor_map All tensors used for calculating the group. + * @return The lowered func body after schedule of the group. + */ + ir::Expr DoGroupSchedule( + ir::IRSchedule& ir_sch, // NOLINT + const GroupPtr& group, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, + const std::unordered_map& tmp_tensor_info); + // Functions used to determine which Ops to schedule at op level, define a // policy for each type of group. inline bool ReduceScheduleDetermineFunction(::pir::Operation* op); @@ -157,7 +173,7 @@ class OpLowererImpl : public OpLowererImplBase { Target target_; }; -} // namespace newir +} // namespace pir } // namespace framework } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_util.cc b/paddle/cinn/hlir/framework/pir/op_lowering_util.cc new file mode 100644 index 0000000000000..a9b14a215107a --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/op_lowering_util.cc @@ -0,0 +1,1585 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" + +#include +#include +#include "glog/logging.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/cinn/hlir/pe/nn_util.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" +#include "paddle/cinn/ir/utils/ir_nodes_collector.h" +#include "paddle/cinn/utils/string.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { + +::pir::Operation* FindGlobalReducer( + const std::vector<::pir::Operation*>& ops_in_order) { + for (auto& op : ops_in_order) { + if (CompatibleInfo::OpKind(*op) == framework::kReduction) { + return op; + } + } + return nullptr; +} + +std::vector<::pir::Operation*> GetConsumersInSet( + ::pir::Operation* op, + const std::unordered_set<::pir::Operation*>& ops_set) { + std::vector<::pir::Operation*> consumers; + for (auto& out : op->results()) { + for (auto use_iter = out.use_begin(); use_iter != out.use_end(); + ++use_iter) { + ::pir::Operation* consumer = use_iter->owner(); + CHECK(consumer); + if (ops_set.count(consumer)) { + consumers.push_back(consumer); + } + } + } + return consumers; +} + +std::vector<::pir::Operation*> GetProducers(::pir::Operation* op) { + std::vector<::pir::Operation*> producers; + for (auto& source : op->operands_source()) { + auto* producer_op = source.dyn_cast<::pir::OpResult>().owner(); + CHECK(producer_op); + producers.push_back(producer_op); + } + return producers; +} + +std::vector<::pir::Operation*> GetProducersInSet( + ::pir::Operation* op, + const std::unordered_set<::pir::Operation*>& ops_set) { + std::vector<::pir::Operation*> producers; + for (auto& producer_op : GetProducers(op)) { + CHECK(producer_op); + if (ops_set.count(producer_op)) { + producers.push_back(producer_op); + } + } + return producers; +} + +std::vector<::pir::Operation*> FindConsumers( + ::pir::Operation* op, + const std::unordered_set<::pir::Operation*>& ops_set, + const std::unordered_map<::pir::Operation*, ::pir::Operation*>& + virtual_consumers) { + auto consumers = GetConsumersInSet(op, ops_set); + if (virtual_consumers.count(op)) { + consumers.push_back(virtual_consumers.find(op)->second); + } + return consumers; +} + +std::vector<::pir::Operation*> FindProducers( + ::pir::Operation* op, + const std::unordered_set<::pir::Operation*>& ops_set, + const std::unordered_map<::pir::Operation*, ::pir::Operation*>& + virtual_consumers) { + auto producers = GetProducersInSet(op, ops_set); + for (const auto& iter : virtual_consumers) { + if (iter.second == op) { + producers.push_back(iter.first); + } + } + + return producers; +} + +using Visitor = std::function( + ::pir::Operation*, const std::unordered_set<::pir::Operation*>&)>; +::pir::Operation* FindReducerInRoute( + ::pir::Operation* op, + const std::unordered_set<::pir::Operation*>& ops_set, + Visitor visitor) { + std::queue<::pir::Operation*> candidates; + candidates.push(op); + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + for (auto consumer : visitor(candidate, ops_set)) { + if (CompatibleInfo::OpKind(*consumer) == framework::kReduction) { + return consumer; + } + candidates.push(consumer); + } + } + + return nullptr; +} + +::pir::Operation* FindNearestReducer( + ::pir::Operation* op, + const std::unordered_set<::pir::Operation*>& ops_set) { + // from consumers find reducer. + if (auto reducer = FindReducerInRoute(op, ops_set, GetConsumersInSet)) { + return reducer; + } else { + return FindReducerInRoute(op, ops_set, GetProducersInSet); + } +} + +std::unordered_map<::pir::Operation*, ::pir::Operation*> BuildVirtualConsumer( + const GroupPtr& group) { + std::unordered_map<::pir::Operation*, ::pir::Operation*> virtual_consumers; + std::unordered_set<::pir::Operation*> ops_set(group->ops.begin(), + group->ops.end()); + if (group->op_pattern_kind != framework::kReduction) { + return virtual_consumers; + } + + ::pir::Operation* e_op = nullptr; + ::pir::Operation* r_op = nullptr; + for (auto master_op : group->master_ops) { + if (CompatibleInfo::OpKind(*master_op) != framework::kReduction) { + // producer exits reduce-sum and not consumers. + if (!e_op && FindReducerInRoute(master_op, ops_set, GetProducersInSet) && + GetConsumersInSet(master_op, ops_set).size() == 0) { + e_op = master_op; + } + } else if (!r_op) { + r_op = master_op; + } + } + + // try to find reducer with different shape. + for (auto output_op : group->output_ops) { + if (CompatibleInfo::OpKind(*output_op) == framework::kReduction) { + if (isl_ast_expr_op_and_then) { + virtual_consumers[output_op] = e_op; + } + continue; + } + if (FindNearestReducer(output_op, ops_set)) { + continue; + } + + bool found = false; + std::unordered_set<::pir::Operation*> visited; + std::queue<::pir::Operation*> candidates; + + candidates.push(output_op); + visited.insert(output_op); + // from producers find reducer consumer. + while (!found && !candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + for (auto producer : GetProducersInSet(candidate, ops_set)) { + if (visited.count(producer)) { + continue; + } + + auto reducer = FindReducerInRoute(producer, ops_set, GetConsumersInSet); + if (reducer) { + virtual_consumers[output_op] = reducer; + found = true; + break; + } + candidates.push(producer); + visited.insert(producer); + } + } + + auto output_shape = CompatibleInfo::ValueShape(output_op->result(0)); + if (!found && output_op != e_op && e_op) { + auto e_output_shape = CompatibleInfo::ValueShape(e_op->result(0)); + if (CompatibleInfo::ShapeProduct(output_shape) == + CompatibleInfo::ShapeProduct(e_output_shape)) { + virtual_consumers[output_op] = e_op; + found = true; + } + } + if (!found && r_op) { + auto r_input_shape = CompatibleInfo::ValueShape(r_op->operand_source(0)); + if (CompatibleInfo::ShapeProduct(output_shape) == + CompatibleInfo::ShapeProduct(r_input_shape)) { + virtual_consumers[output_op] = r_op; + found = true; + } + } + } + // Establish virtual consumer relationships between output nodes with the same + // shape. This allows the calculation of output nodes without affiliation to + // be placed under the same loop. + std::unordered_map numel_consumers; + for (auto out_op : group->output_ops) { + if (virtual_consumers.find(out_op) != virtual_consumers.end() || + !GetConsumersInSet(out_op, ops_set).empty()) { + continue; + } + auto shape = CompatibleInfo::ValueShape(out_op->result(0)); + int numel = CompatibleInfo::ShapeProduct(shape); + if (numel_consumers.find(numel) == numel_consumers.end()) { + numel_consumers.insert(std::make_pair(numel, out_op)); + } else { + virtual_consumers[out_op] = numel_consumers[numel]; + } + } + return virtual_consumers; +} + +std::vector<::pir::Operation*> BFSTopologicalOrderWithPriority( + const GroupPtr& group, + const std::unordered_map<::pir::Operation*, ::pir::Operation*>& + virtual_consumers) { + struct OpWithPriority { + ::pir::Operation* op; + int priority; + }; + + struct Comparator { + bool operator()(const OpWithPriority& lhs, const OpWithPriority& rhs) { + return lhs.priority > rhs.priority; + } + }; + + std::vector<::pir::Operation*> ops_in_order; + std::unordered_set<::pir::Operation*> visited; + std::unordered_set<::pir::Operation*> ops_set(group->ops.begin(), + group->ops.end()); + std::unordered_map<::pir::Operation*, int> degree_map; + std::priority_queue, Comparator> + priority_candidates; + std::vector visited_numel; + + // Calculate the priority of a node. + // The smaller the value, the higher the priority. + // Prioritize the same shape before considering OpPattern + auto PriorityFunc = [&visited_numel](::pir::Operation* op) -> int { + auto op_shape = CompatibleInfo::ValueShape(op->result(0)); + int numel = CompatibleInfo::ShapeProduct(op_shape); + int index = -1; + for (int i = 0; i < visited_numel.size(); ++i) { + if (numel == visited_numel[i]) { + index = i; + break; + } + } + if (index == -1) { + index = visited_numel.size(); + visited_numel.push_back(numel); + } + return index * 10 + static_cast(CompatibleInfo::OpKind(*op)); + }; + + for (::pir::Operation* op : ops_set) { + auto consumers = FindConsumers(op, ops_set, virtual_consumers); + // Some nodes may have multiple edges between them, resulting in duplicates + // in the consumer. We only need to calculate once. + std::unordered_set<::pir::Operation*> consumers_without_duplicate( + consumers.begin(), consumers.end()); + degree_map[op] = consumers_without_duplicate.size(); + if (degree_map.at(op) == 0) { + priority_candidates.push(OpWithPriority{op, PriorityFunc(op)}); + } + } + + // Nested BFS, outer layer traverses priority, inner layer performs BFS on + // current priority. + while (!priority_candidates.empty()) { + ::pir::Operation* cur_priority_op = priority_candidates.top().op; + priority_candidates.pop(); + + std::queue<::pir::Operation*> bfs_queue; + bfs_queue.push(cur_priority_op); + visited.insert(cur_priority_op); + while (!bfs_queue.empty()) { + ::pir::Operation* cur = bfs_queue.front(); + bfs_queue.pop(); + + ops_in_order.push_back(cur); + auto producers = FindProducers(cur, ops_set, virtual_consumers); + std::unordered_set<::pir::Operation*> producers_without_duplicate( + producers.begin(), producers.end()); + for (::pir::Operation* op : producers_without_duplicate) { + --degree_map[op]; + // Ensure that each node is accessed only once and maintain topological + // order. + if (visited.count(op) != 0 || degree_map[op] != 0) { + continue; + } + // Perform BFS access to the current priority producers + int op_priority = PriorityFunc(op); + if (op_priority <= PriorityFunc(cur_priority_op)) { + bfs_queue.push(op); + visited.insert(op); + } else { + priority_candidates.push(OpWithPriority{op, op_priority}); + } + } + } + } + return ops_in_order; +} + +std::unordered_set<::pir::Operation*> GetMasters( + ::pir::Operation* op, + const std::unordered_set<::pir::Operation*>& ops_inline, + const std::unordered_set<::pir::Operation*>& ops_set) { + // find consumer + std::unordered_set<::pir::Operation*> visited; + std::queue<::pir::Operation*> candidates; + candidates.push(op); + std::unordered_set<::pir::Operation*> masters; + + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + auto consumers = GetConsumersInSet(candidate, ops_set); + for (auto consumer : consumers) { + if (visited.count(consumer)) { + continue; + } + if (ops_inline.count(consumer)) { + candidates.push(consumer); + visited.insert(consumer); + } else { + masters.insert(consumer); + } + } + } + + return masters; +} + +bool IsConstOp(const ::pir::Operation* op) { + static std::unordered_set const_op_type = { + "const_scalar", "fill_constant", "arange"}; + return const_op_type.count(CompatibleInfo::OpName(*op)); +} + +bool CanbeInline(::pir::Operation* op, + ::pir::Operation* reducer, + const std::vector<::pir::Operation*> consumers, + const std::unordered_set<::pir::Operation*> masters, + const GroupPtr& group, + const std::unordered_set<::pir::Operation*>& ops_set) { + if (group->output_ops.count(op)) { + return false; + } + for (auto consumer : consumers) { + if (CompatibleInfo::OpKind(*consumer) == framework::kReduction) { + return false; + } + } + + if (IsConstOp(op)) { + return true; + } + if (CompatibleInfo::OpKind(*op) == framework::kReduction) { + return false; + } + + if (consumers.size() == 1) { + return true; + } + auto op_shape = CompatibleInfo::ValueShape(op->result(0)); + if (reducer) { + // node is before reducer and node is not after reduce. + if (FindReducerInRoute(op, ops_set, GetConsumersInSet) && + !FindReducerInRoute(op, ops_set, GetProducersInSet)) { + auto input_shape = CompatibleInfo::ValueShape(reducer->result(0)); + // check with same shape with reducer input. + if (CompatibleInfo::ShapeProduct(op_shape) != + CompatibleInfo::ShapeProduct(input_shape)) { + return true; + } + } + + return false; + } else { + auto op_shape_size = CompatibleInfo::ShapeProduct(op_shape); + for (auto master : masters) { + auto master_shape = CompatibleInfo::ValueShape(master->result(0)); + auto master_size = CompatibleInfo::ShapeProduct(master_shape); + if (op_shape_size != master_size) { + return true; + } + } + + return false; + } +} + +::pir::Operation* GetMasterToComputeAt( + ::pir::Operation* op, + const std::vector<::pir::Operation*>& ops_in_order, + const std::unordered_set<::pir::Operation*>& ops_inline, + const std::unordered_set<::pir::Operation*>& ops_set, + const std::unordered_map<::pir::Operation*, ::pir::Operation*>& + virtual_consumers) { + // if node is reduction, try find horizontal to compute at. + if (CompatibleInfo::OpKind(*op) == framework::kReduction) { + // find all reduce node has done schedule. + std::unordered_set<::pir::Operation*> done_schedule; + for (auto tmp : ops_in_order) { + if (tmp == op) { + break; + } + if (CompatibleInfo::OpKind(*tmp) == framework::kReduction) { + done_schedule.insert(tmp); + } + } + // remove all consuemr reducer node of node from done_schedule. + std::unordered_set<::pir::Operation*> visited; + std::queue<::pir::Operation*> candidates; + candidates.push(op); + + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + for (auto consumer : GetConsumersInSet(candidate, ops_set)) { + // remove reduction node from done_schedule. + if (CompatibleInfo::OpKind(*consumer) == framework::kReduction) { + done_schedule.erase(consumer); + } + if (visited.count(consumer)) { + continue; + } + candidates.push(consumer); + visited.insert(consumer); + } + } + + if (done_schedule.size()) { + auto shape = CompatibleInfo::ValueShape(op->operand_source(0)); + for (auto r_op : done_schedule) { + auto rshape = CompatibleInfo::ValueShape(r_op->operand_source(0)); + if (shape == rshape) { + return r_op; + } + } + return *done_schedule.begin(); + } + } + + // collect all consumers. + std::unordered_set<::pir::Operation*> visited, masters; + std::queue<::pir::Operation*> candidates; + candidates.push(op); + + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + auto consumers = FindConsumers(candidate, ops_set, virtual_consumers); + for (auto consumer : consumers) { + if (visited.count(consumer)) { + continue; + } + if (ops_inline.count(consumer)) { + candidates.push(consumer); + visited.insert(consumer); + } else { + masters.insert(consumer); + } + } + } + + // nodes-in-order + for (int idx = 0; idx < ops_in_order.size(); ++idx) { + if (ops_in_order[idx] == op) { + for (int idy = idx - 1; idy >= 0; --idy) { + if (masters.count(ops_in_order[idy])) { + return ops_in_order[idy]; + } + } + break; + } + } + return nullptr; +} + +void LoopOrderAssignReduce(ir::IRSchedule& ir_sch, // NOLINT + const std::string& block_name, + const std::vector& axes, + const common::Target& target, + const bool just_reorder = false) { + // reorder none-last reduce axis to last. + // like: shape = [16,16,16,16,16],axes = [1,3] -> new order = [0, 2, 4, 1, 3]. + std::vector order; + int n_out_dims = ir_sch.GetLoops(block_name).size(); + for (int idx = 0; idx < n_out_dims; ++idx) { + if (std::find(axes.begin(), axes.end(), idx) == axes.end()) { + order.push_back(idx); + } + } + for (auto axis : axes) { + order.push_back(axis); + } + ir_sch.Reorder(ir_sch.GetBlock(block_name), order); + + if (just_reorder) { + return; + } + // fuse others none-reduce axis. + int last_dimension_num = n_out_dims - axes.back() - 1; + int index = n_out_dims - last_dimension_num - axes.size(); + + // fuse last_dimension_num - 1 times + for (auto idx = index; idx < index + last_dimension_num - 1; ++idx) { + ir_sch.Fuse(block_name, {index, index + 1}); + } + + auto loops = ir_sch.GetLoops(block_name); + auto psize = ir::GetLoopExtent(loops[index]); + + if (psize > target.max_num_threads()) { + for (int idx = target.max_num_threads(); idx > 0; --idx) { + if (psize % idx == 0) { + ir_sch.Split(loops[index], {-1, idx}); + break; + } + CHECK_GT(idx, 1); + } + } + + // fuse index - 1 times + for (int idx = 0; idx < index - 1; ++idx) { + ir_sch.Fuse(block_name, {0, 1}); + } +} + +void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, // NOLINT + const std::string& block_name, + const std::vector& inshape, + const std::vector& axes, + const common::Target& target) { + // If the number of current device SM is smaller than the number of SM + // required by Warp Reduce, the performance of Warp Reduce is better. + // Otherwise, use Block Reduce. + auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + int need_reduce_last_count = 1; + for (int i = 0; i < inshape.size(); i++) { + if (find(axes.begin(), axes.end(), i) == axes.end()) { + need_reduce_last_count *= inshape[i]; + } + } + int warp_reduce_need_sm_count = + ceil((need_reduce_last_count * 32) / + static_cast(target.get_max_threads_per_sm())); + // Set Num_max_threads to 32 is Warp Reduce + if (target.get_multi_processor_count() < warp_reduce_need_sm_count) { + max_num_threads = 32; + } + // find first reduce and second reduce axis. + int lane = 1; + int index = static_cast(axes.size()) - 1; + + for (; index >= 0; --index) { + if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { + break; + } + lane *= inshape[axes[index]]; + if (index == 0 && lane <= max_num_threads) { + LOG(FATAL) + << "Error! lane is less equal than max_num_threads, Please check!"; + } + if (lane >= max_num_threads / 2) { + if (lane <= max_num_threads) { + --index; + } + break; + } + } + std::vector first_axes(axes.begin(), axes.begin() + index + 1); + if (lane > max_num_threads) { + // last reduce axis size > 1024 + if (index == static_cast(axes.size()) - 1) { + int tail = max_num_threads; + bool check_bound = true; + for (; tail >= max_num_threads / 2; --tail) { + if (lane % tail == 0) { + check_bound = false; + break; + } + } + if (check_bound) { + lane = + ((lane + max_num_threads - 1) / max_num_threads) * max_num_threads; + ir_sch.Split(block_name, axes[index], {lane}); + } + int idx = max_num_threads; + do { + if (lane % idx == 0) { + ir_sch.Split(block_name, axes[index], {-1, idx}); + break; + } + --idx; + } while (idx >= max_num_threads / 2); + // if can't be divide by(1024, 512), it's shouldn't be fused. + CHECK_GE(idx, max_num_threads / 2) << "Check bounds exist, can't fuse!"; + } else { + int axis = axes[index]; + int prefix = inshape[axis]; + int tail = lane / prefix; + for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; + --idx) { + if (prefix % idx == 0) { + ir_sch.Split(block_name, axis, {-1, idx}); + break; + } + CHECK_GT(idx, (max_num_threads / 2) / tail) + << "Error, it's shouldn't fuse!"; + } + } + LoopOrderAssignReduce(ir_sch, block_name, first_axes, target); + // The current one-dimensional reduce does not make full use of SM. + // This case is optimized into a two-dimensional. + auto loops = ir_sch.GetLoops(block_name); + auto block_dim_x = loops[1].As()->extent.as_int32(); + int block_dim_y = block_dim_x <= 32 ? 2 : 1; + if (block_dim_y != 1) { + ir_sch.Split(loops[0], {-1, block_dim_y}); + } + } else { + int fuse_times = axes.size() - (index + 1) - 1; + for (int idx = 0; idx < fuse_times; ++idx) { + ir_sch.Fuse(block_name, {axes[index + 1], axes[index + 1] + 1}); + } + LoopOrderAssignReduce(ir_sch, block_name, first_axes, target, true); + // fuse axis before reduce to bind blockidx. + for (int idx = 0; idx < static_cast(inshape.size() - axes.size()) - 1; + ++idx) { + ir_sch.Fuse(block_name, {0, 1}); + } + } +} + +bool WithoutLastDimInReduce(const std::vector& shape, + const std::vector& axes) { + if (axes.empty()) { + return false; + } + // if last axis is in reduce. + if (std::find(axes.begin(), axes.end(), shape.size() - 1) != axes.end() || + std::find(axes.begin(), axes.end(), -1) != axes.end()) { + return false; + } + + int sum_last_axes = 1; + for (int idx = axes.back() + 1; idx < shape.size(); ++idx) { + sum_last_axes *= shape[idx]; + } + + if (sum_last_axes > 1) { + return true; + } else { + return false; + } +} + +void LoopAssignReduceWithoutLast(ir::IRSchedule& ir_sch, // NOLINT + const std::string& block_name, + const std::vector& inshape, + const std::vector& axes, + const common::Target& target) { + int tail = 0; + bool bound = true; + auto shape = pe::GetFirstStepReduceShape(inshape, axes, bound, tail); + CHECK(bound) << std::accumulate(inshape.begin(), + inshape.end(), + std::string(""), + [](const std::string& left, const int right) { + return left + std::to_string(right) + " "; + }); + + VLOG(4) << "LoopAssignReduceWithoutLast: THe input shape=[" + << cinn::utils::Join(inshape, ", ") << "], first step reduce shape=[" + << cinn::utils::Join(shape, ", ") << "]" + << ", axes=[" << cinn::utils::Join(axes, ", ") << "], tail=" << tail; + + // remove loop size = 1 and remove axis in axes. + std::vector nshape, axes_shift_num(axes.size(), 0); + for (int idx = 0; idx < shape.size(); ++idx) { + if (shape[idx] == 1 && idx < axes.back()) { + for (int j = 0; j < axes.size(); ++j) { + if (axes[j] == idx) { + // the loop size at axis is 1, need remove + axes_shift_num[j] = -1; + } else if (axes[j] > idx) { + // the axies value need left shift + axes_shift_num[j]++; + } + } + } else { + nshape.push_back(shape[idx]); + } + } + + // remove loop size - 1 axes + std::vector naxes; + for (int i = 0; i < axes_shift_num.size(); ++i) { + if (axes_shift_num[i] != -1) { + // the axis do not need remove, but need left shift + naxes.emplace_back(axes[i] - axes_shift_num[i]); + } + } + + // fuse tail for bind threadIdx.x + int ptail = 1; + int index = naxes.back() + 2; + for (int idx = index; idx < nshape.size(); ++idx) { + ptail *= nshape[idx]; + } + nshape.resize(index); + nshape.push_back(ptail); + + ir_sch.Split(block_name, 0, nshape); + LoopOrderAssignReduce(ir_sch, block_name, naxes, target, true); + + // fuse loop for bind blockIdx.x + auto loops = ir_sch.GetLoops(block_name); + auto fsize = nshape.size() - (naxes.size() + 2); + if (fsize > 1) { + ir_sch.Fuse({loops.begin(), loops.begin() + fsize}); + } + + auto get_tile_size = [&](int idx) { + auto range = GetLoopExtent(loops[idx - 1]); + if (range > 32) { + return 8; + } else if (range > 16) { + return 16; + } else if (range > 4) { + return 32; + } else { + return 64; + } + }; + + std::vector new_order; + loops = ir_sch.GetLoops(block_name); + if (fsize) { + int tail_index = 2; + auto tile_size = get_tile_size(tail_index); + if (GetLoopExtent(loops[tail_index]) > tile_size) { + // split index + ir_sch.Split(loops[tail_index], {-1, tile_size}); + loops = ir_sch.GetLoops(block_name); + // order + new_order = {0, 2, 3, 1}; + } else { + // order + new_order = {0, 2, 1}; + } + } else { + int tail_index = 1; + auto tile_size = get_tile_size(tail_index); + if (GetLoopExtent(loops[tail_index]) > tile_size) { + // split index + ir_sch.Split(loops[tail_index], {-1, tile_size}); + loops = ir_sch.GetLoops(block_name); + // order + new_order = {1, 2, 0}; + } else { + // order + new_order = {1, 0}; + } + } + for (int idx = new_order.size(); idx < loops.size(); ++idx) { + new_order.push_back(idx); + } + ir_sch.Reorder(block_name, new_order); +} + +std::vector GetReducerDimAttr(::pir::Operation* reduce_op) { + int rank = reduce_op->operand_source(0) + .type() + .dyn_cast<::pir::DenseTensorType>() + .dims() + .size(); + + auto attr = reduce_op->attributes().at("dim"); + auto attr_vec = attr.dyn_cast<::pir::ArrayAttribute>().AsVector(); + + std::vector dim; + for (auto vec_element : attr_vec) { + auto axis = vec_element.dyn_cast<::pir::Int64Attribute>().data(); + if (axis < 0) { + axis += rank; + } + dim.push_back(axis); + } + return dim; +} + +class InsertExpr : public ir::IRMutator<> { + public: + InsertExpr(Expr& target, Expr& anchor) : target_(target), anchor_(anchor) {} + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + private: + void Visit(const ir::ScheduleBlockRealize* expr, Expr* op) override { + IRMutator::Visit(expr, op); + } + + void Visit(const ir::For* expr, Expr* op) override { + IRMutator::Visit(expr, op); + } + + void Visit(const ir::Block* expr, Expr* op) override { + auto* node = op->As(); + auto iter = std::find(node->stmts.begin(), node->stmts.end(), anchor_); + if (iter != node->stmts.end()) { + node->stmts.insert(iter, target_); + } else { + for (auto stmt : node->stmts) { + IRMutator::Visit(&stmt, &stmt); + } + } + } + + private: + Expr target_; + Expr anchor_; +}; + +class RemoveExpr : public ir::IRMutator<> { + public: + explicit RemoveExpr(const Expr& target) : target_(target) {} + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + private: + void Visit(const ir::ScheduleBlockRealize* expr, Expr* op) override { + IRMutator::Visit(expr, op); + } + + void Visit(const ir::For* expr, Expr* op) override { + IRMutator::Visit(expr, op); + } + + void Visit(const ir::Block* expr, Expr* op) override { + auto* node = op->As(); + auto iter = std::find(node->stmts.begin(), node->stmts.end(), target_); + if (iter != node->stmts.end()) { + node->stmts.erase(iter); + } else { + for (auto stmt : node->stmts) { + IRMutator::Visit(&stmt, &stmt); + } + } + } + + private: + const Expr& target_; +}; + +void MergeLoops(ir::Expr root, + std::vector& src, // NOLINT + std::vector& dst, // NOLINT + int index) { + if (index < 0) { + return; + } + CHECK_GT(src.size(), index) << "\nindex -> " << index << "\n" << src[0]; + CHECK_GT(dst.size(), index) << "\nindex -> " << index << "\n" << dst[0]; + + if (src[0] == dst[0]) { + return; + } + + std::vector src_vars; + std::vector dst_vars; + for (int idx = 0; idx <= index; ++idx) { + src_vars.push_back(src[idx].As()->loop_var); + dst_vars.push_back(ir::Expr(dst[idx].As()->loop_var)); + } + + auto src_body = src[index].As()->body; + ReplaceExpr(&src_body, src_vars, dst_vars); + dst[index].As()->body = + ir::Block::Make({src_body, dst[index].As()->body}); + + RemoveExpr remove_expr(src[0]); + remove_expr(&root); +} + +void MergeReduceToReduce( + ir::IRSchedule& ir_sch, // NOLINT + ::pir::Operation* op, + ::pir::Operation* master, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, + const std::unordered_map& tmp_tensor_info) { + VLOG(3) << "start to MergeReduceToReduce..."; + auto op_out_name = CompatibleInfo::ValueName(op->result(0)); + auto master_out_name = CompatibleInfo::ValueName(master->result(0)); + auto shape = CompatibleInfo::ValueShape(op->operand_source(0)); + + std::vector axes = GetReducerDimAttr(master); + if (axes.empty()) { + for (int idx = 0; idx < shape.size(); idx++) { + axes.push_back(idx); + } + } + if (WithoutLastDimInReduce(shape, axes)) { + auto mshape = CompatibleInfo::ValueShape(master->operand_source(0)); + if (tmp_tensor_info.count(op_out_name + "_1")) { + if (shape == mshape) { + // second step reduce + { + auto block = ir_sch.GetBlock(op_out_name); + auto loops = ir_sch.GetLoops(master_out_name); + ir_sch.SimpleComputeAt(block, loops.back()); + // reduce init + { + auto block = ir_sch.GetBlock(op_out_name + "__reduce_init"); + auto loops = ir_sch.GetLoops(master_out_name + "__reduce_init"); + ir_sch.SimpleComputeAt(block, loops.back()); + } + } + // first step reduce + { + auto n_tensor = tmp_tensor_info.at(op_out_name + "_0"); + auto m_tensor = tmp_tensor_info.at(master_out_name + "_0"); + + auto block = ir_sch.GetBlock(n_tensor->name); + auto loops = ir_sch.GetLoops(m_tensor->name); + ir_sch.SimpleComputeAt(block, loops.back()); + // reduce init + { + auto block = ir_sch.GetBlock(n_tensor->name + "__reduce_init"); + auto loops = ir_sch.GetLoops(m_tensor->name + "__reduce_init"); + ir_sch.SimpleComputeAt(block, loops.back()); + } + } + } else { + auto n_tensor = tmp_tensor_info.at(op_out_name + "_0"); + auto m_tensor = tmp_tensor_info.at(master_out_name + "_0"); + if (n_tensor->shape == m_tensor->shape) { + // second step reduce + { + auto block = ir_sch.GetBlock(op_out_name); + auto loops = ir_sch.GetLoops(master_out_name); + ir_sch.SimpleComputeAt(block, loops.back()); + // reduce init + { + auto block = ir_sch.GetBlock(op_out_name + "__reduce_init"); + auto loops = ir_sch.GetLoops(master_out_name + "__reduce_init"); + ir_sch.SimpleComputeAt(block, loops.back()); + } + } + // first step reduce + { + auto n_tensor = tmp_tensor_info.at(op_out_name + "_0"); + auto m_tensor = tmp_tensor_info.at(master_out_name + "_0"); + + auto n_loops = ir_sch.GetLoops(n_tensor->name + "__reduce_init"); + auto m_loops = ir_sch.GetLoops(m_tensor->name + "__reduce_init"); + + CHECK_EQ(n_loops.size(), m_loops.size()); + MergeLoops(ir_sch.GetModule().GetExprs().at(0), + n_loops, + m_loops, + n_loops.size() - 1); + } + } else { + LOG(FATAL) << "not support this type fusion!"; + } + } + } else { + if (shape == mshape) { + // reduce loop + { + auto block = ir_sch.GetBlock(op_out_name); + auto loops = ir_sch.GetLoops(master_out_name); + ir_sch.SimpleComputeAt(block, loops.back()); + // reduce init + { + auto block = ir_sch.GetBlock(op_out_name + "__reduce_init"); + auto loops = ir_sch.GetLoops(master_out_name + "__reduce_init"); + ir_sch.SimpleComputeAt(block, loops.back()); + } + } + } else { + // reduce loop + { + auto block = ir_sch.GetBlock(op_out_name); + auto nloops = ir_sch.GetLoops(op_out_name); + auto mloops = ir_sch.GetLoops(master_out_name); + for (int idx = 0; idx < mloops.size(); ++idx) { + if (GetLoopExtent(nloops[idx]) != GetLoopExtent(mloops[idx])) { + ir_sch.SimpleComputeAt(block, mloops[idx - 1]); + break; + } + } + // reduce init + { + auto block = ir_sch.GetBlock(op_out_name + "__reduce_init"); + auto loops = ir_sch.GetLoops(master_out_name + "__reduce_init"); + ir_sch.SimpleComputeAt(block, loops.back()); + } + } + } + } + } else { + if (tmp_tensor_info.count(op_out_name + "_1")) { + // identity + { + auto block = ir_sch.GetBlock(op_out_name); + auto loops = ir_sch.GetLoops(master_out_name); + ir_sch.SimpleComputeAt(block, loops.back()); + } + // reduce + { + auto n_tensor = tmp_tensor_info.at(op_out_name + "_1"); + auto m_tensor = tmp_tensor_info.at(master_out_name + "_1"); + + auto block = ir_sch.GetBlock(n_tensor->name); + auto loops = ir_sch.GetLoops(m_tensor->name); + ir_sch.SimpleComputeAt(block, loops.back()); + // reduce init + { + auto block = ir_sch.GetBlock(n_tensor->name + "__reduce_init"); + auto loops = ir_sch.GetLoops(m_tensor->name + "__reduce_init"); + ir_sch.SimpleComputeAt(block, loops.back()); + } + } + // block shuffle + { + auto n_tensor = tmp_tensor_info.at(op_out_name + "_0"); + auto m_tensor = tmp_tensor_info.at(master_out_name + "_0"); + + auto n_block = ir_sch.GetBlock(n_tensor->name); + auto m_block = ir_sch.GetBlock(m_tensor->name); + + auto n_loops = ir_sch.GetLoops(n_tensor->name); + auto m_loops = ir_sch.GetLoops(m_tensor->name); + CHECK_EQ(n_loops.size(), m_loops.size()); + + std::vector src_vars; + std::vector dst_vars; + for (int idx = 0; idx < m_loops.size(); ++idx) { + src_vars.push_back(n_loops[idx].As()->loop_var); + dst_vars.push_back(ir::Expr(m_loops[idx].As()->loop_var)); + } + ReplaceExpr(&n_block, src_vars, dst_vars); + + InsertExpr insert_expr(n_block, m_block); + insert_expr(&m_loops.back()); + + RemoveExpr remove_expr(n_loops[0]); + remove_expr(&ir_sch.GetModule().GetExprs().at(0)); + } + } else if (tmp_tensor_info.count(op_out_name + "_0")) { + // identity + { + auto block = ir_sch.GetBlock(op_out_name); + auto loops = ir_sch.GetLoops(master_out_name); + ir_sch.SimpleComputeAt(block, loops.back()); + } + // shuffle reduce + { + auto n_tensor = tmp_tensor_info.at(op_out_name + "_0"); + auto m_tensor = tmp_tensor_info.at(master_out_name + "_0"); + + auto block = ir_sch.GetBlock(n_tensor->name); + auto loops = ir_sch.GetLoops(m_tensor->name); + ir_sch.SimpleComputeAt(block, loops.back()); + } + } else { + LOG(FATAL) << "Error! Unkown Reduce Type, Please Check!"; + } + } +} + +void InsertSyncThread( + ir::IRSchedule& ir_sch, // NOLINT + ::pir::Operation* op, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, + const std::unordered_map& tmp_tensor_info) { + auto shape = CompatibleInfo::ValueShape(op->operand_source(0)); + auto axes = GetReducerDimAttr(op); + if (axes.empty()) { + for (int idx = 0; idx < shape.size(); idx++) { + axes.push_back(idx); + } + } + if (!WithoutLastDimInReduce(shape, axes)) { + return; + } + + auto op_out_name = CompatibleInfo::ValueName(op->result(0)); + std::string post = ""; + for (int idx = 0;; ++idx) { + if (!tmp_tensor_info.count(op_out_name + post)) { + break; + } + auto tensor = tmp_tensor_info.at(op_out_name + post); + if (!ir_sch.HasBlock(tensor->name)) { + break; + } + + post = "_" + std::to_string(idx); + if (idx > 0) { + // insert syncthreads. + auto loops = ir_sch.GetLoops(op_out_name); + ir_sch.SyncThreads(loops[loops.size() - 2], false); + return; + } + } +} + +void MergeReduceLoop( + ir::IRSchedule& ir_sch, // NOLINT + ::pir::Operation* op, + ::pir::Operation* master, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, + const std::unordered_map& tmp_tensor_info) { + VLOG(3) << "start to MergeReduceLoop..."; + if (CompatibleInfo::OpKind(*master) == kReduction && op != master) { + MergeReduceToReduce(ir_sch, op, master, tensor_map, tmp_tensor_info); + return; + } + + auto op_out_name = CompatibleInfo::ValueName(op->result(0)); + auto master_out_name = CompatibleInfo::ValueName(master->result(0)); + int min_index_loop = INT_MAX; + std::string post_ = "", post__ = "_0"; + for (int idx = 0;; ++idx) { + if (!tmp_tensor_info.count(op_out_name + post__)) { + break; + } + auto tensor_ = tmp_tensor_info.at(op_out_name + post_); + auto tensor__ = tmp_tensor_info.at(op_out_name + post__); + if (!ir_sch.HasBlock(tensor__->name)) { + break; + } + auto dst_loops = ir_sch.GetLoops(tensor_->name); + auto src_loops = ir_sch.GetLoops(tensor__->name); + int index = -1; + while (src_loops[index + 1].As()->extent.as_int32() == + dst_loops[index + 1].As()->extent.as_int32()) { + ++index; + if (src_loops.size() == index + 1 || dst_loops.size() == index + 1) { + break; + } + } + min_index_loop = std::min(min_index_loop, index); + MergeLoops( + ir_sch.GetModule().GetExprs().at(0), src_loops, dst_loops, index); + post_ = "_" + std::to_string(idx); + post__ = "_" + std::to_string(idx + 1); + } + InsertSyncThread(ir_sch, op, tensor_map, tmp_tensor_info); + + if (op == master) return; + auto node_loops = ir_sch.GetLoops(op_out_name); + auto master_loops = ir_sch.GetLoops(master_out_name); + + int index = std::min(node_loops.size(), master_loops.size()) - 1; + do { + // if loop range is not equal. + if (node_loops[index].As()->extent.as_int32() != + master_loops[index].As()->extent.as_int32()) { + continue; + } + + MergeLoops(ir_sch.GetModule().GetExprs().at(0), + node_loops, + master_loops, + std::min(index, min_index_loop)); + if (index > min_index_loop) { + auto block = ir_sch.GetBlock(op_out_name); + auto loops = ir_sch.GetLoops(master_out_name); + ir_sch.SimpleComputeAt(block, loops.back()); + + if (ir_sch.HasBlock(op_out_name + "__reduce_init")) { + auto block = ir_sch.GetBlock(op_out_name + "__reduce_init"); + auto loops = ir_sch.GetLoops(master_out_name); + ir_sch.SimpleComputeAt(block, loops.back()); + } + } + + break; + } while (--index >= 0); +} + +void LoopComputeAt( + ir::IRSchedule& ir_sch, // NOLINT + ::pir::Operation* op, + ::pir::Operation* master, + const GroupPtr& group, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, + const std::unordered_map& tmp_tensor_info) { + auto op_out_name = CompatibleInfo::ValueName(op->result(0)); + if (!group->output_ops.count(op)) { + auto block = ir_sch.GetBlock(op_out_name); + ir_sch.SetBuffer(block, "local"); + } + + if (CompatibleInfo::OpKind(*op) == framework::kReduction) { + MergeReduceLoop(ir_sch, op, master, tensor_map, tmp_tensor_info); + return; + } + + if (op == master) return; + auto master_out = master->result(0); + auto master_out_name = CompatibleInfo::ValueName(master_out); + + auto node_loops = ir_sch.GetLoops(op_out_name); + auto master_loops = ir_sch.GetLoops(master_out_name); + + if (CompatibleInfo::OpKind(*master) == framework::kReduction) { + // find real master loops. + std::string prefix = "", post = ""; + for (int idx = 0;; ++idx) { + if (!tmp_tensor_info.count(master_out_name + post)) { + break; + } + auto tensor = tmp_tensor_info.at(master_out_name + post); + if (!ir_sch.HasBlock(tensor->name)) { + break; + } + + prefix = post; + post = "_" + std::to_string(idx); + } + auto tensor = tmp_tensor_info.at(master_out_name + prefix); + master_loops = ir_sch.GetLoops(tensor->name); + } + + int index = std::min(node_loops.size(), master_loops.size()) - 1; + do { + // if loop range is not equal. + if (node_loops[index].As()->extent.as_int32() != + master_loops[index].As()->extent.as_int32()) { + continue; + } + MergeLoops( + ir_sch.GetModule().GetExprs().at(0), node_loops, master_loops, index); + + break; + } while (--index >= 0); +} + +void LoopAssignReduce( + ir::IRSchedule& ir_sch, // NOLINT + ::pir::Operation* op, + ::pir::Operation* reducer, + const Target& target, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, + const std::unordered_map& tmp_tensor_info) { + // if node is reducer, return. + if (CompatibleInfo::OpKind(*op) == framework::kReduction) { + return; + } + ::pir::Value op_data = op->result(0); + ::pir::Value reducer_data = reducer->result(0); + std::string op_data_name = CompatibleInfo::ValueName(op_data); + std::string reducer_data_name = CompatibleInfo::ValueName(reducer_data); + + // flatten loops. + auto loops = ir_sch.GetLoops(op_data_name); + // do loop flatten. + if (CompatibleInfo::OpKind(*op) == framework::kElementWise) { + ir_sch.FlattenLoops(loops, true); + } else { + ir_sch.FlattenLoops(loops, false); + } + std::vector shape = + CompatibleInfo::ValueShape(reducer->operand_source(0)); + auto axes = GetReducerDimAttr(reducer); + if (axes.empty()) { + for (int idx = 0; idx < shape.size(); idx++) { + axes.push_back(idx); + } + } + auto copy_loop_info = [](std::vector& loops, + std::vector& rloops) { + for (int idx = 0; idx < std::min(rloops.size(), loops.size()); ++idx) { + auto l0 = rloops[idx].As(); + auto l1 = loops[idx].As(); + l1->set_for_type(l0->for_type()); + l1->set_bind_info(l0->bind_info()); + } + }; + std::vector op_shape = CompatibleInfo::ValueShape(op_data); + // The output shape of node is different from that of reduce node + if (CompatibleInfo::ShapeProduct(shape) != + CompatibleInfo::ShapeProduct(op_shape)) { + // get loop factors of reduce node + int extend = 1; + std::vector factors; + loops = ir_sch.GetLoops(op_data_name); + auto rloops = ir_sch.GetLoops(reducer_data_name); + + for (auto& loop : rloops) { + if (extend >= loops.back().As()->extent.as_int32() && + factors.size() && loop.As()->extent.as_int32() > 1) { + break; + } + extend *= loop.As()->extent.as_int32(); + factors.push_back(loop.As()->extent.as_int32()); + } + + // If there are IfThenElse stmt in loop, we need to find out the indices in + // condition, and special treatment should be applied to loops with these + // indices. We apply two step split on loop of src node to align the loop of + // reduce node. + std::unordered_set loop_index_in_if; + auto first_reduce_loop = rloops.front(); + // collect if + auto if_checker = [](const Expr* x) { return x->As(); }; + auto if_set = ir::ir_utils::CollectIRNodesWithoutTensor( + first_reduce_loop.As()->body, if_checker); + const std::string& reduce_block_name = reducer_data_name; + for (auto if_expr : if_set) { + auto checker = [reduce_block_name](const Expr* x) { + return x->As() && + x->As() + ->schedule_block.As() + ->name == reduce_block_name; + }; + auto blocks_in_if = + ir::ir_utils::CollectIRNodesWithoutTensor(if_expr, checker); + if (!blocks_in_if.empty()) { + ir::Expr condition = if_expr.As()->condition; + auto indices_in_if = ir::ir_utils::CollectIRNodesWithoutTensor( + condition, [](const Expr* x) { return x->As(); }); + for (int i = 0; i < rloops.size(); ++i) { + std::string var_name = rloops[i].As()->loop_var->name; + auto find_var_iter = + std::find_if(indices_in_if.begin(), + indices_in_if.end(), + [&var_name](const ir::Expr& x) { + return x.As()->name == var_name; + }); + if (find_var_iter != indices_in_if.end()) { + loop_index_in_if.insert(i); + } + } + break; + } + } + // prepare factors of two step split + std::vector first_step_factors; + std::vector second_step_factors; + int second_start_loop_index; + for (int i = 0; i < factors.size(); ++i) { + if (loop_index_in_if.count(i) == 0) { + first_step_factors.push_back(factors[i]); + } else if (loop_index_in_if.count(i) != 0 && + second_step_factors.empty()) { + first_step_factors.push_back(-1); + second_step_factors.push_back(factors[i]); + second_start_loop_index = i; + } else if (loop_index_in_if.count(i) != 0 && + !second_step_factors.empty()) { + second_step_factors.push_back(factors[i]); + } + } + // do two step split + if (!first_step_factors.empty()) { + ir_sch.Split(loops.back(), first_step_factors); + loops = ir_sch.GetLoops(op_data_name); + } + if (!second_step_factors.empty()) { + ir_sch.Split(loops.at(second_start_loop_index), second_step_factors); + loops = ir_sch.GetLoops(op_data_name); + } + + // copy loop info form rloops. + copy_loop_info(loops, rloops); + return; + } + // node output is same shape with reduce input. + if (WithoutLastDimInReduce(shape, axes)) { + // if using two strep reduce. + if (tmp_tensor_info.count(reducer_data_name + "_1")) { + VLOG(4) << "Try assign loop of " << op_data_name + << " into two strep reduce loop of " << reducer_data_name; + LoopAssignReduceWithoutLast(ir_sch, op_data_name, shape, axes, target); + auto nloops = ir_sch.GetLoops(op_data_name); + auto rloops = + ir_sch.GetLoops(tmp_tensor_info.at(reducer_data_name + "_0")->name); + + VLOG(4) << op_data_name << "'s loop level is " << nloops.size() + << ", and " << reducer_data_name << "'s loop level is " + << rloops.size(); + if (nloops.size() < rloops.size()) { + ir_sch.Split(nloops[0], {1, -1}); + } + + nloops = ir_sch.GetLoops(op_data_name); + // copy loop info form rloops. + copy_loop_info(nloops, rloops); + } else { + VLOG(4) << "Try assign loop of " << op_data_name + << " into reduce loop of " << reducer_data_name; + + auto nloops = ir_sch.GetLoops(op_data_name); + ir_sch.Split(nloops.back(), shape); + LoopOrderAssignReduce(ir_sch, op_data_name, axes, target); + nloops = ir_sch.GetLoops(op_data_name); + auto rloops = + ir_sch.GetLoops(tensor_map.find(reducer_data)->second->name); + if (nloops.size() < rloops.size()) { + ir_sch.Split(nloops[0], {1, -1}); + } + + nloops = ir_sch.GetLoops(op_data_name); + // copy loop info form rloops. + copy_loop_info(nloops, rloops); + } + } else { + if (tmp_tensor_info.count(reducer_data_name + "_1")) { + { + auto nloops = ir_sch.GetLoops(op_data_name); + ir_sch.Split(nloops.back(), shape); + } + LoopAssignReduceWithLast(ir_sch, op_data_name, shape, axes, target); + + auto nloops = ir_sch.GetLoops(op_data_name); + auto rloops = + ir_sch.GetLoops(tmp_tensor_info.at(reducer_data_name + "_1")->name); + if (nloops.size() < rloops.size()) { + ir_sch.Split(nloops[0], {1, -1}); + } + + nloops = ir_sch.GetLoops(op_data_name); + // copy loop info form rloops. + copy_loop_info(nloops, rloops); + } else if (tmp_tensor_info.count(reducer_data_name + "_0")) { + auto tensor = tmp_tensor_info.at(reducer_data_name + "_0"); + auto rloops = ir_sch.GetLoops(tensor->name); + std::vector factors; + for (auto& loop : rloops) { + // FIXME(Aurelius84): Need add broadcast_to Op + int factor = loop.As()->extent.as_int32(); + if (factor == 1) { + factor = -1; + } + factors.push_back(factor); + } + auto nloops = ir_sch.GetLoops(op_data_name); + ir_sch.Split(nloops.back(), factors); + + nloops = ir_sch.GetLoops(op_data_name); + // copy loop info form rloops. + copy_loop_info(nloops, rloops); + } else { + LOG(FATAL) << "Error! Unkown Reduce Type!"; + } + } +} + +std::unordered_map GetOutValueSet( + const std::unordered_set<::pir::Operation*>& ops_set) { + std::unordered_map out_value_set; + for (auto* op : ops_set) { + out_value_set[CompatibleInfo::ValueName(op->result(0))] = op->result(0); + } + return out_value_set; +} + +void SyncThreadWithShared( + ir::IRSchedule& ir_sch, // NOLINT + const GroupPtr& group, + const std::unordered_set<::pir::Operation*>& ops_inline, + const std::unordered_set<::pir::Operation*>& ops_set, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map) { + auto exprs_inorder = ir_sch.GetAllBlocks(); + auto op_out_set = GetOutValueSet(ops_set); + + std::unordered_set sync_mark; + auto check_sync_mark = [&](const int start, const std::string& m_id) { + for (int idx = start + 1; exprs_inorder.size(); ++idx) { + auto expr = exprs_inorder[idx]; + CHECK(expr.As()); + CHECK(expr.As() + ->schedule_block.As()); + auto block = expr.As() + ->schedule_block.As(); + + if (sync_mark.count(block->name)) { + return false; + } + + if (block->name == m_id) { + return true; + } + } + return false; + }; + + for (int idx = 0; idx < exprs_inorder.size() - 1; ++idx) { + auto expr = exprs_inorder[idx]; + CHECK(expr.As()); + CHECK(expr.As() + ->schedule_block.As()); + auto block = expr.As() + ->schedule_block.As(); + + if (!op_out_set.count(block->name)) { + continue; + } + auto op_data = op_out_set.find(block->name)->second; + auto* op = op_data.dyn_cast<::pir::OpResult>().owner(); + auto op_shape = CompatibleInfo::ValueShape(op_data); + + auto masters = GetMasters(op, ops_inline, ops_set); + if (masters.empty()) { + continue; + } + + bool do_set_buffer_to_shared = false; + for (auto master : masters) { + auto master_data = master->result(0); + auto master_shape = CompatibleInfo::ValueShape(master_data); + if (CompatibleInfo::OpKind(*master) == framework::kReduction) { + master_shape = CompatibleInfo::ValueShape(master->operand_source(0)); + } + + auto op_shape_size = CompatibleInfo::ShapeProduct(op_shape); + auto master_shape_size = CompatibleInfo::ShapeProduct(master_shape); + std::string master_data_name = CompatibleInfo::ValueName(master_data); + if (op_shape_size != master_shape_size) { + if (check_sync_mark(idx, master_data_name)) { + auto loops = ir_sch.GetLoops(master_data_name); + ir_sch.SyncThreads(loops.back(), false); + sync_mark.insert(master_data_name); + } + do_set_buffer_to_shared = true; + } + } + if (do_set_buffer_to_shared && + group->output_ops.find(op) == group->output_ops.end()) { + auto block = ir_sch.GetBlock(CompatibleInfo::ValueName(op_data)); + ir_sch.SetBuffer(block, "shared"); + } + } +} + +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_util.h b/paddle/cinn/hlir/framework/pir/op_lowering_util.h new file mode 100644 index 0000000000000..784c57f9feea1 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/op_lowering_util.h @@ -0,0 +1,104 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/cinn/hlir/framework/pir/group.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/tensor.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { +using GroupPtr = std::shared_ptr; + +std::unordered_map<::pir::Operation*, ::pir::Operation*> BuildVirtualConsumer( + const GroupPtr& group); + +std::vector<::pir::Value*> GetAllNodeData(::pir::Operation* op); + +std::vector<::pir::Operation*> GetConsumers(::pir::Operation* op); + +bool IsConstOp(const ::pir::Operation* op); + +std::vector<::pir::Operation*> GetConsumersInSet( + ::pir::Operation* op, const std::unordered_set<::pir::Operation*>& ops); + +std::vector<::pir::Operation*> TopologicalOrder( + const GroupPtr& group, + const std::unordered_map<::pir::Operation*, ::pir::Operation*>& + virtual_consumers); + +std::vector<::pir::Operation*> BFSTopologicalOrderWithPriority( + const GroupPtr& group, + const std::unordered_map<::pir::Operation*, ::pir::Operation*>& + virtual_consumers); + +::pir::Operation* FindGlobalReducer( + const std::vector<::pir::Operation*>& ops_in_order); + +::pir::Operation* FindNearestReducer( + ::pir::Operation* op, const std::unordered_set<::pir::Operation*>& ops_set); + +bool CanbeInline(::pir::Operation* op, + ::pir::Operation* reducer, + const std::vector<::pir::Operation*> consumers, + const std::unordered_set<::pir::Operation*> masters, + const GroupPtr& group, + const std::unordered_set<::pir::Operation*>& ops_set); + +::pir::Operation* GetMasterToComputeAt( + ::pir::Operation* op, + const std::vector<::pir::Operation*>& ops_in_order, + const std::unordered_set<::pir::Operation*>& ops_inline, + const std::unordered_set<::pir::Operation*>& ops_set, + const std::unordered_map<::pir::Operation*, ::pir::Operation*>& + virtual_consumers); + +std::unordered_set<::pir::Operation*> GetMasters( + ::pir::Operation* op, + const std::unordered_set<::pir::Operation*>& ops_inline, + const std::unordered_set<::pir::Operation*>& ops_set); + +void LoopAssignReduce( + ir::IRSchedule& ir_sch, // NOLINT + ::pir::Operation* op, + ::pir::Operation* reducer, + const Target& target, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, + const std::unordered_map& tmp_tensor_info); + +void LoopComputeAt( + ir::IRSchedule& ir_sch, // NOLINT + ::pir::Operation* op, + ::pir::Operation* master, + const GroupPtr& group, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, + const std::unordered_map& tmp_tensor_info); + +void SyncThreadWithShared( + ir::IRSchedule& ir_sch, // NOLINT + const GroupPtr& group, + const std::unordered_set<::pir::Operation*>& ops_inline, + const std::unordered_set<::pir::Operation*>& ops_set, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map); + +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/op_mapper.cc b/paddle/cinn/hlir/framework/pir/op_mapper.cc new file mode 100644 index 0000000000000..e49f0df2e3c54 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/op_mapper.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/framework/pir/op_mapper.h" +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { + +namespace { + +void AppendAttrForReduceOp(const ::pir::Operation& op, + utils::AttributeMap& attrs) { // NOLINT + auto attr = op.attributes().at("dim"); + auto attr_vec = attr.dyn_cast<::pir::ArrayAttribute>().AsVector(); + + std::vector dim; + for (auto vec_element : attr_vec) { + dim.push_back(vec_element.dyn_cast<::pir::Int64Attribute>().data()); + } + + attrs["dim"] = dim; +} + +void AppendAttrForBoadcastToOp(const ::pir::Operation& op, + utils::AttributeMap& attrs) { // NOLINT + auto axes_attr = op.attributes().at("broadcast_axes"); + auto attr_vec = axes_attr.dyn_cast<::pir::ArrayAttribute>().AsVector(); + + std::vector axis; + for (auto vec_element : attr_vec) { + axis.push_back(vec_element.dyn_cast<::pir::Int64Attribute>().data()); + } + + attrs["broadcast_axes"] = axis; + + auto out_shape_attr = op.attributes().at("out_shape"); + auto out_shape_attr_vec = + out_shape_attr.dyn_cast<::pir::ArrayAttribute>().AsVector(); + + std::vector out_shape; + for (auto vec_element : out_shape_attr_vec) { + out_shape.push_back(vec_element.dyn_cast<::pir::Int64Attribute>().data()); + } + + attrs["out_shape"] = out_shape; +} + +} // namespace + +#define REGISTER_OPERAND_RULE(OP, args...) \ + operand_funcs_[paddle::dialect::OP::name()] = []() -> std::vector { \ + return {args}; \ + }; + +#define REGISTER_ATTR_RULE(OP, func) \ + attr_funcs_[cinn::dialect::OP::name()] = func; + +void OpMapper::RegisterMapRules() { + // max(x, dim) -> reduce_max(x) + REGISTER_OPERAND_RULE(MaxOp, 0); + REGISTER_OPERAND_RULE(SumOp, 0); + REGISTER_OPERAND_RULE(MinOp, 0); + REGISTER_OPERAND_RULE(ProdOp, 0); + REGISTER_ATTR_RULE(ReduceMaxOp, AppendAttrForReduceOp); + REGISTER_ATTR_RULE(ReduceSumOp, AppendAttrForReduceOp); + REGISTER_ATTR_RULE(BroadcastOp, AppendAttrForBoadcastToOp); +} + +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/op_mapper.h b/paddle/cinn/hlir/framework/pir/op_mapper.h new file mode 100644 index 0000000000000..0a0527cf9abf1 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/op_mapper.h @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/cinn/utils/type_defs.h" +#include "paddle/pir/core/operation.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { + +enum MapperType { + OPERAND, + ATTRIBUTE, +}; + +class OpMapper { + using OprandIndexsFunction = std::function()>; + using AppendAttrFunction = + std::function; // NOLINT + + public: + static OpMapper& Instance() { + static OpMapper instance; + return instance; + } + + bool has(const ::pir::Operation& op, MapperType type) const { + if (type == MapperType::OPERAND) { + return operand_funcs_.find(op.name()) != operand_funcs_.end(); + } else if (type == MapperType::ATTRIBUTE) { + return attr_funcs_.find(op.name()) != attr_funcs_.end(); + } + return false; + } + + std::vector<::pir::Value> RealOprandSources( + const ::pir::Operation& op) const { + CHECK(has(op, MapperType::OPERAND)) + << "Not register OprandIndexsFunction for " << op.name(); + std::vector<::pir::Value> inputs; + for (auto idx : operand_funcs_.at(op.name())()) { + inputs.push_back(op.operand_source(idx)); + } + return inputs; + } + + void AppendVariantAttrs(const ::pir::Operation& op, + utils::AttributeMap& attrs) const { // NOLINT + CHECK(has(op, MapperType::ATTRIBUTE)) + << "Not register AppendAttrFunction for " << op.name(); + attr_funcs_.at(op.name())(op, attrs); + } + + private: + OpMapper() { RegisterMapRules(); } + void RegisterMapRules(); + + std::unordered_map operand_funcs_; + std::unordered_map attr_funcs_; +}; + +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/utils.cc b/paddle/cinn/hlir/framework/pir/utils.cc new file mode 100644 index 0000000000000..aeb502d511fd8 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/utils.cc @@ -0,0 +1,243 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/framework/pir/utils.h" + +#include +#include +#include "glog/logging.h" + +#include "paddle/cinn/hlir/framework/op.h" +#include "paddle/cinn/hlir/framework/pir/op_mapper.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/builtin_type.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { + +// Mapping PaddleDialect Op into CINN AST Compute register Op +const std::unordered_map CompatibleInfo::OP_NAMES = { + {"pd_op.full", "fill_constant"}, + {"pd_op.sum", "reduce_sum"}, + {"pd_op.max", "reduce_max"}, + {"pd_op.add", "elementwise_add"}, + {"pd_op.subtract", "subtract"}, + {"pd_op.divide", "divide"}, + {"cinn_op.broadcast", "broadcast_to"}}; + +// Tagging PaddleDialect Op with REGITER_OP_MAPPER(OP) +const std::unordered_set CompatibleInfo::CINN_WHITE_OPS = { + "subtract", "divide", "broadcast_to", "multiply"}; + +bool CompatibleInfo::IsSupportCinn(const ::pir::Operation& op) { + return CINN_WHITE_OPS.find(CompatibleInfo::OpName(op)) != + CINN_WHITE_OPS.end(); +} + +std::string CompatibleInfo::OpName(const ::pir::Operation& op) { + std::string name = op.name(); + if (OP_NAMES.count(name)) { + return OP_NAMES.at(name); + } + auto pos = name.find("."); + if (pos == std::string::npos) { + return name; + } + auto cinn_op_name = name.substr(pos + 1); + VLOG(4) << "GetOpName: " << name << " -> " << cinn_op_name; + return cinn_op_name; +} + +std::string CompatibleInfo::ValueName(const ::pir::Value& value) { + size_t hash_key = std::hash<::pir::Value>()(value); + return cinn::common::Context::Global().PrettyUniqName( + hash_key, CompatibleInfo::kNamePrefix); +} + +std::string CompatibleInfo::OpFuncName(const ::pir::Operation& op) { + std::string op_name = OpName(op); + std::string func_name = + cinn::common::Context::Global().NewName("fn_" + op_name); + return func_name; +} + +std::string CompatibleInfo::GroupOpsName( + const std::vector<::pir::Operation*>& ops) { + std::string name = "fn"; + for (auto* op : ops) { + std::string op_name = OpName(*op); + name += "_" + cinn::common::Context::Global().NewName(op_name); + } + return name; +} + +std::vector CompatibleInfo::InputNames(const ::pir::Operation& op, + bool allow_duplicate) { + std::vector names; + std::unordered_set repeat; + for (int i = 0; i < op.num_operands(); ++i) { + auto value = op.operand_source(i); + std::string name = CompatibleInfo::ValueName(value); + if (!allow_duplicate && repeat.count(name)) { + continue; + } + repeat.insert(name); + names.push_back(name); + } + return names; +} + +std::vector CompatibleInfo::OutputNames(::pir::Operation& op) { + std::vector names; + for (int i = 0; i < op.num_results(); ++i) { + auto value = op.result(i); + std::string name = CompatibleInfo::ValueName(value); + names.push_back(std::move(name)); + } + return names; +} + +std::vector<::pir::Value> CompatibleInfo::RealOperandSources( + const ::pir::Operation& op) { + if (OpMapper::Instance().has(op, MapperType::OPERAND)) { + return OpMapper::Instance().RealOprandSources(op); + } else { + return op.operands_source(); + } +} + +utils::Attribute CompatibleInfo::ConvertAttribute( + const ::pir::Attribute& src_attr) { + utils::Attribute dst_attr; + if (src_attr.isa<::pir::BoolAttribute>()) { + dst_attr = src_attr.dyn_cast<::pir::BoolAttribute>().data(); + } else if (src_attr.isa<::pir::FloatAttribute>()) { + dst_attr = src_attr.dyn_cast<::pir::FloatAttribute>().data(); + } else if (src_attr.isa<::pir::Int32Attribute>()) { + dst_attr = src_attr.dyn_cast<::pir::Int32Attribute>().data(); + } else if (src_attr.isa<::pir::StrAttribute>()) { + dst_attr = src_attr.dyn_cast<::pir::StrAttribute>().AsString(); + } else if (src_attr.isa<::pir::Int64Attribute>()) { + dst_attr = src_attr.dyn_cast<::pir::Int64Attribute>().data(); + } else if (src_attr.isa<::pir::DoubleAttribute>()) { + dst_attr = src_attr.dyn_cast<::pir::DoubleAttribute>().data(); + } else if (src_attr.isa()) { + auto& arr = src_attr.dyn_cast() + .data() + .GetData(); + std::vector val(arr.begin(), arr.end()); + dst_attr = val; + } else if (src_attr.isa()) { + auto dtype = src_attr.dyn_cast().data(); + dst_attr = phi::DataTypeToString(dtype); + } else if (src_attr.isa<::pir::ArrayAttribute>()) { + auto attr_vec = src_attr.dyn_cast<::pir::ArrayAttribute>().AsVector(); + if (attr_vec.size() > 0) { + if (attr_vec[0].isa<::pir::Int32Attribute>()) { + std::vector vec_int32; + for (auto vec_element : attr_vec) { + vec_int32.push_back( + vec_element.dyn_cast<::pir::Int32Attribute>().data()); + } + dst_attr = vec_int32; + + } else if (attr_vec[0].isa<::pir::Int64Attribute>()) { + std::vector vec_int64; + for (auto vec_element : attr_vec) { + vec_int64.push_back( + vec_element.dyn_cast<::pir::Int64Attribute>().data()); + } + + dst_attr = vec_int64; + } else { + LOG(FATAL) + << "only suuport int32 and int64 attribute in ArrayAttribute"; + } + } + } else { + LOG(FATAL) << "unknown Attribute: " << src_attr; + } + + return dst_attr; +} + +utils::AttributeMap CompatibleInfo::ConvertAttributes( + const ::pir::Operation& op) { + auto& src_attrs = op.attributes(); + utils::AttributeMap dst_attrs; + for (auto& item : src_attrs) { + VLOG(4) << "deal with " << item.first; + if (item.first == ::pir::kStopGradientAttrName) { + continue; + } else if (item.second.isa()) { + auto is_cpu = + item.second.dyn_cast().data() == + phi::CPUPlace(); + dst_attrs["force_cpu"] = is_cpu; + } else { + dst_attrs[item.first] = std::move(ConvertAttribute(item.second)); + } + } + + if (OpMapper::Instance().has(op, MapperType::ATTRIBUTE)) { + OpMapper::Instance().AppendVariantAttrs(op, dst_attrs); + } + VLOG(4) << "dst_attrs.size(): " << dst_attrs.size(); + return dst_attrs; +} + +#define CASE_TYPE(src, dst) \ + else if (type.isa<::pir::src>()) return common::dst(); + +common::Type CompatibleInfo::ConvertIRType(::pir::Type type) { + if (type.isa<::pir::BFloat16Type>()) return common::BF16(); + CASE_TYPE(Float16Type, F16) + CASE_TYPE(Float32Type, F32) + CASE_TYPE(Float64Type, F64) + CASE_TYPE(Int8Type, I8) + CASE_TYPE(UInt8Type, UI8) + CASE_TYPE(Int16Type, I16) + CASE_TYPE(Int32Type, I32) + CASE_TYPE(Int64Type, I64) + CASE_TYPE(IndexType, I32) + CASE_TYPE(BoolType, UI1) + + LOG(FATAL) << "unknown ir::Type " << type; +} + +int CompatibleInfo::ShapeProduct(const std::vector& shape) { + return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); +} + +OpPatternKind CompatibleInfo::OpKind(const ::pir::Operation& op) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + const hlir::framework::Operator* cinn_op = + Operator::Get(CompatibleInfo::OpName(op)); + CHECK(op_pattern_dict.Find(cinn_op)); + return op_pattern_dict[cinn_op]; +} + +std::vector CompatibleInfo::ValueShape(const ::pir::Value& value) { + auto& dim = value.type().dyn_cast<::pir::DenseTensorType>().dims(); + return phi::vectorize(dim); +} + +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/new_ir/utils.h b/paddle/cinn/hlir/framework/pir/utils.h similarity index 58% rename from paddle/cinn/hlir/framework/new_ir/utils.h rename to paddle/cinn/hlir/framework/pir/utils.h index 953dc6672bc18..d588f458990a0 100644 --- a/paddle/cinn/hlir/framework/new_ir/utils.h +++ b/paddle/cinn/hlir/framework/pir/utils.h @@ -15,19 +15,38 @@ #pragma once #include #include +#include #include "paddle/cinn/common/context.h" +#include "paddle/cinn/common/type.h" +#include "paddle/cinn/hlir/framework/op.h" +#include "paddle/cinn/utils/type_defs.h" #include "paddle/pir/core/operation.h" namespace cinn { namespace hlir { namespace framework { -namespace newir { + +namespace pir { + +struct CUDAJITInfo { + void* fn_ptr; + std::vector block_dims; + std::vector grid_dims; + void* compiler; +}; struct CompatibleInfo { - static constexpr char* kNamePrefix = "var_"; + static constexpr char* kNamePrefix = "var"; // TODO(Aurelius): Need add name mapping logic in REGISTER_CINN_OP // macros or attempt to unify Op name with Paddle and CINN. static const std::unordered_map OP_NAMES; + // NOTE(Aurelius): Some ops in CINN register different + // name between OpMapper and Compute/Schedule, such as + // 'subtract': 1. OpMapper: 'elementwise_sub'; 2. Compute/Schedule: + // 'subtract'. + static const std::unordered_set CINN_WHITE_OPS; + + static bool IsSupportCinn(const ::pir::Operation& op); static std::string OpName(const ::pir::Operation& op); @@ -41,9 +60,24 @@ struct CompatibleInfo { bool allow_duplicate = false); static std::vector OutputNames(::pir::Operation& op); // NOLINT + + static std::vector<::pir::Value> RealOperandSources( + const ::pir::Operation& op); + + static utils::Attribute ConvertAttribute(const ::pir::Attribute& src_attr); + + static utils::AttributeMap ConvertAttributes(const ::pir::Operation& op); + + static common::Type ConvertIRType(::pir::Type type); + + static std::vector ValueShape(const ::pir::Value& value); + + static int ShapeProduct(const std::vector& shape); + + static OpPatternKind OpKind(const ::pir::Operation& op); }; -} // namespace newir +} // namespace pir } // namespace framework } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/framework/new_ir_compiler.cc b/paddle/cinn/hlir/framework/pir_compiler.cc similarity index 72% rename from paddle/cinn/hlir/framework/new_ir_compiler.cc rename to paddle/cinn/hlir/framework/pir_compiler.cc index 2a40531196da4..efd1d6999e3af 100644 --- a/paddle/cinn/hlir/framework/new_ir_compiler.cc +++ b/paddle/cinn/hlir/framework/pir_compiler.cc @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/hlir/framework/new_ir_compiler.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" #include -#include "paddle/cinn/hlir/framework/new_ir/utils.h" -#include "paddle/cinn/utils/attribute_util.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/pir/core/builtin_type.h" @@ -26,22 +25,62 @@ namespace framework { // TODO(Aurelius84): Need abstract this logic to implement Proxy for // the co-existance with GraphCompiler. -std::unique_ptr NewIRCompiler::Build() { +std::unique_ptr PirCompiler::Build() { m_builder_.Clear(); // NOTE(Aurelius84): Currently only support each op for one group - std::vector groups; + std::vector groups; for (auto it = program_.block()->begin(); it != program_.block()->end(); ++it) { std::vector<::pir::Operation*> ops = {*it}; - groups.push_back(std::make_shared(ops)); + groups.push_back(std::make_shared(ops)); } VLOG(4) << "Groups size: " << groups.size(); return std::move(Build(groups)); } -std::unique_ptr NewIRCompiler::Build( - const std::vector& groups) { - auto op_lowerer = CreateOpLowerer(target_); +std::vector PirCompiler::BuildCUDAJITInfo( + const std::vector& groups) { + std::vector vec_res; + + auto op_lowerer = CreateOpLowerer(target_); + + std::vector> lowered_funcs; + for (int i = 0; i < groups.size(); ++i) { + lowered_funcs.emplace_back(op_lowerer.Lower(groups[i])); + } + + for (auto&& lowered_func : lowered_funcs) { + ProcessFunction(lowered_func); + } + + compiler_ = backends::Compiler::Create(target_); + auto build_module = m_builder_.Build(); + compiler_->Build(build_module, ""); + + auto instructions = BuildInstructions(groups); + + auto fn_ptrs = compiler_->GetFnPtr(); + + auto* compilter_ptr = compiler_.release(); + for (int idx = 0; idx < groups.size(); ++idx) { + pir::CUDAJITInfo jit_info; + jit_info.fn_ptr = fn_ptrs[idx]; + jit_info.compiler = reinterpret_cast(compilter_ptr); + + lowered_funcs[idx][0]->cuda_axis_info.CopyBlockDimsTo( + &(jit_info.block_dims)); + + lowered_funcs[idx][0]->cuda_axis_info.CopyGridDimsTo(&(jit_info.grid_dims)); + + vec_res.push_back(jit_info); + } + + return vec_res; +} + +std::unique_ptr PirCompiler::Build( + const std::vector& groups) { + auto op_lowerer = CreateOpLowerer(target_); std::vector> lowered_funcs; for (int i = 0; i < groups.size(); ++i) { @@ -72,7 +111,7 @@ std::unique_ptr NewIRCompiler::Build( return std::make_unique(scope_, std::move(instructions)); } -void NewIRCompiler::ProcessFunction( +void PirCompiler::ProcessFunction( const std::vector& lowered_funcs) { for (auto&& func : lowered_funcs) { for (auto&& arg : func->args) { @@ -97,18 +136,18 @@ void NewIRCompiler::ProcessFunction( } } -std::vector> NewIRCompiler::BuildInstructions( - const std::vector& groups) { +std::vector> PirCompiler::BuildInstructions( + const std::vector& groups) { std::vector> instructions; for (int idx = 0; idx < groups.size(); ++idx) { - auto& fn_name = groups[idx]->fn_name; + auto fn_name = groups[idx]->FuncName(); auto instr = std::unique_ptr(new Instruction(target_, scope_.get(), groups[idx]->input_names, groups[idx]->output_names, fn_name)); - VLOG(1) << "Lookup kernel name: " << fn_name; + VLOG(4) << "Lookup kernel name: " << fn_name; auto* fn_ptr = compiler_->Lookup(fn_name); CHECK(fn_ptr); instr->SetLoweredFunc(reinterpret_cast(fn_ptr), fn_name); @@ -130,7 +169,7 @@ std::shared_ptr BuildScope(const Target& target, if (visited.count(value) > 0) return; visited.emplace(value); - std::string name = newir::CompatibleInfo::ValueName(value); + std::string name = pir::CompatibleInfo::ValueName(value); auto type_info = value.type().dyn_cast(); auto* var = scope->Var(name); auto& tensor = absl::get(*var); @@ -140,7 +179,7 @@ std::shared_ptr BuildScope(const Target& target, shape.push_back(Shape::dim_t(type_info.dims()[i])); } tensor->Resize(Shape{shape}); - tensor->set_type(utils::ConvertIRType(type_info.dtype())); + tensor->set_type(pir::CompatibleInfo::ConvertIRType(type_info.dtype())); }; for (auto it = program.block()->begin(); it != program.block()->end(); ++it) { diff --git a/paddle/cinn/hlir/framework/new_ir_compiler.h b/paddle/cinn/hlir/framework/pir_compiler.h similarity index 68% rename from paddle/cinn/hlir/framework/new_ir_compiler.h rename to paddle/cinn/hlir/framework/pir_compiler.h index 62c3d97a21a41..acb4b5c1e9e21 100644 --- a/paddle/cinn/hlir/framework/new_ir_compiler.h +++ b/paddle/cinn/hlir/framework/pir_compiler.h @@ -28,29 +28,32 @@ namespace framework { // TODO(Aurelius84): Need abstract this logic to implement Proxy for // the co-existance with GraphCompiler. -class NewIRCompiler final { +class PirCompiler final { public: - NewIRCompiler(const ::pir::Program& prog, - const Target& target, - const std::shared_ptr& scope) + PirCompiler(const ::pir::Program& prog, + const Target& target, + const std::shared_ptr& scope) : program_(prog), - m_builder_("NewIR", target), + m_builder_("Pir", target), target_(target), scope_(scope) {} std::unique_ptr Build(); - std::unique_ptr Build(const std::vector& groups); + std::vector BuildCUDAJITInfo( + const std::vector& groups); + + std::unique_ptr Build(const std::vector& groups); private: - CINN_DISALLOW_COPY_AND_ASSIGN(NewIRCompiler); + CINN_DISALLOW_COPY_AND_ASSIGN(PirCompiler); std::vector GetOpFunc(const ::pir::Operation& op, int idx); void ProcessFunction(const std::vector& lowered_funcs); std::vector> BuildInstructions( - const std::vector& groups); + const std::vector& groups); const ::pir::Program& program_; ir::Module::Builder m_builder_; @@ -62,6 +65,23 @@ class NewIRCompiler final { std::shared_ptr BuildScope(const Target&, const ::pir::Program&); +class PirCompilerManager { + public: + static PirCompilerManager& Instance() { + static PirCompilerManager instance; + return instance; + } + + void insert(const std::shared_ptr& compiler) { + compilers_.push_back(compiler); + } + + void clear() { compilers_.clear(); } + + private: + std::vector> compilers_; +}; + } // namespace framework } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/op/contrib/argmax.cc b/paddle/cinn/hlir/op/contrib/argmax.cc index 2a1f19a5d2608..041cfe7dc47a5 100644 --- a/paddle/cinn/hlir/op/contrib/argmax.cc +++ b/paddle/cinn/hlir/op/contrib/argmax.cc @@ -161,6 +161,25 @@ std::shared_ptr StrategyForArgmax( ir_sch.SetBuffer(blocks[0], "local"); ir_sch.SetBuffer(blocks[1], "local"); + int iter_var_size = blocks[0] + .As() + ->schedule_block.As() + ->iter_vars.size(); + int real_axis = axis; + if (real_axis < 0) { + real_axis += iter_var_size; + } + blocks[0] + .As() + ->schedule_block.As() + ->iter_vars[real_axis] + ->is_reduce_axis = true; + blocks[1] + .As() + ->schedule_block.As() + ->iter_vars[real_axis] + ->is_reduce_axis = true; + int64_t prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, diff --git a/paddle/cinn/hlir/op/contrib/argmin.cc b/paddle/cinn/hlir/op/contrib/argmin.cc index dfd88deb6f380..3caaf45c46a5e 100644 --- a/paddle/cinn/hlir/op/contrib/argmin.cc +++ b/paddle/cinn/hlir/op/contrib/argmin.cc @@ -158,6 +158,26 @@ std::shared_ptr StrategyForArgmin( // variables, because the size will exceed the limit. ir_sch.SetBuffer(blocks[0], "local"); ir_sch.SetBuffer(blocks[1], "local"); + + int iter_var_size = blocks[0] + .As() + ->schedule_block.As() + ->iter_vars.size(); + int real_axis = axis; + if (real_axis < 0) { + real_axis += iter_var_size; + } + blocks[0] + .As() + ->schedule_block.As() + ->iter_vars[real_axis] + ->is_reduce_axis = true; + blocks[1] + .As() + ->schedule_block.As() + ->iter_vars[real_axis] + ->is_reduce_axis = true; + int64_t prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, diff --git a/paddle/cinn/hlir/op/reduction.cc b/paddle/cinn/hlir/op/reduction.cc index a396aec315af4..0c279737a2a72 100644 --- a/paddle/cinn/hlir/op/reduction.cc +++ b/paddle/cinn/hlir/op/reduction.cc @@ -29,6 +29,8 @@ #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/optim/ir_simplify.h" +PD_DECLARE_bool(cinn_new_group_scheduler); + namespace cinn { namespace hlir { namespace op { @@ -58,7 +60,7 @@ std::shared_ptr StrategyForReduce( const std::string &op_name, BlockReduceFunc gpu_reduce_with_last_axis_func, BlockReduceFunc gpu_reduce_without_last_axis_func, - ReduceFunc cpu_reduce_func) { + ReduceFunc common_reduce_func) { std::vector reduce_axes; auto ndim = inputs[0]->shape.size(); if (attrs.attr_store.count("dim")) { @@ -127,7 +129,8 @@ std::shared_ptr StrategyForReduce( << "The type of input argument " << x->name << " of " << op_name << " should be bool, but get " << x->type() << "! Please check."; - if (target == common::DefaultNVGPUTarget()) { + if (!FLAGS_cinn_new_group_scheduler && + target == common::DefaultNVGPUTarget()) { if (!WithoutLastDimInReduce(inputs[0]->shape, reduce_axes)) { VLOG(3) << "Do Two Step Block Reduce Compute!"; auto res = gpu_reduce_with_last_axis_func( @@ -155,7 +158,7 @@ std::shared_ptr StrategyForReduce( } } else { VLOG(3) << "Do Reduce Compute!"; - auto out = cpu_reduce_func(x, reduce_axes, keep_dim, tensor_name); + auto out = common_reduce_func(x, reduce_axes, keep_dim, tensor_name); auto stages = CreateStages({out}); std::vector cinn_values{CINNValue(out), CINNValue(stages)}; @@ -193,7 +196,7 @@ std::shared_ptr StrategyForReduce( ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); - if (target.arch == Target::Arch::NVGPU) { + if (!FLAGS_cinn_new_group_scheduler && target.arch == Target::Arch::NVGPU) { if (!WithoutLastDimInReduce(inputs[0]->shape, reduce_axes)) { if (arg_pack.size() == 4) { CHECK_EQ(vec_tensor.size(), 2); @@ -313,7 +316,7 @@ std::shared_ptr StrategyForReduce( reduce_op_, \ gpu_reduce_with_last_axis_func, \ gpu_reduce_without_last_axis_func, \ - cpu_reduce_func) \ + common_reduce_func) \ std::shared_ptr StrategyFor##reduce_op_( \ const framework::NodeAttr &attrs, \ const std::vector &inputs, \ @@ -328,7 +331,7 @@ std::shared_ptr StrategyForReduce( #op_name_, \ gpu_reduce_with_last_axis_func, \ gpu_reduce_without_last_axis_func, \ - cpu_reduce_func); \ + common_reduce_func); \ } STRATEGY_FOR_REDUCE(reduce_sum, diff --git a/paddle/cinn/hlir/op/reduction_test.cc b/paddle/cinn/hlir/op/reduction_test.cc index ca20c0d3fdd76..953dd82017d9b 100644 --- a/paddle/cinn/hlir/op/reduction_test.cc +++ b/paddle/cinn/hlir/op/reduction_test.cc @@ -39,6 +39,9 @@ #include "paddle/cinn/hlir/pe/nn.h" #include "paddle/cinn/runtime/cinn_runtime.h" #include "paddle/cinn/runtime/cuda/cuda_module.h" + +PD_DECLARE_bool(cinn_new_group_scheduler); + namespace cinn { namespace hlir { namespace framework { @@ -362,6 +365,9 @@ void TestCaseForReduce(const float init_val, dim3 block; grid = {c, 1, 1}; int block_dim_x = n * w * h > 1024 ? 1024 : n * w * h; + if (FLAGS_cinn_new_group_scheduler) { + block_dim_x = 1; + } block = {block_dim_x, 1, 1}; void* args[] = {&dev_x, &dev_z}; @@ -531,7 +537,8 @@ TEST(Operator, Operator_Reduction_Case_Warp_Reduce) { std::vector dim = {1}; auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce"); - CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); + if (!FLAGS_cinn_new_group_scheduler) + CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); } TEST(Operator, Operator_Reduction_Case_Block_Reduce) { @@ -544,7 +551,8 @@ TEST(Operator, Operator_Reduction_Case_Block_Reduce) { std::vector dim = {1}; auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce"); - CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); + if (!FLAGS_cinn_new_group_scheduler) + CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); } TEST(Operator, Operator_Reduction_Case_Warp_Reduce_Case_1) { @@ -558,7 +566,8 @@ TEST(Operator, Operator_Reduction_Case_Warp_Reduce_Case_1) { auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce_Case_1"); - CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); + if (!FLAGS_cinn_new_group_scheduler) + CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); } TEST(Operator, Operator_Reduction_Case_Block_Reduce_Case_1) { @@ -572,7 +581,8 @@ TEST(Operator, Operator_Reduction_Case_Block_Reduce_Case_1) { auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce_Case_2"); - CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); + if (!FLAGS_cinn_new_group_scheduler) + CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); } } // namespace framework } // namespace hlir diff --git a/paddle/cinn/hlir/pass/general_fusion_merge_pass/fusion_pass_registrar.h b/paddle/cinn/hlir/pass/general_fusion_merge_pass/fusion_pass_registrar.h index e67afa0cdd2c8..0c9c86fccaa5a 100644 --- a/paddle/cinn/hlir/pass/general_fusion_merge_pass/fusion_pass_registrar.h +++ b/paddle/cinn/hlir/pass/general_fusion_merge_pass/fusion_pass_registrar.h @@ -52,11 +52,11 @@ class FusionPassRegistrar final : public Registrar { #define CINN_REGISTER_FUSION_PASS(pass_name, pass_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_pass__##pass_name, \ + __reg_cinn_fusion_pass__##pass_name, \ "CINN_REGISTER_FUSION_PASS must be called in global namespace"); \ static ::cinn::hlir::pass::FusionPassRegistrar \ - __pass_registrar_##pass_name##__(#pass_name); \ - int TouchFusionPassRegistrar_##pass_name() { \ - __pass_registrar_##pass_name##__.Touch(); \ + __cinn_fusion_pass_registrar_##pass_name##__(#pass_name); \ + int TouchCinnFusionPassRegistrar_##pass_name() { \ + __cinn_fusion_pass_registrar_##pass_name##__.Touch(); \ return 0; \ } diff --git a/paddle/cinn/hlir/pe/ir_schedule_pe.cc b/paddle/cinn/hlir/pe/ir_schedule_pe.cc index 6600905b083c1..b8f6d170996b3 100644 --- a/paddle/cinn/hlir/pe/ir_schedule_pe.cc +++ b/paddle/cinn/hlir/pe/ir_schedule_pe.cc @@ -31,14 +31,39 @@ #include "paddle/cinn/hlir/pe/schedule.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" +#include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/poly/isl_utils.h" #include "paddle/cinn/utils/string.h" +PD_DECLARE_bool(cinn_new_group_scheduler); namespace cinn { namespace hlir { namespace pe { +void SetReduceAxis(ir::Expr loop, ir::Expr block) { + std::string var_name = loop.As()->loop_var->name; + std::vector iter_vars = block.As() + ->schedule_block.As() + ->iter_vars; + std::vector iter_values = + block.As()->iter_values; + CHECK_EQ(iter_vars.size(), iter_values.size()); + for (int i = 0; i < iter_values.size(); ++i) { + std::set contains = ir::ir_utils::CollectIRNodesWithoutTensor( + iter_values[i], + [&var_name](const Expr *expr) { + return expr->As() != nullptr && + expr->As()->name == var_name; + }, + true); + if (!contains.empty()) { + iter_vars[i]->is_reduce_axis = true; + } + } +} + void IRElementwiseSchedule(ir::IRSchedule &ir_sch, // NOLINT const std::vector &output_shape, const common::Target &target) { @@ -457,9 +482,15 @@ void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, // NOLINT if (loops_tmp_out.size() == 1) { ir_sch.Bind(loops_tmp_out[0], "threadIdx.x"); ir_sch.Bind(loops_out[0], "threadIdx.x"); + if (FLAGS_cinn_new_group_scheduler) { + SetReduceAxis(loops_tmp_out[0], ir_sch.GetBlock(tmp_out->name)); + } } else { ir_sch.Bind(loops_tmp_out[0], "blockIdx.x"); ir_sch.Bind(loops_tmp_out[1], "threadIdx.x"); + if (FLAGS_cinn_new_group_scheduler) { + SetReduceAxis(loops_tmp_out[1], ir_sch.GetBlock(tmp_out->name)); + } if (loops_out.size() == 1) { ir_sch.Split(loops_out[0], {-1, 1}); @@ -471,7 +502,11 @@ void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, // NOLINT for (auto &tensor : {tmp_out}) { auto block = ir_sch.GetBlock(tensor->name); - ir_sch.SetBuffer(block, "local", true); + if (FLAGS_cinn_new_group_scheduler) { + ir_sch.SetBuffer(block, "local"); + } else { + ir_sch.SetBuffer(block, "local", true); + } } VLOG(3) << "After IRCudaScheduleBlockReduceInternal : " @@ -600,6 +635,9 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT ir_sch.Bind(loops[0], "blockIdx.x"); ir_sch.Bind(loops[1], "threadIdx.x"); + if (FLAGS_cinn_new_group_scheduler) { + SetReduceAxis(loops[1], ir_sch.GetBlock(tmp_out->name)); + } } // out { @@ -614,7 +652,11 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT for (auto &tensor : {reduce_tmp_out, tmp_out}) { auto block = ir_sch.GetBlock(tensor->name); - ir_sch.SetBuffer(block, "local", true); + if (FLAGS_cinn_new_group_scheduler) { + ir_sch.SetBuffer(block, "local"); + } else { + ir_sch.SetBuffer(block, "local", true); + } } VLOG(3) << "After IRCudaScheduleBlockReduce : " @@ -673,8 +715,10 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // NOLINT auto load = exprs.front().As(); load->indices = {index}; }; - hand_write_simplify(ir_sch.GetLoops(reshape->name), - ir_sch.GetBlock(reshape->name)); + if (!FLAGS_cinn_new_group_scheduler) { + hand_write_simplify(ir_sch.GetLoops(reshape->name), + ir_sch.GetBlock(reshape->name)); + } auto block = ir_sch.GetBlock(reshape->name); ir_sch.ComputeInline(block); VLOG(4) << "After simplify reshape index : " @@ -955,10 +999,14 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // NOLINT ir_sch.ComputeInline(reshape_block); auto internal_block = ir_sch.GetBlock(internal->name); - ir_sch.SetBuffer(internal_block, "local", true); - auto tmp_out_block = ir_sch.GetBlock(tmp_out->name); - ir_sch.SetBuffer(tmp_out_block, "local", true); + if (FLAGS_cinn_new_group_scheduler) { + ir_sch.SetBuffer(internal_block, "local"); + ir_sch.SetBuffer(tmp_out_block, "local"); + } else { + ir_sch.SetBuffer(internal_block, "local", true); + ir_sch.SetBuffer(tmp_out_block, "local", true); + } // The current one-dimensional reduce does not make full use of SM. // This case is optimized into a two-dimensional. @@ -978,9 +1026,15 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // NOLINT ir_sch.Bind(loops[0], "blockIdx.x"); ir_sch.Bind(loops[1], "threadIdx.y"); ir_sch.Bind(loops[2], "threadIdx.x"); + if (FLAGS_cinn_new_group_scheduler && tensor->name == tmp_out->name) { + SetReduceAxis(loops[2], ir_sch.GetBlock(tmp_out->name)); + } } else { ir_sch.Bind(loops[0], "blockIdx.x"); ir_sch.Bind(loops[1], "threadIdx.x"); + if (FLAGS_cinn_new_group_scheduler && tensor->name == tmp_out->name) { + SetReduceAxis(loops[1], ir_sch.GetBlock(tmp_out->name)); + } } } VLOG(3) << "After IRCudaTwoStepReduceSchedule : " diff --git a/paddle/cinn/hlir/pe/reduction.cc b/paddle/cinn/hlir/pe/reduction.cc index e38465babbb38..f809efbd13e67 100644 --- a/paddle/cinn/hlir/pe/reduction.cc +++ b/paddle/cinn/hlir/pe/reduction.cc @@ -1077,6 +1077,31 @@ std::vector TwoStepBlockReduceAny(const ir::Tensor& A, Expr(false)); } +std::string CrossThreadReduceExternalFuncName(const ir::Expr& op, + const ir::Expr& tensor) { + CHECK_NOTNULL(tensor.as_tensor()); + if (op.As()) { + return "cinn_block_reduce_sum" + + Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm"; + } else if (op.As()) { + return "cinn_block_reduce_prod" + + Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm"; + } else if (op.As()) { + return "cinn_block_reduce_max" + + Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm"; + } else if (op.As()) { + return "cinn_block_reduce_min" + + Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm"; + } else if (op.As()) { + return "cinn_block_reduce_all_internal_shm"; + } else if (op.As()) { + return "cinn_block_reduce_any_internal_shm"; + } else { + LOG(FATAL) << "Reduce type: " << op << " Not supported yet!"; + } + return ""; +} + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/pe/reduction.h b/paddle/cinn/hlir/pe/reduction.h index ceb82e8f6fe0b..a3a5f02915ef9 100644 --- a/paddle/cinn/hlir/pe/reduction.h +++ b/paddle/cinn/hlir/pe/reduction.h @@ -467,6 +467,11 @@ std::vector TwoStepBlockReduceAny( const std::vector& axes, const bool keep_dim, const std::string& output_name = "T_Reduce_Any_out"); + +std::string CrossThreadReduceExternalFuncName(const ir::Expr& op, + const ir::Expr& tensor); + +std::string Type2StrForReduce(common::Type type); } // namespace pe } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/ir/CMakeLists.txt b/paddle/cinn/ir/CMakeLists.txt index f0df58ca8dd6b..ff714c24496bb 100644 --- a/paddle/cinn/ir/CMakeLists.txt +++ b/paddle/cinn/ir/CMakeLists.txt @@ -24,3 +24,4 @@ add_subdirectory(op) add_subdirectory(test) add_subdirectory(utils) add_subdirectory(schedule) +add_subdirectory(group_schedule) diff --git a/paddle/cinn/ir/group_schedule/CMakeLists.txt b/paddle/cinn/ir/group_schedule/CMakeLists.txt new file mode 100644 index 0000000000000..e43f56553c496 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/CMakeLists.txt @@ -0,0 +1,4 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS st_shape_group_scheduler.cc) +gather_srcs(cinnapi_src SRCS dy_shape_group_scheduler.cc) diff --git a/paddle/cinn/ir/group_schedule/base_group_scheduler.h b/paddle/cinn/ir/group_schedule/base_group_scheduler.h new file mode 100644 index 0000000000000..a72bfc3f53766 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/base_group_scheduler.h @@ -0,0 +1,55 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/common/target.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule_block_graph.h" + +namespace cinn { +namespace ir { + +using SymbolicCondition = Expr; + +/** + * The base class used for scheduling fusion groups. + */ +class GroupScheduler { + public: + GroupScheduler(ir::IRSchedule* ir_sch, + const std::unordered_set& output_tensor_names, + const common::Target& target) + : ir_sch_(ir_sch), + output_tensor_names_(output_tensor_names), + target_(target) { + schedule_block_graph_ = std::make_unique(*ir_sch_); + } + + virtual ~GroupScheduler() = default; + + virtual void Schedule() = 0; + + virtual std::vector> GetIRs() = 0; + + protected: + ir::IRSchedule* ir_sch_; + const std::unordered_set& output_tensor_names_; + const common::Target& target_; + // Graph in units of ScheduleBlockNode, each node corresponds to a + // ScheduleBlock in IR. + std::unique_ptr schedule_block_graph_; +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc new file mode 100644 index 0000000000000..6d346ec2ea828 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h" + +namespace cinn { +namespace ir { + +void DynamicShapeGroupScheduler::Schedule() { + // Fake schedule for test + int max_spacial_numel = 1; + ScheduleBlockNode* node = schedule_block_graph_->EndPoints()[0]; + ir::Expr block_realize = node->Block(); + std::vector loops = ir_sch_->GetLoops(block_realize); + ir::Expr extent = loops[0].As()->extent; + + ir::Expr condition1 = ir::LE::Make(extent, Expr(1024)); + std::unique_ptr new_ir_sch1 = + std::make_unique(*ir_sch_); + ScheduleBlockGraph sbg1(*new_ir_sch1); + sbg1.NodesWalk([&](ir::ScheduleBlockNode* node) { + new_ir_sch1->Bind(ir_sch_->GetLoops(node->Block())[0], "threadIdx.x"); + }); + ir_schs_.emplace_back(condition1, std::move(new_ir_sch1)); + + ir::Expr condition2 = ir::GT::Make(extent, Expr(1024)); + std::unique_ptr new_ir_sch2 = + std::make_unique(*ir_sch_); + ScheduleBlockGraph sbg2(*new_ir_sch2); + sbg2.NodesWalk([&](ir::ScheduleBlockNode* node) { + new_ir_sch2->Bind(ir_sch_->GetLoops(node->Block())[0], "threadIdx.x"); + }); + ir_schs_.emplace_back(condition2, std::move(new_ir_sch2)); +} + +std::vector> +DynamicShapeGroupScheduler::GetIRs() { + std::vector> irs; + for (auto& sch_pair : ir_schs_) { + irs.emplace_back(sch_pair.first, + sch_pair.second->GetModule().GetExprs()[0]); + } + return irs; +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h new file mode 100644 index 0000000000000..2d9129a6a6db2 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h @@ -0,0 +1,43 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" + +namespace cinn { +namespace ir { + +/** + * The class used for scheduling fusion groups with dynamic shape. + * Note: Currently only CUDA backend is supported. + */ +class DynamicShapeGroupScheduler : public GroupScheduler { + public: + DynamicShapeGroupScheduler( + ir::IRSchedule* ir_sch, + const std::unordered_set& output_tensor_names, + const common::Target& target) + : GroupScheduler(ir_sch, output_tensor_names, target) {} + + void Schedule() override; + + std::vector> GetIRs() override; + + private: + std::vector>> + ir_schs_; +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc new file mode 100644 index 0000000000000..8c2ae6a6799c9 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc @@ -0,0 +1,1199 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" +#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h" +#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h" +#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h" +#include "paddle/cinn/common/cas.h" +#include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/op/ir_operators.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" +#include "paddle/cinn/ir/tensor.h" +#include "paddle/cinn/ir/utils/ir_copy.h" +#include "paddle/cinn/ir/utils/ir_nodes_collector.h" +#include "paddle/cinn/optim/replace_var_with_expr.h" + +namespace cinn { +namespace ir { + +static const std::unordered_set + kProhibitScheduleExternalFuncNames = { +#define CINN_NVGPU_FUNC2STRING(str) #str +#define CINN_NVGPU_FUNC_TYPE(FUNC, TYPE) \ + CINN_NVGPU_FUNC2STRING(cinn_nvgpu_##FUNC##TYPE) + +#define GEN_FUNC_NAME(_, impl) \ + _(impl, gt_num) \ + _(impl, lt_num) \ + _(impl, index_add) \ + _(impl, next_smallest) + +#define GEN_FUNC_NAME_WITH_TYPE(_, ...) \ + _(__VA_ARGS__, _bool), _(__VA_ARGS__, _fp16), _(__VA_ARGS__, _fp32), \ + _(__VA_ARGS__, _fp64), _(__VA_ARGS__, _uint8), _(__VA_ARGS__, _int8), \ + _(__VA_ARGS__, _int16), _(__VA_ARGS__, _int32), _(__VA_ARGS__, _int64), + + GEN_FUNC_NAME(GEN_FUNC_NAME_WITH_TYPE, CINN_NVGPU_FUNC_TYPE) +#undef GEN_FUNC_NAME +}; + +bool IsProhibitScheduleExternCallBlock(ir::Expr block) { + ir::ScheduleBlockRealize* sch_block_realize = + block.As(); + CHECK_NOTNULL(sch_block_realize); + ir::ScheduleBlock* sch_block = + sch_block_realize->schedule_block.As(); + CHECK_NOTNULL(sch_block); + + auto find_call = ir::ir_utils::CollectIRNodesWithoutTensor( + sch_block->body, [&](const Expr* x) { return x->As(); }); + for (ir::Expr call : find_call) { + ir::Call* call_node = call.As(); + if (call.As() && kProhibitScheduleExternalFuncNames.count( + call.As()->name) != 0) { + return true; + } + } + return false; +} + +// Find loops with same extents of 2 ScheduleBlock +std::vector> FindSameOuterLoops( + ir::ScheduleBlockNode* source_node, ir::ScheduleBlockNode* target_node) { + std::vector src_ctrl_stmts = source_node->ControlStmts(); + std::vector tgt_ctrl_stmts = target_node->ControlStmts(); + std::vector> same_loops; + int min_stmt_size = std::min(src_ctrl_stmts.size(), tgt_ctrl_stmts.size()); + for (int i = 0; i < min_stmt_size; ++i) { + if (src_ctrl_stmts[i].As() && tgt_ctrl_stmts[i].As() && + ir::GetLoopExtent(src_ctrl_stmts[i]) == + GetLoopExtent(tgt_ctrl_stmts[i])) { + same_loops.push_back( + std::make_tuple(src_ctrl_stmts[i], tgt_ctrl_stmts[i])); + } else { + break; + } + } + + return same_loops; +} + +std::unordered_set GetReduceLoopVarNames(ir::Expr block) { + ir::ScheduleBlockRealize* schedule_block_realize = + block.As(); + ir::ScheduleBlock* schedule_block = + schedule_block_realize->schedule_block.As(); + std::vector iter_values = schedule_block_realize->iter_values; + std::vector iter_vars = schedule_block->iter_vars; + std::unordered_set reduce_loop_var_names; + for (int i = 0; i < iter_vars.size(); ++i) { + if (iter_vars[i]->is_reduce_axis) { + ir::ir_utils::CollectIRNodesWithoutTensor( + iter_values[i], [&](const ir::Expr* x) { + if (x->as_var()) { + reduce_loop_var_names.insert(x->as_var_ref()->name); + } + return false; + }); + } + } + return reduce_loop_var_names; +} + +std::unordered_set GetReduceVarNames(ir::Expr block) { + ir::ScheduleBlockRealize* schedule_block_realize = + block.As(); + ir::ScheduleBlock* schedule_block = + schedule_block_realize->schedule_block.As(); + std::vector iter_vars = schedule_block->iter_vars; + std::unordered_set reduce_var_names; + for (int i = 0; i < iter_vars.size(); ++i) { + if (iter_vars[i]->is_reduce_axis) { + reduce_var_names.insert(iter_vars[i]->name); + } + } + return reduce_var_names; +} + +void StaticShapeGroupScheduler::Schedule() { + feasible_conditions_.emplace_back( + &StaticShapeGroupScheduler::IsKeepGraphDependency); + DoLoopAlignment(); + DoComputeInline(); +#ifdef CINN_WITH_CUDA + OptimizeReduction(); +#endif + DoHorizontalLoopFusion(); + DoVerticalLoopFusion(); +#ifdef CINN_WITH_CUDA + BindCudaAxis(); + AllocateStorage(); +#endif +} + +std::vector> +StaticShapeGroupScheduler::GetIRs() { + return {{Expr(1), ir_sch_->GetModule().GetExprs()[0]}}; +} + +NodePriority StaticShapeGroupScheduler::CalculateNodePriority( + const ir::ScheduleBlockNode* node) const { + bool has_loop_binded = false; + std::unordered_set reduce_loop_var_names = + GetReduceLoopVarNames(node->Block()); + + int64_t reduce_score = 1; + double score = 1; + for (Expr expr : node->ControlStmts()) { + ir::For* for_node = expr.As(); + if (for_node != nullptr) { + score *= ir::GetLoopExtent(expr); + } + if (reduce_loop_var_names.count(for_node->loop_var->name) != 0) { + reduce_score *= ir::GetLoopExtent(expr); + } + if (for_node->is_binded()) { + has_loop_binded = true; + } + } + if (reduce_score > 1) { + score *= (reduce_score * std::log2(reduce_score)); + } + + VLOG(6) << "The priority score of node " << node->id() << " is " << score; + VLOG(6) << "The node has_loop_binded: " << has_loop_binded; + return NodePriority{has_loop_binded, score}; +} + +ir::ScheduleBlockNode* StaticShapeGroupScheduler::FindGlobalMasterNode() const { + NodePriority max{false, std::numeric_limits::min()}; + ir::ScheduleBlockNode* master = nullptr; + auto FindMaster = [&](ir::ScheduleBlockNode* node) { + NodePriority priority = CalculateNodePriority(node); + VLOG(6) << "The priority score of node " << node->id() << " is " + << priority.score + << ", has_loop_binded: " << priority.has_loop_binded; + if (max < priority) { + max = priority; + master = node; + } + }; + + schedule_block_graph_->NodesWalk(FindMaster); + CHECK(master) << "Cannot find global master node"; + VLOG(6) << "Find the global master node: " << master->id(); + return master; +} + +std::unordered_set StaticShapeGroupScheduler::OutputTensorNames() + const { + std::unordered_set output_tensor_names{output_tensor_names_}; + for (ir::ScheduleBlockNode* node : schedule_block_graph_->EndPoints()) { + output_tensor_names.insert(node->id()); + } + return output_tensor_names; +} + +void StaticShapeGroupScheduler::DoLoopAlignment() { + VLOG(5) << "[Start LoopAlignment] func body: " + << ir_sch_->GetModule().GetExprs().front(); + ir::ScheduleBlockNode* global_master = FindGlobalMasterNode(); + ir::Expr master_block = global_master->Block(); + std::vector original_master_loop_extents; + std::vector spacial_master_loop_extents; + std::vector original_master_loop_order; + std::vector recover_loop_order; + + std::vector master_iter_values = + master_block.As()->iter_values; + std::vector master_iter_vars = + master_block.As() + ->schedule_block.As() + ->iter_vars; + std::vector master_loops = ir_sch_->GetLoops(master_block); + + std::unordered_set reduce_var_names = + GetReduceVarNames(master_block); + if (!reduce_var_names.empty()) { + std::set reduce_loads = ir::ir_utils::CollectIRNodesWithoutTensor( + master_block, + [&](const ir::Expr* x) { + bool find_reduce_var = false; + if (x->As()) { + int i = 0; + for (ir::Expr index : x->As()->indices) { + if (index.as_var() && + reduce_var_names.count(index.as_var_ref()->name) > 0) { + find_reduce_var = true; + } + ++i; + } + } + return find_reduce_var; + }, + /* uniq_target = */ true); + CHECK_EQ(reduce_loads.size(), 1); + + std::vector indices = + reduce_loads.begin()->As()->indices; + for (ir::Expr index : indices) { + CHECK_NOTNULL(index.as_var()); + int idx = 0; + bool is_reduce_var = false; + for (const ir::Var& iter_var : master_iter_vars) { + if (iter_var->name == index.as_var_ref()->name) { + is_reduce_var = iter_var->is_reduce_axis; + break; + } + ++idx; + } + std::vector loop_vars_in_order; + ir::ir_utils::CollectIRNodesInOrder( + master_iter_values[idx], [&](const ir::Expr* x) { + if (x->as_var()) { + loop_vars_in_order.push_back(x->as_var_ref()); + } + return false; + }); + for (const ir::Var& loop_var : loop_vars_in_order) { + for (int i = 0; i < master_loops.size(); ++i) { + if (master_loops[i].As()->loop_var->name == loop_var->name) { + original_master_loop_order.push_back(i); + int extent = ir::GetLoopExtent(master_loops[i]); + original_master_loop_extents.push_back(extent); + if (!is_reduce_var) { + spacial_master_loop_extents.push_back(extent); + } + } + } + } + } + + for (int i = 0; i < original_master_loop_order.size(); ++i) { + for (int j = 0; j < original_master_loop_order.size(); ++j) { + if (original_master_loop_order[j] == i) { + recover_loop_order.push_back(j); + break; + } + } + } + CHECK_EQ(original_master_loop_order.size(), recover_loop_order.size()); + } else { + for (int i = 0; i < master_loops.size(); ++i) { + original_master_loop_extents.push_back( + ir::GetLoopExtent(master_loops[i])); + spacial_master_loop_extents.push_back(ir::GetLoopExtent(master_loops[i])); + original_master_loop_order.push_back(i); + recover_loop_order.push_back(i); + } + } + + int total_master_loop_extents = 1; + int total_spacial_loop_extents = 1; + for (int extent : original_master_loop_extents) { + total_master_loop_extents *= extent; + } + for (int extent : spacial_master_loop_extents) { + total_spacial_loop_extents *= extent; + } + + auto LoopAlignmentFunc = [&](ir::ScheduleBlockNode* node) { + if (IsProhibitScheduleExternCallBlock(node->Block())) { + return false; + } + + if (node == global_master) { + return false; + } + + for (ir::Expr expr : node->ControlStmts()) { + if (expr.As() != nullptr && + (expr.As()->for_type() == ir::ForType::GPUBlock || + expr.As()->for_type() == ir::ForType::GPUThread)) { + return false; + } + if (expr.As()->body.As() && + expr.As()->body.As()->stmts.size() > 1) { + return false; + } + } + + VLOG(6) << "try to align loops of block: " << node->id() + << " with block: " << global_master->id(); + + // 1. Fuse source loops + ir::Expr source_loop = ir_sch_->Fuse(node->ControlStmts()); + int total_source_extent = ir::GetLoopExtent(source_loop); + + // 2. Split source loop to align with the target loops + std::vector target_loop_extents; + if (total_source_extent < total_spacial_loop_extents) { + int cur_extent = 1; + for (int extent : spacial_master_loop_extents) { + cur_extent *= extent; + if (cur_extent == total_source_extent) { + target_loop_extents.push_back(extent); + break; + } else if (cur_extent > total_source_extent) { + target_loop_extents.push_back(-1); + break; + } else { + target_loop_extents.push_back(extent); + } + } + } else if (total_source_extent == total_spacial_loop_extents) { + target_loop_extents = spacial_master_loop_extents; + } else if (total_source_extent < total_master_loop_extents) { + target_loop_extents = spacial_master_loop_extents; + target_loop_extents.push_back(-1); + } else if (total_source_extent == total_master_loop_extents) { + target_loop_extents = original_master_loop_extents; + } + std::vector source_loops; + if (target_loop_extents.size() > 0 && + target_loop_extents[0] < total_source_extent) { + source_loops = ir_sch_->Split(source_loop, target_loop_extents); + } else { + source_loops = {source_loop}; + } + + // 3. Rerorder loops to match the target loops + if (total_source_extent == total_master_loop_extents) { + ir_sch_->Reorder(node->id(), recover_loop_order); + } + + return true; + }; + + schedule_block_graph_->DFSTopoWalk(LoopAlignmentFunc); + VLOG(5) << "[After LoopAlignment] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +void StaticShapeGroupScheduler::DoComputeInline() { + VLOG(5) << "[Start DoComputeInline] func body: " + << ir_sch_->GetModule().GetExprs().front(); + + std::unordered_set no_inline_output_names = OutputTensorNames(); + auto_schedule::AutoInline inliner(target_, no_inline_output_names); + + auto InlineFunc = [&](ir::ScheduleBlockNode* node) { + if (IsProhibitScheduleExternCallBlock(node->Block())) { + return; + } + VLOG(6) << "try ComputeInline on: " << node->id() + << ", before ComputeInline, func body: " + << ir_sch_->GetModule().GetExprs().front(); + ir::Expr schedule_block = node->Block(); + inliner.Apply(ir_sch_, schedule_block); + VLOG(6) << "try ComputeInline on: " << node->id() + << ", after ComputeInline, func body: " + << ir_sch_->GetModule().GetExprs().front(); + }; + + schedule_block_graph_->DFSTopoWalk(InlineFunc); + schedule_block_graph_->Update(*ir_sch_); + VLOG(5) << "[After DoComputeInline] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +void StaticShapeGroupScheduler::DoHorizontalLoopFusion() { + VLOG(5) << "[Start DoHorizontalLoopFusion] func body: " + << ir_sch_->GetModule().GetExprs().front(); + + std::vector end_nodes = + schedule_block_graph_->EndPoints(); + std::reverse(end_nodes.begin(), end_nodes.end()); + ir::ScheduleBlockNode* master_node = end_nodes.front(); + CHECK_NOTNULL(master_node); + for (int i = 1; i < end_nodes.size(); ++i) { + if (IsProhibitScheduleExternCallBlock(end_nodes[i]->Block())) { + continue; + } + VLOG(6) << "try to fuse loop of " << end_nodes[i]->id() << " to " + << master_node->id(); + std::vector>&& same_loops = + FindSameOuterLoops(end_nodes[i], master_node); + if (same_loops.size() == 0) { + continue; + } + ir::Expr target_loop = std::get<1>(same_loops.back()); + VLOG(6) << "target_loop: " << target_loop; + ir_sch_->SimpleComputeAt(end_nodes[i]->Block(), target_loop); + VLOG(6) << "after fuse: " << ir_sch_->GetModule().GetExprs().front(); + } + + VLOG(5) << "[After DoHorizontalLoopFusion] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +void StaticShapeGroupScheduler::DoVerticalLoopFusion() { + VLOG(5) << "[Start DoVerticalLoopFusion] func body: " + << ir_sch_->GetModule().GetExprs().front(); + UpdateBlockOrder(); + + auto FindMaster = + [&](ir::ScheduleBlockNode* node) -> std::vector { + std::vector masters = node->Consumers(); + std::sort( + masters.begin(), + masters.end(), + [&](const ir::ScheduleBlockNode* a, const ir::ScheduleBlockNode* b) { + return this->CalculateNodePriority(b) < + this->CalculateNodePriority(a); + }); + return masters; + }; + + auto ComputeAtFunc = [&](ir::ScheduleBlockNode* node) { + if (IsProhibitScheduleExternCallBlock(node->Block())) { + return; + } + std::vector masters = FindMaster(node); + if (masters.size() == 0) { + return; + } + ir::Expr target_loop; + bool find_target_loop = false; + // Collect infomation of original loops + std::vector original_ctrl_stmts = node->ControlStmts(); + int64_t original_total_loop_extent = 1; + std::vector> original_loop_infos; + std::unordered_set original_loop_node_ptrs; + for (ir::Expr stmt : original_ctrl_stmts) { + if (stmt.As()) { + int extent = ir::GetLoopExtent(stmt); + original_total_loop_extent *= extent; + std::string thread_axis = ""; + ir::ForType target_for_type = stmt.As()->for_type(); + if (target_for_type == ir::ForType::GPUBlock) { + thread_axis += "blockIdx."; + } else if (target_for_type == ir::ForType::GPUThread) { + thread_axis += "threadIdx."; + } else { + original_loop_infos.push_back(std::make_pair(thread_axis, extent)); + continue; + } + int offset = stmt.As()->bind_info().offset; + thread_axis += ('x' + offset); + original_loop_infos.push_back(std::make_pair(thread_axis, extent)); + original_loop_node_ptrs.insert(stmt.ptr()); + } + } + + std::unordered_set src_reduce_loop_var_names = + GetReduceLoopVarNames(node->Block()); + for (ir::ScheduleBlockNode* master : masters) { + // Find the target loop candidates; + std::vector target_loop_candidates; + int64_t total_loop_extent = 1; + std::unordered_set tgt_reduce_loop_var_names = + GetReduceLoopVarNames(master->Block()); + std::vector> same_loops = + FindSameOuterLoops(node, master); + for (const std::tuple& same_loop : + same_loops) { + ir::Expr source_loop = std::get<0>(same_loop); + ir::Expr target_loop = std::get<1>(same_loop); + bool is_src_loop_reduce = + src_reduce_loop_var_names.count( + source_loop.As()->loop_var->name) > 0; + bool is_tgt_loop_reduce = + tgt_reduce_loop_var_names.count( + target_loop.As()->loop_var->name) > 0; + if (source_loop.ptr() != target_loop.ptr() && !is_src_loop_reduce && + !is_tgt_loop_reduce) { + target_loop_candidates.push_back(target_loop); + } + } + // Find the target loop with the highest priority and passing the + // feasibility condition check + for (std::vector::reverse_iterator iter = + target_loop_candidates.rbegin(); + iter != target_loop_candidates.rend(); + ++iter) { + ir::Expr candidate_loop = *iter; + if (candidate_loop.As() && + this->MeetConditions(node->Block(), candidate_loop, 0)) { + target_loop = candidate_loop; + find_target_loop = true; + break; + } + } + if (find_target_loop) { + VLOG(6) << "try to fuse loop of " << node->id() << " to " + << master->id(); + break; + } + } + + // Do schedule + if (find_target_loop) { + ir_sch_->SimpleComputeAt(node->Block(), target_loop); + VLOG(6) << "after compute at: " << ir_sch_->GetModule().GetExprs()[0]; + std::vector new_stmts = node->ControlStmts(); + for (int idx = 0; idx < original_loop_infos.size(); ++idx) { + if (original_loop_infos[idx].first.empty()) { + continue; + } + if (idx < new_stmts.size()) { + CHECK(new_stmts[idx].As()); + if (new_stmts[idx].As()->is_serial()) { + ir_sch_->Bind(new_stmts[idx], original_loop_infos[idx].first); + } + } else { + ir::Expr unit_loop = ir_sch_->AddUnitLoop(node->Block()); + ir_sch_->Bind(unit_loop, original_loop_infos[idx].first); + } + } + VLOG(6) << "after loop info copy: " << ir_sch_->GetModule().GetExprs()[0]; + // Update block and control stmts order after schedule. + this->UpdateBlockOrder(); + } else { + LOG(INFO) << "Cannot find a loop of masters to ComputeAt, do not merge.\n" + << "The schedule block: " << node->Block(); + } + }; + + schedule_block_graph_->DFSTopoWalk(ComputeAtFunc); + VLOG(5) << "[After DoVerticalLoopFusion] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +void StaticShapeGroupScheduler::BindCudaAxis() { + if (target_.arch != Target::Arch::NVGPU) return; + VLOG(5) << "[Start BindCudaAxis] func body: " + << ir_sch_->GetModule().GetExprs().front(); + + auto_schedule::AutoBind binder(target_); + + auto BindFunc = [&](ir::ScheduleBlockNode* node) { + if (IsProhibitScheduleExternCallBlock(node->Block())) { + return; + } + VLOG(6) << "try bind cuda axis on: " << node->id() + << ", before bind, func body: " + << ir_sch_->GetModule().GetExprs().front(); + binder.Apply(ir_sch_, node->id()); + VLOG(6) << "try bind cuda axis on: " << node->id() + << ", after bind, func body: " + << ir_sch_->GetModule().GetExprs().front(); + }; + + schedule_block_graph_->DFSTopoWalk(BindFunc); + + VLOG(5) << "[After BindCudaAxis] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +struct Range { + int min; + int max; +}; + +std::ostream& operator<<(std::ostream& os, const Range& x) { + os << "(" << x.min << ", " << x.max << ")"; + return os; +} + +// TODO(BiynXu): After implementing auxiliary data structures such as IntegerSet +// and MultiDimIntegerSet, re implement this function to simplify these ugly +// codes. +void StaticShapeGroupScheduler::AllocateStorage() { + if (target_.arch != Target::Arch::NVGPU) return; + VLOG(5) << "[Start AllocateStorage] func body: " + << ir_sch_->GetModule().GetExprs().front(); + + // Record ir::For using index structure: > + std::unordered_map> + for_map; + std::unordered_set sync_mark; + + // function to update for_map + auto UpdateVarNameToForMap = [&](ir::Expr root) { + std::vector all_blocks = ir_sch_->GetAllBlocks(); + for (const ir::Expr& block : all_blocks) { + std::string block_name = block.As() + ->schedule_block.As() + ->name; + std::vector for_expr = ir_sch_->GetLoops(block); + for (ir::Expr for_expr : for_expr) { + for_map[block_name][for_expr.As()->loop_var->name] = for_expr; + VLOG(6) << "for_map.insert: <" << block_name << ", " + << for_expr.As()->loop_var->name << ">"; + } + } + }; + + // function to analyze and flatten indices to one dim of load_or_store node + auto AnalyzeIndiceValue = [](ir::Expr load_or_store, + ir::Expr block) -> ir::Expr { + std::vector indices; + ir::Tensor tensor; + if (load_or_store.As()) { + indices = load_or_store.As()->indices; + tensor = load_or_store.As()->tensor.as_tensor_ref(); + } else { + indices = load_or_store.As()->indices; + tensor = load_or_store.As()->tensor.as_tensor_ref(); + } + std::vector iter_vars = + block.As() + ->schedule_block.As() + ->iter_vars; + std::vector iter_values = + block.As()->iter_values; + struct VarHash { + size_t operator()(const ir::Var& var) const { + std::string name = var->name; + return std::hash()(name); + } + }; + std::vector strides; + int extent = 1; + for (int idx = tensor->shape.size() - 1; idx >= 0; --idx) { + strides.insert(strides.begin(), extent); + tensor->shape[idx] = common::AutoSimplify(tensor->shape[idx]); + CHECK(tensor->shape[idx].is_constant()) + << "Shape of tensor: " << tensor << " is not constant"; + extent *= tensor->shape[idx].get_constant(); + } + ir::Expr flatten_indice(0); + for (int idx = 0; idx < indices.size(); ++idx) { + flatten_indice = flatten_indice + ir::Expr(strides[idx]) * indices[idx]; + } + flatten_indice = common::AutoSimplify(flatten_indice); + for (int idx = 0; idx < iter_vars.size(); ++idx) { + optim::ReplaceVarWithExpr( + &flatten_indice, iter_vars[idx], iter_values[idx]); + } + flatten_indice = common::AutoSimplify(flatten_indice); + VLOG(6) << "flatten_indice of " << load_or_store << " : " << flatten_indice; + return flatten_indice; + }; + + enum class CudaBindInfo : int { + kCudaBlock, + kCudaThread, + kSerial, + kCudaThreadAndSerial, + }; + + // function to calculate the range of the specified CUDA axis in a indice + // expression + auto CalculateRange = [&for_map](ir::Expr indice_value, + const CudaBindInfo& bind_info, + const std::string& block_name) { + ir::Expr copy_for_upper_bound = ir::ir_utils::IRCopy(indice_value); + ir::Expr copy_for_lower_bound = ir::ir_utils::IRCopy(indice_value); + std::set var_set = ir::ir_utils::CollectIRNodesWithoutTensor( + indice_value, [](const ir::Expr* x) { return x->as_var(); }); + for (ir::Expr var : var_set) { + std::string name = var.as_var_ref()->name; + CHECK(for_map.find(block_name) != for_map.end()); + CHECK(for_map[block_name].find(name) != for_map[block_name].end()); + ir::Expr for_expr = for_map[block_name][name]; + if (bind_info == CudaBindInfo::kCudaBlock) { + if (for_expr.As()->is_gpu_block_binded()) { + optim::ReplaceVarWithExpr(©_for_upper_bound, + var.as_var_ref(), + for_expr.As()->min + + for_expr.As()->extent - + Expr(1)); + optim::ReplaceVarWithExpr(©_for_lower_bound, + var.as_var_ref(), + for_expr.As()->min); + } else { + optim::ReplaceVarWithExpr( + ©_for_upper_bound, var.as_var_ref(), ir::Expr(0)); + optim::ReplaceVarWithExpr( + ©_for_lower_bound, var.as_var_ref(), ir::Expr(0)); + } + } else if (bind_info == CudaBindInfo::kCudaThread) { + if (for_expr.As()->is_gpu_thread_binded()) { + optim::ReplaceVarWithExpr(©_for_upper_bound, + var.as_var_ref(), + for_expr.As()->min + + for_expr.As()->extent - + Expr(1)); + optim::ReplaceVarWithExpr(©_for_lower_bound, + var.as_var_ref(), + for_expr.As()->min); + } else { + optim::ReplaceVarWithExpr( + ©_for_upper_bound, var.as_var_ref(), ir::Expr(0)); + optim::ReplaceVarWithExpr( + ©_for_lower_bound, var.as_var_ref(), ir::Expr(0)); + } + } else if (bind_info == CudaBindInfo::kSerial) { + if (!for_expr.As()->is_gpu_thread_binded() && + !for_expr.As()->is_gpu_block_binded()) { + optim::ReplaceVarWithExpr(©_for_upper_bound, + var.as_var_ref(), + for_expr.As()->min + + for_expr.As()->extent - + Expr(1)); + optim::ReplaceVarWithExpr(©_for_lower_bound, + var.as_var_ref(), + for_expr.As()->min); + } else { + optim::ReplaceVarWithExpr( + ©_for_upper_bound, var.as_var_ref(), ir::Expr(0)); + optim::ReplaceVarWithExpr( + ©_for_lower_bound, var.as_var_ref(), ir::Expr(0)); + } + } else if (bind_info == CudaBindInfo::kCudaThreadAndSerial) { + if (!for_expr.As()->is_gpu_block_binded()) { + optim::ReplaceVarWithExpr(©_for_upper_bound, + var.as_var_ref(), + for_expr.As()->min + + for_expr.As()->extent - + Expr(1)); + optim::ReplaceVarWithExpr(©_for_lower_bound, + var.as_var_ref(), + for_expr.As()->min); + } else { + optim::ReplaceVarWithExpr( + ©_for_upper_bound, var.as_var_ref(), ir::Expr(0)); + optim::ReplaceVarWithExpr( + ©_for_lower_bound, var.as_var_ref(), ir::Expr(0)); + } + } + } + VLOG(6) << "lower_bound before simplify of " << indice_value << " = " + << copy_for_lower_bound; + copy_for_lower_bound = + common::AutoSimplify(common::AutoSimplify(copy_for_lower_bound)); + VLOG(6) << "upper_bound before simplify of " << indice_value << " = " + << copy_for_upper_bound; + copy_for_upper_bound = + common::AutoSimplify(common::AutoSimplify(copy_for_upper_bound)); + VLOG(6) << "lower_bound of " << indice_value << " = " + << copy_for_lower_bound; + VLOG(6) << "upper_bound of " << indice_value << " = " + << copy_for_upper_bound; + return Range{static_cast(copy_for_lower_bound.get_constant()), + static_cast(copy_for_upper_bound.get_constant())}; + }; + + // function to calculate the coefficient and range of the specified for_type + // in a indice expression + auto GetCoefficientAndRange = [&for_map](ir::Expr indice_value, + const ir::ForType& for_type, + const std::string& block_name) { + std::vector> coef_and_ranges(3); + std::vector indice_copies; + for (int i = 0; i < 3; ++i) { + indice_copies.push_back(ir::ir_utils::IRCopy(indice_value)); + } + std::set var_set = ir::ir_utils::CollectIRNodesWithoutTensor( + indice_value, [](const ir::Expr* x) { return x->as_var(); }); + std::unordered_set visited_var_names; + for (ir::Expr var : var_set) { + std::string name = var.as_var_ref()->name; + if (visited_var_names.count(name) > 0) { + continue; + } + visited_var_names.insert(name); + CHECK(for_map.find(block_name) != for_map.end()); + CHECK(for_map[block_name].find(name) != for_map[block_name].end()); + ir::Expr for_expr = for_map[block_name][name]; + for (int i = 0; i < 3; ++i) { + if (for_type == for_expr.As()->for_type() && + for_expr.As()->bind_info().offset == i && + for_expr.As()->extent.get_constant() > 1) { + optim::ReplaceVarWithExpr( + &(indice_copies[i]), var.as_var_ref(), ir::Expr(1)); + coef_and_ranges[i].second.min = + for_expr.As()->min.get_constant(); + coef_and_ranges[i].second.max = + for_expr.As()->min.get_constant() + + for_expr.As()->extent.get_constant(); + } else { + optim::ReplaceVarWithExpr( + &(indice_copies[i]), var.as_var_ref(), ir::Expr(0)); + } + } + } + for (int i = 0; i < 3; ++i) { + VLOG(6) << "before simplify [" << i << "], the coefficient of " + << indice_value << " = " << indice_copies[i] << ", range = (" + << coef_and_ranges[i].second.min << ", " + << coef_and_ranges[i].second.max << ")"; + indice_copies[i] = common::AutoSimplify(indice_copies[i]); + VLOG(6) << "after simplify [" << i << "], the coefficient of " + << indice_value << " = " << indice_copies << ", range = (" + << coef_and_ranges[i].second.min << ", " + << coef_and_ranges[i].second.max << ")"; + coef_and_ranges[i].first = + static_cast(indice_copies[i].get_constant()); + } + return coef_and_ranges; + }; + + // Determine whether the indice of a pair of Store and Load cross CUDA threads + auto IsCrossThread = [&](ir::Expr store_indice_value, + ir::Expr load_indice_value, + const std::string& store_block_name, + const std::string& load_block_name) { + Range store_thread_overall_range = CalculateRange( + store_indice_value, CudaBindInfo::kCudaThread, store_block_name); + Range load_thread_overall_range = CalculateRange( + load_indice_value, CudaBindInfo::kCudaThread, load_block_name); + Range store_serial_overall_range = CalculateRange( + store_indice_value, CudaBindInfo::kSerial, store_block_name); + Range load_serial_overall_range = CalculateRange( + load_indice_value, CudaBindInfo::kSerial, load_block_name); + auto store_thread_coefficient_and_range = GetCoefficientAndRange( + store_indice_value, ir::ForType::GPUThread, store_block_name); + auto load_thread_coefficient_and_range = GetCoefficientAndRange( + load_indice_value, ir::ForType::GPUThread, load_block_name); + VLOG(6) << "store_block_name: " << store_block_name + << ", load_block_name: " << load_block_name; + VLOG(6) << "store_indice_value: " << store_indice_value + << ", load_indice_value: " << load_indice_value; + VLOG(6) << "store_thread_overall_range = " << store_thread_overall_range; + VLOG(6) << "load_thread_overall_range = " << load_thread_overall_range; + VLOG(6) << "store_serial_overall_range = " << store_serial_overall_range; + VLOG(6) << "load_serial_overall_range = " << load_serial_overall_range; + VLOG(6) << "store_thread_coefficient_and_range[0] = <" + << store_thread_coefficient_and_range[0].first << ", " + << store_thread_coefficient_and_range[0].second << ">"; + VLOG(6) << "load_thread_coefficient_and_range[0] = <" + << load_thread_coefficient_and_range[0].first << ", " + << load_thread_coefficient_and_range[0].second << ">"; + VLOG(6) << "store_thread_coefficient_and_range[1] = <" + << store_thread_coefficient_and_range[1].first << ", " + << store_thread_coefficient_and_range[1].second << ">"; + VLOG(6) << "load_thread_coefficient_and_range[1] = <" + << load_thread_coefficient_and_range[1].first << ", " + << load_thread_coefficient_and_range[1].second << ">"; + VLOG(6) << "store_thread_coefficient_and_range[2] = <" + << store_thread_coefficient_and_range[2].first << ", " + << store_thread_coefficient_and_range[2].second << ">"; + VLOG(6) << "load_thread_coefficient_and_range[2] = <" + << load_thread_coefficient_and_range[2].first << ", " + << load_thread_coefficient_and_range[2].second << ">"; + return !(store_thread_overall_range.min <= load_thread_overall_range.min && + store_thread_overall_range.max >= load_thread_overall_range.max && + store_serial_overall_range.min <= load_serial_overall_range.min && + store_serial_overall_range.max >= load_serial_overall_range.max && + (store_thread_coefficient_and_range[0].first == + load_thread_coefficient_and_range[0].first || + load_thread_coefficient_and_range[0].first == 0) && + store_thread_coefficient_and_range[0].second.min <= + load_thread_coefficient_and_range[0].second.min && + store_thread_coefficient_and_range[0].second.max >= + load_thread_coefficient_and_range[0].second.max && + (store_thread_coefficient_and_range[1].first == + load_thread_coefficient_and_range[1].first || + load_thread_coefficient_and_range[1].first == 0) && + store_thread_coefficient_and_range[1].second.min <= + load_thread_coefficient_and_range[1].second.min && + store_thread_coefficient_and_range[1].second.max >= + load_thread_coefficient_and_range[1].second.max && + (store_thread_coefficient_and_range[2].first == + load_thread_coefficient_and_range[2].first || + load_thread_coefficient_and_range[2].first == 0) && + store_thread_coefficient_and_range[2].second.min <= + load_thread_coefficient_and_range[2].second.min && + store_thread_coefficient_and_range[2].second.max >= + load_thread_coefficient_and_range[2].second.max); + }; + + // Determine whether the indice of a pair of Store and Load cross CUDA block + auto IsCrossBlock = [&](ir::Expr store_indice_value, + ir::Expr load_indice_value, + const std::string& store_block_name, + const std::string& load_block_name) { + Range store_block_overall_range = CalculateRange( + store_indice_value, CudaBindInfo::kCudaBlock, store_block_name); + Range load_block_overall_range = CalculateRange( + load_indice_value, CudaBindInfo::kCudaBlock, load_block_name); + Range store_thread_and_serial_overall_range = + CalculateRange(store_indice_value, + CudaBindInfo::kCudaThreadAndSerial, + store_block_name); + Range load_thread_and_serial_overall_range = CalculateRange( + load_indice_value, CudaBindInfo::kCudaThreadAndSerial, load_block_name); + auto store_block_coefficient_and_range = GetCoefficientAndRange( + store_indice_value, ir::ForType::GPUBlock, store_block_name); + auto load_block_coefficient_and_range = GetCoefficientAndRange( + load_indice_value, ir::ForType::GPUBlock, load_block_name); + VLOG(6) << "store_block_name: " << store_block_name + << ", load_block_name: " << load_block_name; + VLOG(6) << "store_indice_value: " << store_indice_value + << ", load_indice_value: " << load_indice_value; + VLOG(6) << "store_block_overall_range = " << store_block_overall_range; + VLOG(6) << "load_block_overall_range = " << load_block_overall_range; + VLOG(6) << "store_thread_and_serial_overall_range = " + << store_thread_and_serial_overall_range; + VLOG(6) << "load_thread_and_serial_overall_range = " + << load_thread_and_serial_overall_range; + VLOG(6) << "store_block_coefficient_and_range[0] = <" + << store_block_coefficient_and_range[0].first << ", " + << store_block_coefficient_and_range[0].second << ">"; + VLOG(6) << "load_block_coefficient_and_range[0] = <" + << load_block_coefficient_and_range[0].first << ", " + << load_block_coefficient_and_range[0].second << ">"; + VLOG(6) << "store_block_coefficient_and_range[1] = <" + << store_block_coefficient_and_range[1].first << ", " + << store_block_coefficient_and_range[1].second << ">"; + VLOG(6) << "load_block_coefficient_and_range[1] = <" + << load_block_coefficient_and_range[1].first << ", " + << load_block_coefficient_and_range[1].second << ">"; + VLOG(6) << "store_block_coefficient_and_range[2] = <" + << store_block_coefficient_and_range[2].first << ", " + << store_block_coefficient_and_range[2].second << ">"; + VLOG(6) << "load_block_coefficient_and_range[2] = <" + << load_block_coefficient_and_range[2].first << ", " + << load_block_coefficient_and_range[2].second << ">"; + return !(store_block_overall_range.min <= load_block_overall_range.min && + store_block_overall_range.max >= load_block_overall_range.max && + store_thread_and_serial_overall_range.min <= + load_thread_and_serial_overall_range.min && + store_thread_and_serial_overall_range.max >= + load_thread_and_serial_overall_range.max && + (store_block_coefficient_and_range[0].first == + load_block_coefficient_and_range[0].first || + load_block_coefficient_and_range[0].first == 0) && + store_block_coefficient_and_range[0].second.min <= + load_block_coefficient_and_range[0].second.min && + store_block_coefficient_and_range[0].second.max >= + load_block_coefficient_and_range[0].second.max && + (store_block_coefficient_and_range[1].first == + load_block_coefficient_and_range[1].first || + load_block_coefficient_and_range[1].first == 0) && + store_block_coefficient_and_range[1].second.min <= + load_block_coefficient_and_range[1].second.min && + store_block_coefficient_and_range[1].second.max >= + load_block_coefficient_and_range[1].second.max && + (store_block_coefficient_and_range[2].first == + load_block_coefficient_and_range[2].first || + load_block_coefficient_and_range[2].first == 0) && + store_block_coefficient_and_range[2].second.min <= + load_block_coefficient_and_range[2].second.min && + store_block_coefficient_and_range[2].second.max >= + load_block_coefficient_and_range[2].second.max); + }; + + // function to set storage of each tensor + auto SetStorage = [&](ir::ScheduleBlockNode* node) { + if (IsProhibitScheduleExternCallBlock(node->Block())) { + return; + } + ir::MemoryType memory_type = ir::MemoryType::GPULocal; + ir::Expr cur_block = node->Block(); + ir::Expr root_block = ir_sch_->GetRootBlock(cur_block); + UpdateVarNameToForMap(root_block); + std::vector consumer_blocks = + ir::GetConsumers(cur_block, root_block); + // find store and corresponding load nodes + ir::Expr find_store = + *ir::ir_utils::CollectIRNodesWithoutTensor( + cur_block, + [&](const ir::Expr* x) { return x->As(); }, + true) + .begin(); + ir::Expr store_indice_value = AnalyzeIndiceValue(find_store, cur_block); + std::vector> loads_and_blocks; + for (const ir::Expr& consumer_block : consumer_blocks) { + ir::ir_utils::CollectIRNodesWithoutTensor( + consumer_block, [&](const Expr* x) { + if (x->As() && (x->As()->name() == + find_store.As()->name())) { + loads_and_blocks.push_back(std::make_tuple(*x, consumer_block)); + } + return false; + }); + } + // Traverse load nodes to check if there are loads that cross cuda blocks or + // threads + for (const auto& load_and_block : loads_and_blocks) { + ir::Expr load = std::get<0>(load_and_block); + ir::Expr consumer_block = std::get<1>(load_and_block); + std::string consumer_block_name = + consumer_block.As() + ->schedule_block.As() + ->name; + ir::Expr load_indice_value = AnalyzeIndiceValue(load, consumer_block); + if (IsCrossBlock(store_indice_value, + load_indice_value, + node->id(), + consumer_block_name)) { + // TODO(BiynXu): Return error information to the front-end instead of + // terminating the program. + LOG(FATAL) << "Fusion requires synchronization across blocks, but " + "currently we do not support it."; + break; + } else if (IsCrossThread(store_indice_value, + load_indice_value, + node->id(), + consumer_block_name)) { + memory_type = ir::MemoryType::GPUShared; + } + } + // Set output node to global + std::unordered_set output_names = OutputTensorNames(); + if (output_names.count(node->id()) > 0) { + memory_type = ir::MemoryType::Auto; + } + // Set the reduce_init tensor and the real tensor to the same memory + if (ir::IsReduceInitTensorName(node->id())) { + ir::Expr block = + ir_sch_->GetBlock(ir::GetOriginalReduceTensorName(node->id())); + memory_type = ir::GetTensor(block)->buffer->memory_type; + } + // Do schedule + if (memory_type == ir::MemoryType::Auto) { + VLOG(6) << "Set store tensor of block " << node->id() << " to global"; + } else if (memory_type == ir::MemoryType::GPUShared) { + VLOG(6) << "Set store tensor of block " << node->id() << " to shared"; + ir_sch_->SetBuffer(cur_block, "shared"); + std::vector loops = ir_sch_->GetLoops(cur_block); + if (sync_mark.count(ir::GetOriginalReduceTensorName(node->id())) == 0) { + ir_sch_->SyncThreads(loops.back(), true); + sync_mark.insert(ir::GetOriginalReduceTensorName(node->id())); + } + } else if (memory_type == ir::MemoryType::GPULocal) { + VLOG(6) << "Set store tensor of block " << node->id() << " to register"; + ir_sch_->SetBuffer(cur_block, "local"); + } + }; + schedule_block_graph_->DFSTopoWalk(SetStorage); + VLOG(5) << "[After AllocateStorage] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +void StaticShapeGroupScheduler::OptimizeReduction() { + VLOG(5) << "[Start OptimizeReduction] func body: " + << ir_sch_->GetModule().GetExprs().front(); + + auto_schedule::ReductionFactoring rf(target_); + + auto ReductionFactoring = [&](ir::ScheduleBlockNode* node) { + if (IsProhibitScheduleExternCallBlock(node->Block())) { + return; + } + VLOG(6) << "try ReductionFactoring on: " << node->id() + << ", before ReductionFactoring, func body: " + << ir_sch_->GetModule().GetExprs().front(); + rf.Apply(node->id(), ir_sch_); + VLOG(6) << "try ReductionFactoring on: " << node->id() + << ", after ReductionFactoring, func body: " + << ir_sch_->GetModule().GetExprs().front(); + }; + + schedule_block_graph_->DFSTopoWalk(ReductionFactoring); + schedule_block_graph_->Update(*ir_sch_); + + VLOG(5) << "[After OptimizeReduction] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +void StaticShapeGroupScheduler::UpdateBlockOrder() { + ir::Expr root_block = ir_sch_->GetRootBlock(ir_sch_->GetAllBlocks()[0]); + ir::BlockOrderConstructor block_order_constructor; + blocks_order_with_ctrl_stmt_ = block_order_constructor(&root_block); +} + +bool StaticShapeGroupScheduler::IsKeepGraphDependency(Expr schedule_block, + Expr target_loop, + int insert_pos) const { + // Assuming inserting the schedule_block into the target_loop, + // obtain the transformed upstream and downstream blocks. + std::unordered_set blocks_above; + std::unordered_set blocks_below; + bool is_below = false; + bool find_target_loop = false; + int pos_count = -1; + std::map, ir::Expr>::const_iterator iter; + for (iter = blocks_order_with_ctrl_stmt_.begin(); + iter != blocks_order_with_ctrl_stmt_.end(); + ++iter) { + if (iter->second.get() == schedule_block.get()) { + continue; + } + if (iter->second.get() == target_loop.get()) { + find_target_loop = true; + } + if (find_target_loop) { + ++pos_count; + } + if (pos_count == insert_pos) { + is_below = true; + } + if (iter->second.As()) { + std::string block_id = iter->second.As() + ->schedule_block.As() + ->name; + if (is_below) { + blocks_below.insert(block_id); + } else { + blocks_above.insert(block_id); + } + } + } + + // Obtain real upstream and downstream nodes + std::string src_id = schedule_block.As() + ->schedule_block.As() + ->name; + ir::ScheduleBlockNode* node = schedule_block_graph_->RetrieveNode(src_id); + std::unordered_set upstream_ids = node->UpstreamNodes(); + std::unordered_set downstream_ids = node->DownstreamNodes(); + + // Check that the transformed upstream and downstream blocks + // still meet the relationship between the + // original upstream and downstream nodes. + for (const std::string& id : upstream_ids) { + if (blocks_above.count(id) == 0) { + VLOG(6) << "[Breaking Graph Level Dependency] ScheduleBlock: " << src_id + << " cannot be insert into target loop at insert_pos: " + << insert_pos << " because its upstream block: " << id + << " will appear downstream."; + VLOG(6) << "The target loop:\n" << target_loop; + return false; + } + } + for (const std::string& id : downstream_ids) { + if (blocks_below.count(id) == 0) { + VLOG(6) << "[Breaking Graph Level Dependency] ScheduleBlock: " << src_id + << " cannot be insert into target loop at insert_pos: " + << insert_pos << " because its downstream block: " << id + << " will appear upstream."; + VLOG(6) << "The target loop:\n" << target_loop; + return false; + } + } + VLOG(6) << "[Meet Graph Level Dependency] ScheduleBlock: " << src_id + << " can be insert into target loop at insert_pos: " << insert_pos; + VLOG(6) << "The target loop:\n" << target_loop; + return true; +} + +bool StaticShapeGroupScheduler::MeetConditions(Expr schedule_block, + Expr target_loop, + int insert_pos) const { + for (const auto& condition_func : feasible_conditions_) { + if (!(this->*condition_func)(schedule_block, target_loop, insert_pos)) { + return false; + } + } + return true; +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h b/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h new file mode 100644 index 0000000000000..b2b89c392bdc0 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h @@ -0,0 +1,163 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" + +namespace cinn { +namespace ir { + +// The priority of the ScheduleBlockNode, +// prioritizing whether it has been bound to the cuda axis, +// and secondly considering the amount of calculated data. +struct NodePriority { + bool has_loop_binded; + double score; + + bool operator<(const NodePriority& other) const { + if (has_loop_binded ^ other.has_loop_binded) { + return !has_loop_binded; + } else { + return score < other.score; + } + } +}; + +/** + * The class used for scheduling fusion groups with static shape. + * Its responsibility is to perform loop alignment, + * automatic inline, automatic loop fusion, + * and optimize the storage location of intermediate variables. + * Note: Currently only CUDA backend is supported. + */ +class StaticShapeGroupScheduler : public GroupScheduler { + public: + StaticShapeGroupScheduler( + ir::IRSchedule* ir_sch, + const std::unordered_set& output_tensor_names, + const common::Target& target) + : GroupScheduler(ir_sch, output_tensor_names, target) {} + + void Schedule() override; + + std::vector> GetIRs() override; + + private: + // Automatically align loops for each ScheduleBlock. + void DoLoopAlignment(); + + // Automatically inline some ScheduleBlock which meets the conditions. + void DoComputeInline(); + + // Make every effort to automatically merge the loops of the horizontal + // relationship ScheduleBlockNode. + void DoHorizontalLoopFusion(); + + // Make every effort to automatically merge the loops of the vertical + // relationship ScheduleBlockNode. + void DoVerticalLoopFusion(); + + // Automatically bind cuda axis on loops. + void BindCudaAxis(); + + // Automatically allocate storage locations for variables to optimize IO. + void AllocateStorage(); + + // Automatically optimize the reductive calculation + void OptimizeReduction(); + + // Evaluate the priority of ScheduleBlockNode. + // The node where the performance bottleneck is located + // has a higher priority, while the node with a lower priority + // needs to compromise and align loops with the node with the highest + // priority. + NodePriority CalculateNodePriority(const ir::ScheduleBlockNode* node) const; + + // Find the highest priority ScheduleBlockNode, + // other nodes need to align the loop with it. + ir::ScheduleBlockNode* FindGlobalMasterNode() const; + + // Obtain the latest order of ScheduleBlock and the control structures + // throughout the entire IR. + void UpdateBlockOrder(); + + // Get output tensor names of group. + std::unordered_set OutputTensorNames() const; + + /** + * @brief Determine whether the graph level dependency is still maintained + * after the schedule_block is placed in the insert position of target_loop. + * @param schedule_block The src schedule_block to be replaced. + * @param target_loop The target loop to be insert into the schedule_block. + * @param insert_pos The insert position of new schedule_block in the + * target_loop. + */ + bool IsKeepGraphDependency(Expr schedule_block, + Expr target_loop, + int insert_pos) const; + + /** + * @brief Determine whether all feasible conditions are met + * after the schedule_block is placed in the insert position of target_loop. + * @param schedule_block The src schedule_block to be replaced. + * @param target_loop The target loop to be insert into the schedule_block. + * @param insert_pos The insert position of new schedule_block in the + * target_loop. + */ + bool MeetConditions(Expr schedule_block, + Expr target_loop, + int insert_pos) const; + + private: + /** + * @brief Interface of feasibility condition. + * @param schedule_block The src schedule_block to be replaced. + * @param target_loop The target loop to be insert into the schedule_block. + * @param insert_pos The insert position of new schedule_block in the + * target_loop. + */ + using FeasibleCondition = bool (StaticShapeGroupScheduler::*)( + Expr schedule_block, Expr target_loop, int insert_pos) const; + // All feasible conditions. + std::vector feasible_conditions_; + + /** + * The order of blocks and their control statements, + * only For, IfThenElse and ScheduleBlock is considered. + * + * Example: + * for0: + * for1: + * block0 + * block1 + * block2 + * for2: + * block3 + * block4 + * + * the result is: + * [0]: for0 + * [0, 0]: for1 + * [0, 0, 0]: block0 + * [0, 0, 1]: block1 + * [0, 1]: block2 + * [0, 2]: for2 + * [0, 2, 0]: block3 + * [0, 2, 1]: block4 + */ + std::map, ir::Expr> blocks_order_with_ctrl_stmt_; +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/schedule/factorize_reduction.h b/paddle/cinn/ir/schedule/factorize_reduction.h index 0973d123fd40c..4075feb93599e 100644 --- a/paddle/cinn/ir/schedule/factorize_reduction.h +++ b/paddle/cinn/ir/schedule/factorize_reduction.h @@ -33,7 +33,7 @@ namespace ir { Tensor CreateRFTensor(const Tensor& original_tensor, const Expr& rf_loop, int rf_axis) { - std::string name = original_tensor->name + "_rf"; + std::string name = common::UniqName(original_tensor->name + "_rf"); std::vector new_shape = original_tensor->shape; new_shape.insert(new_shape.begin() + rf_axis, rf_loop.As()->extent); Tensor rf_tensor = _Tensor_::Make(name, @@ -80,19 +80,23 @@ class ReduceBlockCreater { ->schedule_block.As() ->name; if (is_rf_block_) { - new_update_block_name += "_rf"; + new_update_block_name = rf_tensor_->name; } std::string new_init_block_name = ir::GenReduceInitTensorNameOf(new_update_block_name); VLOG(5) << "new_init_block_name = " << new_init_block_name; - Expr init_value = rf_tensor_->GetReduceInitVal(); - const std::vector& domain = rf_tensor_->domain_without_reduce_axis(); + const ir::Tensor& real_tensor = + is_rf_block_ + ? rf_tensor_ + : original_update_stmt_.As()->tensor.as_tensor_ref(); + Expr init_value = real_tensor->GetReduceInitVal(); + const std::vector& domain = real_tensor->domain_without_reduce_axis(); ir::Tensor init_tensor = lang::Compute( domain, [=](const std::vector& axis) { return init_value; }, new_init_block_name); - init_tensor->Bind(rf_tensor_->buffer); + init_tensor->Bind(real_tensor->buffer); Expr init_stmt = ir::Store::Make( init_tensor, init_value, new_update_stmt_.As()->indices); new_init_sch_block_ = ScheduleBlock::Make( @@ -299,6 +303,12 @@ class RFBlockCreater : public ReduceBlockCreater { REPLACE_RF_TENSOR(Mul) REPLACE_RF_TENSOR(Max) REPLACE_RF_TENSOR(Min) + REPLACE_RF_TENSOR(And) + REPLACE_RF_TENSOR(Or) + REPLACE_RF_TENSOR(LT) + REPLACE_RF_TENSOR(LE) + REPLACE_RF_TENSOR(GT) + REPLACE_RF_TENSOR(GE) #undef REPLACE_RF_TENSOR new_update_stmt_ = @@ -388,6 +398,12 @@ class RBBlockCreater : public ReduceBlockCreater { REPLACE_RF_TENSOR(Mul) REPLACE_RF_TENSOR(Max) REPLACE_RF_TENSOR(Min) + REPLACE_RF_TENSOR(And) + REPLACE_RF_TENSOR(Or) + REPLACE_RF_TENSOR(LT) + REPLACE_RF_TENSOR(LE) + REPLACE_RF_TENSOR(GT) + REPLACE_RF_TENSOR(GE) #undef REPLACE_RF_TENSOR Expr original_store_tensor = original_update_stmt_.As()->tensor; diff --git a/paddle/cinn/ir/schedule/ir_schedule.cc b/paddle/cinn/ir/schedule/ir_schedule.cc index 24f97b6e03d1e..2baebcbacc61b 100644 --- a/paddle/cinn/ir/schedule/ir_schedule.cc +++ b/paddle/cinn/ir/schedule/ir_schedule.cc @@ -2109,7 +2109,7 @@ void ScheduleImpl::FlattenLoops(const std::vector& loops, CHECK_EQ(iter.as_var_ref()->name, loop_vars[idx]->name) << "loops is not the same order with tensor!"; } else { - CHECK(iter.As()); + CHECK(iter.As()) << iter.node_type() << " is not IntImm"; CHECK_EQ(iter.as_int32(), 0); } } @@ -2640,6 +2640,13 @@ void IRSchedule::SetBuffer(Expr& block, {})); } +Expr IRSchedule::AddUnitLoop(const Expr& block) { + Expr ret = impl_->AddUnitLoop(block); + trace_.Append(ScheduleDesc::Step( + "AddUnitLoop", {{"block", std::vector({block})}}, {}, {ret})); + return ret; +} + Expr IRSchedule::Reorder(const std::vector& loops) { Expr ret = impl_->Reorder(loops); trace_.Append(ScheduleDesc::Step("Reorder", {{"loops", loops}}, {}, {ret})); diff --git a/paddle/cinn/ir/schedule/ir_schedule.h b/paddle/cinn/ir/schedule/ir_schedule.h index 4c5fc1d10f1b6..b33afd03a799a 100644 --- a/paddle/cinn/ir/schedule/ir_schedule.h +++ b/paddle/cinn/ir/schedule/ir_schedule.h @@ -244,7 +244,7 @@ class IRSchedule { */ void SyncThreads(const Expr& ir_node, bool after_node = true); - /*! + /** * \brief Set a tensor's buffer type(memory_type) * \param block The ScheduleBlockRealize corresponding to an unique tensor. * \param memory_type The memory type we want to set. Should be "local", @@ -254,6 +254,13 @@ class IRSchedule { const std::string& memory_type, bool fixed = false); // NOLINT + /** + * \brief Create a new unit loop on top of the block. + * @param block The block to be added the new loop. + * @return The new unit loop. + */ + Expr AddUnitLoop(const Expr& block); + /** * \brief Reorder the loops in the order of vector. * @param loops The loops to be reordered. diff --git a/paddle/cinn/ir/schedule/ir_schedule_util.cc b/paddle/cinn/ir/schedule/ir_schedule_util.cc index 7a2daa3106612..db378eba74194 100644 --- a/paddle/cinn/ir/schedule/ir_schedule_util.cc +++ b/paddle/cinn/ir/schedule/ir_schedule_util.cc @@ -367,8 +367,16 @@ IterRange GetAccessedRange(const Expr& index, Expr indice_extent; Expr mod_extent(0); - if (indice_min.As() && indice_min.As()->b().is_constant()) + if (indice_min.As() && indice_min.As()->b().is_constant()) { + Expr mod_right_min = indice_min.As()->a(); + Expr mod_right_max = indice_max.As()->a(); + Expr mod_right_extent = + common::AutoSimplify(mod_right_max - mod_right_min + 1); mod_extent = indice_min.As()->b(); + if (mod_right_extent.get_constant() < mod_extent.get_constant()) { + mod_extent = mod_right_extent; + } + } if (indice_min == indice_max) { if (common::is_zero(mod_extent)) { @@ -875,7 +883,7 @@ std::vector GetProducers(const Expr& block, const Expr& root) { ->name; ir::ir_utils::CollectIRNodesWithoutTensor( compute_body, [&producer_tensor_names, &block_name](const Expr* x) { - auto* load = x->As(); + const ir::Load* load = x->As(); if (load) { producer_tensor_names.insert(load->tensor.as_tensor()->name); if (load->tensor.as_tensor()->name == block_name) { @@ -884,6 +892,22 @@ std::vector GetProducers(const Expr& block, const Expr& root) { } return true; } + const ir::Store* store = x->As(); + if (store) { + std::set call_nodes = + ir::ir_utils::CollectIRNodesWithoutTensor( + store->value, + [](const ir::Expr* x) { return x->As(); }); + for (ir::Expr call : call_nodes) { + const std::vector& read_args = + call.As()->read_args; + for (const ir::Expr& arg : read_args) { + if (arg.as_tensor()) { + producer_tensor_names.insert(arg.as_tensor_ref()->name); + } + } + } + } return false; }); @@ -936,13 +960,23 @@ std::vector GetConsumers(const Expr& block, const Expr& root) { auto block_body = i.As() ->schedule_block.As() ->body; - auto find_load = ir::ir_utils::CollectIRNodesWithoutTensor( + auto find_load_or_call = ir::ir_utils::CollectIRNodesWithoutTensor( block_body, [&](const Expr* x) { + if (x->As()) { + const std::vector& read_args = + x->As()->read_args; + for (const ir::Expr& arg : read_args) { + if (arg.as_tensor() && + arg.as_tensor_ref()->name == block_tensor) { + return true; + } + } + } return x->As() && x->As()->tensor.as_tensor_ref()->name == block_tensor; }); - if (!find_load.empty()) consumers.emplace_back(i); + if (!find_load_or_call.empty()) consumers.emplace_back(i); } return consumers; } diff --git a/paddle/cinn/ir/schedule/ir_schedule_util.h b/paddle/cinn/ir/schedule/ir_schedule_util.h index 9c9418b4d577e..edd202c6093d6 100644 --- a/paddle/cinn/ir/schedule/ir_schedule_util.h +++ b/paddle/cinn/ir/schedule/ir_schedule_util.h @@ -436,9 +436,11 @@ IterRange RangeUnion(const IterRange& range1, const IterRange& range2); * \param loop The loop where we will insert the block under it * @param root The root of the whole AST. * \param required_blocks vector of ScheduleBlockRealize nodes that require the - * block \param is_store_provided Whether Store nodes of the block provide the + * block + * \param is_store_provided Whether Store nodes of the block provide the * tensor, true means it is in compute_at case, otherwise false means in - * reverse_compuate_at case \return Each index's range of block's tensor. + * reverse_compuate_at case + * \return Each index's range and can_keep_loop flag of block's tensor. * Indicating the buffer region being required. */ std::vector CalculateRequiredRegions( diff --git a/paddle/cinn/ir/schedule/schedule_desc.cc b/paddle/cinn/ir/schedule/schedule_desc.cc index e0d5f4ab21701..c9a26dfa1643d 100644 --- a/paddle/cinn/ir/schedule/schedule_desc.cc +++ b/paddle/cinn/ir/schedule/schedule_desc.cc @@ -422,6 +422,12 @@ CINN_BUILD_STEP_KIND(SetBuffer) .SetApplyFn( APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SetBuffer))); +CINN_BUILD_STEP_KIND(AddUnitLoop) + .Inputs({"block"}) + .SetApplyFn(APPLY_FUNC_UNIFORM( + FREE_FUNCTION_CONVERTER(static_cast( + &IRSchedule::AddUnitLoop)))); + CINN_BUILD_STEP_KIND(Reorder).Inputs({"loops"}).SetApplyFn( APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast&)>( diff --git a/paddle/cinn/ir/test/CMakeLists.txt b/paddle/cinn/ir/test/CMakeLists.txt index e503f5ebfd964..f51ac68d43639 100644 --- a/paddle/cinn/ir/test/CMakeLists.txt +++ b/paddle/cinn/ir/test/CMakeLists.txt @@ -19,3 +19,9 @@ cinn_cc_test(test_ir_compare SRCS ir_compare_test.cc DEPS cinncore) cinn_cc_test(test_ir_copy SRCS ir_copy_test.cc DEPS cinncore) cinn_cc_test(test_schedule_block_graph SRCS schedule_block_graph_test.cc DEPS cinncore) + +if(WITH_CUDA) + cinn_cc_test( + test_static_shape_group_scheduler SRCS st_shape_group_scheduler_test.cc + DEPS cinncore decomposer_test_helper) +endif() diff --git a/paddle/cinn/ir/test/collect_ir_nodes_test.cc b/paddle/cinn/ir/test/collect_ir_nodes_test.cc index d380b4475e37d..859a35a5c0fa9 100644 --- a/paddle/cinn/ir/test/collect_ir_nodes_test.cc +++ b/paddle/cinn/ir/test/collect_ir_nodes_test.cc @@ -42,15 +42,15 @@ TEST(CollectIRNodes, basic) { auto C = Compute( {M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); - auto stages = CreateStages({C}); + ast_gen_ius::TensorGroup tensor_group({C}); - auto fn = Lower("fn", stages, {A, B, C}); + auto fn = LowerToAst("fn", {A, B, C}, &tensor_group); LOG(INFO) << "fn:\n" << fn; auto tensors = CollectIRNodes(fn, [](const Expr* x) { return x->as_tensor(); }); - ASSERT_EQ(tensors.size(), 5UL); + ASSERT_EQ(tensors.size(), 3UL); auto fn_body = fn.As()->body; LOG(INFO) << "fn.body:\n" << fn_body; diff --git a/paddle/cinn/ir/test/schedule_block_graph_test.cc b/paddle/cinn/ir/test/schedule_block_graph_test.cc index 20c7f03b4d235..78c809dc117d4 100644 --- a/paddle/cinn/ir/test/schedule_block_graph_test.cc +++ b/paddle/cinn/ir/test/schedule_block_graph_test.cc @@ -20,6 +20,8 @@ #include "paddle/cinn/hlir/framework/op_lowering.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" +PD_DECLARE_bool(cinn_new_group_scheduler); + namespace cinn { namespace ir { @@ -95,6 +97,7 @@ frontend::Program CreateReduceProgram() { } TEST(ScheduleBlockGraph, elementwise) { + Context::Global().ResetNameId(); frontend::Program program = CreateElementwiseProgram(); IRSchedule ir_sch = MakeIRSchedule(&program); LOG(INFO) << GetIR(ir_sch); @@ -136,23 +139,72 @@ TEST(ScheduleBlockGraph, elementwise) { #ifdef CINN_WITH_CUDA TEST(ScheduleBlockGraph, reduce) { - frontend::Program program = CreateReduceProgram(); + if (FLAGS_cinn_new_group_scheduler) { + Context::Global().ResetNameId(); + frontend::Program program = CreateReduceProgram(); + IRSchedule ir_sch = MakeIRSchedule(&program); + ScheduleBlockGraph sbg(ir_sch); + LOG(INFO) << GetIR(ir_sch); + LOG(INFO) << sbg.Visualize(); + CHECK_EQ(sbg.BlockIdsInOrder().size(), 5); + CHECK_EQ(sbg.nodes().size(), 5); + + ScheduleBlockNode* v_reduce_init = sbg.RetrieveNode("var_2__reduce_init"); + CHECK(v_reduce_init); + CHECK_EQ(v_reduce_init->UpstreamNodes().size(), 0); + CHECK_EQ(v_reduce_init->DownstreamNodes().size(), 3); + + ScheduleBlockNode* v = sbg.RetrieveNode("var_2"); + CHECK(v); + CHECK_EQ(v->UpstreamNodes().size(), 2); + CHECK_EQ(v->DownstreamNodes().size(), 2); + + std::vector reverse_dfs_topo_order_ids; + sbg.DFSTopoWalk( + [&reverse_dfs_topo_order_ids](const ScheduleBlockNode* node) { + reverse_dfs_topo_order_ids.push_back(node->id()); + }); + for (const std::string& id : reverse_dfs_topo_order_ids) { + LOG(INFO) << id; + } + CHECK_EQ(reverse_dfs_topo_order_ids.size(), 5); + + std::vector dfs_topo_order_ids; + sbg.DFSTopoWalk( + [&dfs_topo_order_ids](const ScheduleBlockNode* node) { + dfs_topo_order_ids.push_back(node->id()); + }, + false); + for (const std::string& id : dfs_topo_order_ids) { + LOG(INFO) << id; + } + CHECK_EQ(dfs_topo_order_ids.size(), 5); + } +} + +TEST(ScheduleBlockGraph, arg_max) { + Context::Global().ResetNameId(); + frontend::NetBuilder builder("net_builder"); + auto x = builder.CreateInput(Float(32), {8, 16}, "X"); + auto y = builder.Argmax(x, 0); + frontend::Program program = builder.Build(); + IRSchedule ir_sch = MakeIRSchedule(&program); LOG(INFO) << GetIR(ir_sch); ScheduleBlockGraph sbg(ir_sch); LOG(INFO) << sbg.Visualize(); - CHECK_EQ(sbg.BlockIdsInOrder().size(), 8); - CHECK_EQ(sbg.nodes().size(), 8); + CHECK_EQ(sbg.BlockIdsInOrder().size(), 3); + CHECK_EQ(sbg.nodes().size(), 3); - ScheduleBlockNode* v_reduce_init = sbg.RetrieveNode("var_48__reduce_init"); - CHECK(v_reduce_init); - CHECK_EQ(v_reduce_init->UpstreamNodes().size(), 0); - CHECK_EQ(v_reduce_init->DownstreamNodes().size(), 3); + ScheduleBlockNode* v0_idx = sbg.RetrieveNode("var_0_index"); + CHECK(v0_idx); + CHECK_EQ(v0_idx->UpstreamNodes().size(), 1); + CHECK_EQ(v0_idx->DownstreamNodes().size(), 1); - ScheduleBlockNode* v = sbg.RetrieveNode("var_48"); - CHECK(v); - CHECK_EQ(v->UpstreamNodes().size(), 5); - CHECK_EQ(v->DownstreamNodes().size(), 2); + ScheduleBlockNode* v0 = sbg.RetrieveNode("var_0"); + CHECK(v0); + CHECK_EQ(v0->UpstreamNodes().size(), 2); + CHECK_EQ(v0->DownstreamNodes().size(), 0); std::vector reverse_dfs_topo_order_ids; sbg.DFSTopoWalk([&reverse_dfs_topo_order_ids](const ScheduleBlockNode* node) { @@ -161,7 +213,7 @@ TEST(ScheduleBlockGraph, reduce) { for (const std::string& id : reverse_dfs_topo_order_ids) { LOG(INFO) << id; } - CHECK_EQ(reverse_dfs_topo_order_ids.size(), 8); + CHECK_EQ(reverse_dfs_topo_order_ids.size(), 3); std::vector dfs_topo_order_ids; sbg.DFSTopoWalk( @@ -172,7 +224,7 @@ TEST(ScheduleBlockGraph, reduce) { for (const std::string& id : dfs_topo_order_ids) { LOG(INFO) << id; } - CHECK_EQ(dfs_topo_order_ids.size(), 8); + CHECK_EQ(dfs_topo_order_ids.size(), 3); } #endif diff --git a/paddle/cinn/ir/test/st_shape_group_scheduler_test.cc b/paddle/cinn/ir/test/st_shape_group_scheduler_test.cc new file mode 100644 index 0000000000000..22f64849a8f7a --- /dev/null +++ b/paddle/cinn/ir/test/st_shape_group_scheduler_test.cc @@ -0,0 +1,767 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" + +#include + +#include "paddle/cinn/common/target.h" +#include "paddle/cinn/frontend/decomposer/test_helper.h" +#include "paddle/cinn/hlir/framework/op_lowering.h" + +PD_DECLARE_bool(cinn_new_group_scheduler); + +namespace cinn { +namespace ir { + +using frontend::NetBuilder; +using frontend::RunDecomposer; + +void Compile(NetBuilder* net_builder) { + auto program = net_builder->Build(); + auto target = common::DefaultTarget(); + RunDecomposer(&program, target); + + auto graph = std::make_shared(program, target); + hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); + hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); + CHECK_EQ(graph->fusion_groups.size(), 1); + + auto& dtype_dict = + graph->GetMutableAttrs>( + "inferdtype"); + auto& shape_dict = graph->GetMutableAttrs< + absl::flat_hash_map>("infershape"); + + auto op_lowerer = + hlir::framework::CreateOpLowerer(dtype_dict, shape_dict, target); + for (auto& fusion_group : graph->fusion_groups) { + std::vector lowered_funcs = + op_lowerer.Lower(fusion_group, + /* apply_op_schedule = */ true, + /* apply_group_schedule = */ false); + CHECK_EQ(lowered_funcs.size(), 1); + VLOG(1) << "without group schedule, lowered_func: " + << lowered_funcs.front(); + + FLAGS_cinn_new_group_scheduler = true; + lowered_funcs = op_lowerer.Lower(fusion_group, + /* apply_op_schedule = */ true, + /* apply_group_schedule = */ true); + CHECK_EQ(lowered_funcs.size(), 1); + VLOG(1) << "after group schedule, lowered_func: " << lowered_funcs.front(); + } +} + +void CheckAccuracy(NetBuilder* net_builder, + const std::vector& input_names) { + FLAGS_cinn_new_group_scheduler = true; + auto program = net_builder->Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + + auto scope = BuildScope(target, graph); + hlir::framework::CompilationContext context(graph, scope, target); + hlir::framework::GraphCompiler gc(context); + + for (size_t i = 0; i < input_names.size(); ++i) { + scope->Var(input_names[i]); + auto tensor = scope->GetTensor(input_names[i]); + + std::vector vec; + frontend::InitRandomVector( + &vec, tensor->shape().numel(), 0.0f, 1.0f); + frontend::CopyFromVector(vec, tensor, target); + } + + auto runtime_program = gc.Build(); + runtime_program->Execute(); +} + +// Each unittest below tests a single reduce, +// these unittests are only used to observe the generated IR and debug. +// Accuracy testing is guaranteed by Python unittests named +// test_reduce_op_xxx.py. +TEST(GROUP_SCHEDULER, last_reduce_only_1) { + NetBuilder net_builder("last_reduce_only_1"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {128, 64, 32}, "A"); + auto B = net_builder.ReduceSum(A, {2}); + }; + + CreateModel(); + Compile(&net_builder); +} + +TEST(GROUP_SCHEDULER, last_reduce_only_2) { + NetBuilder net_builder("last_reduce_only_2"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {1024}, "A"); + auto B = net_builder.ReduceSum(A, {0}); + }; + + CreateModel(); + Compile(&net_builder); +} + +TEST(GROUP_SCHEDULER, last_reduce_only_3) { + NetBuilder net_builder("last_reduce_only_3"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {512, 256}, "A"); + auto B = net_builder.ReduceSum(A, {1}); + }; + + CreateModel(); + Compile(&net_builder); +} + +TEST(GROUP_SCHEDULER, non_last_reduce_only_1) { + NetBuilder net_builder("non_last_reduce_only_1"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {10, 10, 10}, "A"); + auto B = net_builder.ReduceSum(A, {0, 1}, /* keep_dim = */ true); + }; + + CreateModel(); + Compile(&net_builder); +} + +TEST(GROUP_SCHEDULER, non_last_reduce_only_2) { + NetBuilder net_builder("non_last_reduce_only_2"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {64, 32, 16, 8, 4}, "A"); + auto B = net_builder.ReduceSum(A, {1, 2, 3}, /* keep_dim = */ true); + }; + + CreateModel(); + Compile(&net_builder); +} + +TEST(GROUP_SCHEDULER, shuffle_reduce_only_1) { + NetBuilder net_builder("shuffle_reduce_only_1"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {32, 32, 32, 32}, "A"); + auto B = net_builder.ReduceSum(A, {0, 2, 3}); + }; + + CreateModel(); + Compile(&net_builder); +} + +TEST(GROUP_SCHEDULER, shuffle_reduce_only_2) { + NetBuilder net_builder("shuffle_reduce_only_2"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {32, 64, 56, 56}, "A"); + auto B = net_builder.ReduceSum(A, {0, 2, 3}); + }; + + CreateModel(); + Compile(&net_builder); +} + +// Each of the following unittest tests a basic pattern composed of multiple +// basic op. And apply accuracy checks to ensure that the results of fusion +// groups and independently running each op are consistent. +TEST(GROUP_SCHEDULER, elementwise_1) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_1"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_2) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_2"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.Cast(C, "float16"); + auto E = net_builder.Cast(C, "float16"); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_3) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_3"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.Cast(C, "float16"); + auto E = net_builder.Cast(C, "float16"); + auto F = net_builder.Cast(D, "float32"); + auto G = net_builder.Cast(E, "float32"); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_4) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_4"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.Cast(C, "float16"); + auto E = net_builder.Cast(C, "float16"); + auto F = net_builder.Add(D, E); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_broadcast) { + NetBuilder net_builder("elementwise_broadcast"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {128}, "A"); + auto B = net_builder.CreateInput(Float(32), {128}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.BroadcastTo(C, {128, 128}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_double_broadcast) { + NetBuilder net_builder("elementwise_double_broadcast"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {128}, "A"); + auto B = net_builder.CreateInput(Float(32), {128}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.BroadcastTo(C, {128, 128}); + auto E = net_builder.BroadcastTo(C, {128, 128}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, non_last_reduce_elementwise_1) { + int h = 128, w = 128; + NetBuilder net_builder("non_last_reduce_elementwise_1"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.ReduceSum(A, {0}); + auto C = net_builder.Cast(B, "float16"); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, last_reduce_elementwise) { + NetBuilder net_builder("last_reduce_elementwise"); + std::vector input_names = {"A", "C"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {128, 64}, "A"); + auto B = net_builder.ReduceSum(A, {1}); + auto C = net_builder.CreateInput(Float(32), {128}, "C"); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_1) { + NetBuilder net_builder("keep_dim_reduce_elementwise"); + std::vector input_names = {"A", "C"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {16, 64, 112, 112}, "A"); + auto B = net_builder.CreateInput(Float(32), {1, 64, 1, 1}, "B"); + auto C = net_builder.ReduceSum(A, {0, 2, 3}, true); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_2) { + NetBuilder net_builder("keep_dim_reduce_elementwise_2"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {16, 64, 112, 112}, "A"); + auto B = net_builder.CreateInput(Float(32), {16, 64, 1, 1}, "B"); + auto C = net_builder.ReduceSum(A, {2, 3}, true); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_3) { + NetBuilder net_builder("keep_dim_reduce_elementwise_3"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {16, 64, 2048}, "A"); + auto B = net_builder.CreateInput(Float(32), {16, 64, 1}, "B"); + auto C = net_builder.ReduceSum(A, {2}, true); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_4) { + NetBuilder net_builder("keep_dim_reduce_elementwise_4"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {16, 64, 2048}, "A"); + auto B = net_builder.CreateInput(Float(32), {16, 1, 2048}, "B"); + auto C = net_builder.ReduceSum(A, {1}, true); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_5) { + NetBuilder net_builder("keep_dim_reduce_elementwise_5"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {16, 64, 16, 1024}, "A"); + auto B = net_builder.CreateInput(Float(32), {16, 1, 16, 1}, "B"); + auto C = net_builder.ReduceSum(A, {1, 3}, true); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_non_last_reduce) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_non_last_reduce"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {0}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_last_reduce) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_last_reduce"); + std::vector input_names = {"A", "C"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {1}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_non_last_reduce_elementwise) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_non_last_reduce_elementwise"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(C, {0}); + auto F = net_builder.Cast(E, "float16"); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_last_reduce_elementwise) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_non_last_reduce_elementwise"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(C, {1}); + auto F = net_builder.Cast(E, "float16"); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_double_non_last_reduce_elementwise) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_double_non_last_reduce_elementwise"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(C, {0}); + auto F = net_builder.ReduceSum(C, {0}); + auto G = net_builder.Add(E, F); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, double_non_last_reduce_elementwise) { + int h = 128, w = 128; + NetBuilder net_builder("double_non_last_reduce_elementwise"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h * 2, w}, "B"); + auto E = net_builder.ReduceSum(A, {0}); + auto F = net_builder.ReduceSum(B, {0}); + auto G = net_builder.Add(E, F); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, triple_non_last_reduce) { + int h = 128, w = 1024; + NetBuilder net_builder("triple_non_last_reduce"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {128, 1024}, "A"); + auto B = net_builder.ReduceSum(A, {0}); + auto C = net_builder.ReduceSum(A, {0}); + auto D = net_builder.ReduceSum(A, {0}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, reduce_broadcast_1) { + int h = 32, w = 32; + NetBuilder net_builder("reduce_broadcast_1"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h * w}, "A"); + auto B = net_builder.ReduceSum(A, {0}); + auto C = net_builder.BroadcastTo(B, {h * w}, {0}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, reduce_broadcast_2) { + int h = 32, w = 32; + NetBuilder net_builder("reduce_broadcast_2"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.ReduceSum(A, {0, 1}); + auto C = net_builder.BroadcastTo(B, {h, w}, {1}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, reduce_broadcast_3) { + int h = 32, w = 32; + NetBuilder net_builder("reduce_broadcast_3"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1, 2}); + auto C = net_builder.BroadcastTo(B, {h, h, w}, {0}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, reduce_broadcast_reduce_broadcast) { + int h = 32, w = 32; + NetBuilder net_builder("reduce_broadcast_reduce_broadcast"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1, 2}); + auto C = net_builder.BroadcastTo(B, {h, h, w}, {0}); + auto D = net_builder.ReduceSum(C, {1, 2}); + auto E = net_builder.BroadcastTo(D, {h, h, w}, {0}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, reduce_broadcast_elementwise) { + int h = 32, w = 32; + NetBuilder net_builder("reduce_broadcast_elementwise"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1, 2}); + auto C = net_builder.BroadcastTo(B, {h, h, w}, {0}); + auto D = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto E = net_builder.BroadcastTo(D, {h, h, w}, {1, 2}); + auto F = net_builder.Add(C, E); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_double_reduce_elementwise_1) { + NetBuilder net_builder("elementwise_double_reduce_elementwise_1"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {32, 32}, "A"); + auto B = net_builder.CreateInput(Float(32), {32, 32}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {1}, false); + auto E = net_builder.ReduceSum(C, {1}, false); + auto F = net_builder.Add(D, E); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_double_reduce_elementwise_2) { + NetBuilder net_builder("elementwise_double_reduce_elementwise_2"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {1, 1000}, "A"); + auto B = net_builder.CreateInput(Float(32), {1, 1000}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {1}, false); + auto E = net_builder.ReduceSum(C, {1}, false); + auto F = net_builder.Add(D, E); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +// Each of following unittests tests a group composed of typical operators +TEST(GROUP_SCHEDULER, layernorm) { + int h = 32, w = 1024; + NetBuilder net_builder("layernorm"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + // x + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + // x * x + auto B = net_builder.Multiply(A, A); + // sum x + auto C = net_builder.ReduceSum(A, {1}); + // sum x*x + auto D = net_builder.ReduceSum(B, {1}); + // constant w + auto E = net_builder.FillConstant({h}, 1024.0f, "E"); + // mean + auto F = net_builder.Divide(C, E); + auto FF = net_builder.BroadcastTo(F, {h, w}, {0}); + // mean x*x + auto G = net_builder.Divide(D, E); + // mean * mean + auto H = net_builder.Multiply(F, F); + // var^2 + auto I = net_builder.Subtract(G, H); + // eps + auto J = net_builder.FillConstant({h}, 1e-10f, "J"); + // eps + delta + auto K = net_builder.Add(I, J); + // var + auto L = net_builder.Sqrt(K); + auto LL = net_builder.BroadcastTo(L, {h, w}, {0}); + // x - mean + auto M = net_builder.Subtract(A, FF); + // /var + auto N = net_builder.Divide(M, LL); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, softmax) { + int h = 32, w = 1024; + NetBuilder net_builder("softmax"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + // softmax + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + // reduce max + auto B = net_builder.ReduceMax(A, {1}); + // broadcast + auto C = net_builder.BroadcastTo(B, {h, w}, {0}); + // x - max(x) + auto D = net_builder.Subtract(A, C); + // exp(x) + auto E = net_builder.Exp(D); + // reduce sum + auto F = net_builder.ReduceSum(E, {1}); + // broadcast + auto G = net_builder.BroadcastTo(F, {h, w}, {0}); + // exp(x)/sum(exp(x)) + auto H = net_builder.Divide(E, G); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/optim/CMakeLists.txt b/paddle/cinn/optim/CMakeLists.txt index 03b8c95b74173..7c30a6e565c43 100755 --- a/paddle/cinn/optim/CMakeLists.txt +++ b/paddle/cinn/optim/CMakeLists.txt @@ -23,7 +23,8 @@ gather_srcs( lower_intrin.cc cast_bool_to_int8.cc var_mod_simplify.cc - remove_schedule_block.cc) + remove_schedule_block.cc + replace_cross_thread_reduction.cc) if(WITH_CUDA) gather_srcs(cinnapi_src SRCS transform_gpu_forloop.cc) @@ -55,3 +56,5 @@ cinn_cc_test(test_cast_simplify SRCS cast_simplify_test.cc DEPS cinncore) cinn_cc_test(test_remove_schedule_block SRCS remove_schedule_block_test.cc DEPS cinncore) cinn_cc_test(test_unroll_loops SRCS unroll_loops_test.cc DEPS cinncore) +cinn_cc_test(test_replace_cross_thread_reduction SRCS + replace_cross_thread_reduction_test.cc DEPS cinncore) diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index 7d6dfe60744ab..238a28ab4da1d 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -29,6 +29,7 @@ #include "paddle/cinn/optim/map_extern_call.h" #include "paddle/cinn/optim/remove_schedule_block.h" #include "paddle/cinn/optim/replace_const_param_to_integer.h" +#include "paddle/cinn/optim/replace_cross_thread_reduction.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" #include "paddle/cinn/optim/transform_polyfor_to_for.h" #include "paddle/cinn/optim/unroll_loops.h" @@ -49,6 +50,7 @@ Expr Optimize(Expr e, ReplaceConstParamToInteger(&copied); // Simplify already contains CastSimplify Simplify(&copied); + ReplaceCrossThreadReduction(&copied); UnrollLoop(&copied); VLOG(4) << "After Optimize UnrollLoop:" << copied; @@ -85,6 +87,7 @@ Expr Optimize(Expr e, ir::Module Optimize(const ir::Module& module, const Target& target) { auto copied = ir::ir_utils::IRCopy(Expr(module)); + ReplaceCrossThreadReduction(&copied); UnrollLoop(&copied); VectorizeLoops(&copied, Target()); VLOG(10) << "After VectorizeLoops:" << copied.as_module_ref(); diff --git a/paddle/cinn/optim/replace_cross_thread_reduction.cc b/paddle/cinn/optim/replace_cross_thread_reduction.cc new file mode 100644 index 0000000000000..2524874bace60 --- /dev/null +++ b/paddle/cinn/optim/replace_cross_thread_reduction.cc @@ -0,0 +1,189 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * This file implements the strategy to remove the unnecessary nested block. + */ +#pragma once +#include "paddle/cinn/optim/replace_cross_thread_reduction.h" +#include + +#include "paddle/cinn/common/common.h" +#include "paddle/cinn/hlir/pe/reduction.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_mutator.h" +#include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" +#include "paddle/cinn/lang/compute.h" + +namespace cinn { +namespace optim { +namespace { + +struct BufferCmp { + bool operator()(const ir::Buffer& a, const ir::Buffer& b) const { + if (a->name == b->name) return false; + return true; + } +}; + +thread_local std::set shm_buffer_; +struct CrossThreadReductionReplacer : public ir::IRMutator<> { + void operator()(ir::Expr* expr) { Visit(expr); } + + private: + bool CanReplace(const ir::ScheduleBlockRealize* block_realize) { + const ir::ScheduleBlock* schedule_block = + block_realize->schedule_block.As(); + CHECK_NOTNULL(schedule_block); + + if (block_realize->schedule_block.As()->name.substr( + 0, 4) == "root") { + return false; + } + + const std::vector& iter_values = block_realize->iter_values; + const std::vector& iter_vars = schedule_block->iter_vars; + ir::Expr body = schedule_block->body; + + std::unordered_set reduce_var_names; + for (int i = 0; i < iter_values.size(); ++i) { + if (!iter_vars[i]->is_reduce_axis) { + continue; + } + ir::ir_utils::CollectIRNodesWithoutTensor( + iter_values[i], [&](const ir::Expr* x) { + if (x->as_var()) { + reduce_var_names.insert(x->as_var()->name); + } + return false; + }); + } + + std::vector thread_binded_reduce_loop_indices; + for (int i = 0; i < cur_loops_.size(); ++i) { + if (reduce_var_names.count(cur_loops_[i].As()->loop_var->name) > + 0) { + if (cur_loops_[i].As()->is_gpu_thread_binded()) { + if (ir::GetLoopExtent(cur_loops_[i]) > 1024) { + return false; + } + thread_binded_reduce_loop_indices.push_back(i); + } + } + } + if (thread_binded_reduce_loop_indices.size() == 0 || + thread_binded_reduce_loop_indices.back() != cur_loops_.size() - 1) { + return false; + } + for (int i = 1; i < thread_binded_reduce_loop_indices.size(); ++i) { + if (thread_binded_reduce_loop_indices[i - 1] + 1 != + thread_binded_reduce_loop_indices[i]) { + return false; + } + } + + return true; + } + + void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + void Visit(const ir::_LoweredFunc_* expr, ir::Expr* op) override { + ir::IRMutator<>::Visit(expr, op); + if (std::find_if(op->as_lowered_func()->temp_bufs.begin(), + op->as_lowered_func()->temp_bufs.end(), + [&](const ir::Buffer& buf) -> bool { + for (auto& tmp_buf : shm_buffer_) { + if (buf->name == tmp_buf->name) return true; + } + return false; + }) == op->as_lowered_func()->temp_bufs.end()) + op->as_lowered_func()->temp_bufs.insert( + op->as_lowered_func()->temp_bufs.end(), + shm_buffer_.begin(), + shm_buffer_.end()); + shm_buffer_.clear(); + } + + void Visit(const ir::ScheduleBlockRealize* expr, ir::Expr* op) override { + if (!CanReplace(expr)) { + VLOG(6) << "Can't replace cross thread reduction: " << *op; + IRMutator::Visit(expr, op); + return; + } + VLOG(6) << "Can replace cross thread reduction: " << *op; + + const ir::ScheduleBlock* schedule_block = + expr->schedule_block.As(); + CHECK_NOTNULL(schedule_block); + ir::Expr original_update_body = schedule_block->body; + ir::Expr original_update_stmt; + CHECK(original_update_body.As() || + original_update_body.As()); + if (original_update_body.As()) { + CHECK_EQ(original_update_body.As()->stmts.size(), 1); + original_update_stmt = original_update_body.As()->stmts[0]; + } else if (original_update_body.As()) { + original_update_stmt = original_update_body; + } + +#define REPLACE_TO_EXTERNAL_CALL(Op) \ + if (original_update_stmt.As()->value.As()) { \ + auto* node = original_update_stmt.As()->value.As(); \ + CHECK(node); \ + auto& operand = node->b(); \ + std::string reduce_func_name = \ + hlir::pe::CrossThreadReduceExternalFuncName( \ + original_update_stmt.As()->value, \ + operand.As()->tensor); \ + auto tmp_dtype = operand.As()->tensor.as_tensor()->type(); \ + auto tmp_buffer = ir::_Buffer_::Make( \ + "shm32_" + hlir::pe::Type2StrForReduce(tmp_dtype) + "_reduce", \ + {ir::Expr(32)}); \ + tmp_buffer->dtype = tmp_dtype; \ + tmp_buffer->memory_type = ir::MemoryType::GPUShared; \ + shm_buffer_.insert(tmp_buffer); \ + original_update_stmt.As()->value = \ + lang::CallExtern(reduce_func_name, {node->b(), tmp_buffer}); \ + } + + REPLACE_TO_EXTERNAL_CALL(ir::Add) + REPLACE_TO_EXTERNAL_CALL(ir::Mul) + REPLACE_TO_EXTERNAL_CALL(ir::Max) + REPLACE_TO_EXTERNAL_CALL(ir::Min) + REPLACE_TO_EXTERNAL_CALL(ir::And) + REPLACE_TO_EXTERNAL_CALL(ir::Or) +#undef REPLACE_TO_EXTERNAL_CALL + + VLOG(6) << "Replace cross thread reduction: " << *op; + + IRMutator::Visit(expr, op); + } + + void Visit(const ir::For* expr, ir::Expr* op) override { + cur_loops_.push_back(*op); + IRMutator::Visit(expr, op); + cur_loops_.pop_back(); + } + + private: + std::vector cur_loops_; +}; + +} // namespace + +void ReplaceCrossThreadReduction(Expr* e) { CrossThreadReductionReplacer()(e); } + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/replace_cross_thread_reduction.h b/paddle/cinn/optim/replace_cross_thread_reduction.h new file mode 100644 index 0000000000000..5bc0d2828d6b2 --- /dev/null +++ b/paddle/cinn/optim/replace_cross_thread_reduction.h @@ -0,0 +1,33 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * This file implements the strategy to remove the unnecessary nested block. + */ +#pragma once +#include + +#include "paddle/cinn/common/common.h" +#include "paddle/cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +/** + * Replace cross thread reduction to external call. + */ +void ReplaceCrossThreadReduction(Expr* e); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/replace_cross_thread_reduction_test.cc b/paddle/cinn/optim/replace_cross_thread_reduction_test.cc new file mode 100644 index 0000000000000..fb8c0d185ed11 --- /dev/null +++ b/paddle/cinn/optim/replace_cross_thread_reduction_test.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/optim/replace_cross_thread_reduction.h" + +#include + +#include +#include + +#include "paddle/cinn/cinn.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/op/ir_operators.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/utils/string.h" + +namespace cinn { +namespace optim { + +TEST(CrossThreadReductionReplacer, basic) { +#ifdef CINN_WITH_CUDA + Context::Global().ResetNameId(); + Placeholder A("A", {Expr(64), Expr(128)}); + Target target = common::DefaultNVGPUTarget(); + Module::Builder builder("reduce_sum", target); + Var reduce_j(128, "reduce_j"); + ir::Tensor B = Compute( + {Expr(64)}, + [&](Var i) { return lang::ReduceSum(A(i, reduce_j), {reduce_j}); }, + "B"); + ast_gen_ius::TensorGroup tensor_group({A, B}); + auto func = lang::LowerToAst("reduce_sum", {A, B}, &tensor_group); + VLOG(6) << "original func\n" << func; + + ir::ModuleExpr mod_expr({func->body}); + ir::IRSchedule ir_sch(mod_expr); + + ir_sch.Bind(ir_sch.GetLoops("B")[0], "blockIdx.x"); + ir_sch.Bind(ir_sch.GetLoops("B")[1], "threadIdx.x"); + + ir::Expr new_func = ir_sch.GetModule().GetExprs()[0]; + VLOG(6) << "After Bind: " << new_func; + + ReplaceCrossThreadReduction(&new_func); + VLOG(6) << "After ReplaceCrossThreadReduction: " << new_func; + + EXPECT_EQ(utils::GetStreamCnt(new_func), utils::Trim(R"ROC({ + ScheduleBlock(root) + { + thread_bind[blockIdx.x] for (i, 0, 64) + { + ScheduleBlock(B__reduce_init) + { + i0 = axis.bind(i) + B__reduce_init[i0] = 0.00000000f + } + thread_bind[threadIdx.x] for (reduce_j, 0, 128) + { + ScheduleBlock(B) + { + i0_0, i1 = axis.bind(i, reduce_j) + B[i0_0] = cinn_block_reduce_sum_fp32_internal_shm(A[i0_0, i1], _Buffer_(shm32__fp32_reduce)) + } + } + } + } +} +)ROC")); +#endif +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index a7e4dc6e1de1a..98acd9576913d 100644 --- a/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -474,11 +474,11 @@ __device__ inline bool cinn_any(const bool left, const bool right) { return left tmp_val = __shfl_sync(mask, tmp_val, 0, 32); \ return tmp_val; \ } else { \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 16, 32)); \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 8, 32)); \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 4, 32)); \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 2, 32)); \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 1, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 16, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 8, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 4, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 2, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 1, 32)); \ return tmp_val; \ } \ } @@ -530,25 +530,22 @@ __device__ inline float cinn_warp_reduce_avg_fp32(const float *buf, int offset, #define CINN_BLOCK_REDUCE_INTERNAL_IMPL(TYPE, value, init_value, cinn_warp_shuffle_internal) \ int warp_id = threadIdx.x / 32; \ - __shared__ TYPE tmp[32]; \ - if (warp_id == 0) { \ - tmp[threadIdx.x] = init_value; \ - } \ TYPE tmp_val = cinn_warp_shuffle_internal(value); \ if (blockDim.x <= 32) { \ return tmp_val; \ } \ + __shared__ TYPE tmp[32]; \ + if (warp_id == 0) { \ + tmp[threadIdx.x] = init_value; \ + } \ __syncthreads(); \ - if (threadIdx.x % 32 == 0) { \ + if ((threadIdx.x & 31) == 0) { \ tmp[warp_id] = tmp_val; \ } \ __syncthreads(); \ if (warp_id == 0) { \ tmp_val = tmp[threadIdx.x]; \ - tmp_val = cinn_warp_shuffle_internal(tmp_val); \ - if (threadIdx.x == 0) { \ - tmp[0] = tmp_val; \ - } \ + tmp[threadIdx.x] = cinn_warp_shuffle_internal(tmp_val); \ } \ __syncthreads(); \ return tmp[0]; @@ -575,13 +572,57 @@ EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_INTERNAL_MACRO) #undef CINN_BLOCK_REDUCE_INTERNAL_IMPL #undef CINN_BLOCK_REDUCE_INTERNAL_MACRO +#define CINN_BLOCK_REDUCE_INTERNAL_SHM_IMPL(TYPE, value, init_value, cinn_warp_shuffle_internal) \ + int warp_id = threadIdx.x / 32; \ + TYPE tmp_val = cinn_warp_shuffle_internal(value); \ + if (blockDim.x <= 32) { \ + return tmp_val; \ + } \ + if (warp_id == 0) { \ + shm[threadIdx.x] = init_value; \ + } \ + __syncthreads(); \ + if ((threadIdx.x & 31) == 0) { \ + shm[warp_id] = tmp_val; \ + } \ + __syncthreads(); \ + if (warp_id == 0) { \ + tmp_val = shm[threadIdx.x]; \ + shm[threadIdx.x] = cinn_warp_shuffle_internal(tmp_val); \ + } \ + __syncthreads(); \ + return shm[0]; + +#define CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ + __device__ inline DTYPE cinn_block_reduce_##REDUCE_TYPE##_internal_shm(const DTYPE value, DTYPE* shm) { \ + CINN_BLOCK_REDUCE_INTERNAL_SHM_IMPL(DTYPE, value, (DTYPE)(INITIAL_VALUE), cinn_warp_shuffle_##REDUCE_TYPE##_internal); \ + } + +EXPAND_REDUCE_INT32_MARCO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO) +EXPAND_REDUCE_INT64_MARCO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO) +EXPAND_REDUCE_FP32_MACRO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO) +EXPAND_REDUCE_FP64_MACRO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO) +EXPAND_REDUCE_BOOL_MACRO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO) + +#ifdef CINN_CUDA_BF16 +EXPAND_REDUCE_BF16_MACRO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO) +#endif + +#ifdef CINN_CUDA_FP16 +EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO) +#endif + +#undef CINN_BLOCK_REDUCE_INTERNAL_SHM_IMPL +#undef CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO + #define CINN_BLOCK_REDUCE_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ __device__ inline DTYPE cinn_block_reduce_##REDUCE_TYPE(const DTYPE *buf, int offset, int extend) { \ + __shared__ DTYPE shm[32]; \ DTYPE tmp_val = (DTYPE)(INITIAL_VALUE); \ for (int i = threadIdx.x; i < extend; i += blockDim.x) { \ tmp_val = cinn_##REDUCE_TYPE(tmp_val, buf[offset + i]); \ } \ - return cinn_block_reduce_##REDUCE_TYPE##_internal(tmp_val); \ + return cinn_block_reduce_##REDUCE_TYPE##_internal_shm(tmp_val,shm); \ } EXPAND_REDUCE_INT32_MARCO(CINN_BLOCK_REDUCE_IMPL) diff --git a/paddle/cinn/runtime/cuda/cublas_util.h b/paddle/cinn/runtime/cuda/cublas_util.h index 8cd45e4538b89..edb3d60e8a1a3 100644 --- a/paddle/cinn/runtime/cuda/cublas_util.h +++ b/paddle/cinn/runtime/cuda/cublas_util.h @@ -70,6 +70,27 @@ inline cublasStatus_t cublasGemm(cudaDataType_t dtype, reinterpret_cast(C), ldc); } else if (dtype == CUDA_R_16F) { +#if CUDA_VERSION >= 11000 + return cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + &alpha, + A, + CUDA_R_16F, + lda, + B, + CUDA_R_16F, + ldb, + &beta, + C, + CUDA_R_16F, + ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#else common::float16 alpha_fp16{alpha}; common::float16 beta_fp16{beta}; return cublasHgemm(handle, @@ -86,6 +107,7 @@ inline cublasStatus_t cublasGemm(cudaDataType_t dtype, reinterpret_cast(&beta_fp16), reinterpret_cast<__half *>(C), ldc); +#endif } else if (dtype == CUDA_R_16BF) { #if CUDA_VERSION >= 11000 return cublasGemmEx(handle, @@ -174,6 +196,31 @@ inline cublasStatus_t cublasGemmStridedBatched(cudaDataType_t dtype, strideC, batchCount); } else if (dtype == CUDA_R_16F) { +#if CUDA_VERSION >= 11000 + return cublasGemmStridedBatchedEx(handle, + transa, + transb, + m, + n, + k, + &alpha, + A, + CUDA_R_16F, + lda, + strideA, + B, + CUDA_R_16F, + ldb, + strideB, + &beta, + C, + CUDA_R_16F, + ldc, + strideC, + batchCount, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#else common::float16 alpha_fp16{alpha}; common::float16 beta_fp16{beta}; return cublasHgemmStridedBatched( @@ -195,6 +242,7 @@ inline cublasStatus_t cublasGemmStridedBatched(cudaDataType_t dtype, ldc, strideC, batchCount); +#endif } else if (dtype == CUDA_R_16BF) { #if CUDA_VERSION >= 11000 return cublasGemmStridedBatchedEx(handle, @@ -279,6 +327,28 @@ inline cublasStatus_t cublasGemmBatched(cudaDataType_t dtype, ldc, batchCount); } else if (dtype == CUDA_R_16F) { +#if CUDA_VERSION >= 11000 + return cublasGemmBatchedEx(handle, + transa, + transb, + m, + n, + k, + &alpha, + A, + CUDA_R_16F, + lda, + B, + CUDA_R_16F, + ldb, + &beta, + C, + CUDA_R_16F, + ldc, + batchCount, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#else __half alpha_fp16{alpha}; __half beta_fp16{beta}; return cublasHgemmBatched(handle, @@ -296,6 +366,7 @@ inline cublasStatus_t cublasGemmBatched(cudaDataType_t dtype, reinterpret_cast<__half **>(C), ldc, batchCount); +#endif } else if (dtype == CUDA_R_16BF) { #if CUDA_VERSION >= 11000 return cublasGemmBatchedEx(handle, diff --git a/paddle/cinn/runtime/cuda/cuda_intrinsics_reduce.cc b/paddle/cinn/runtime/cuda/cuda_intrinsics_reduce.cc index 9f7bdefcd2fcc..15fcb4030e89b 100644 --- a/paddle/cinn/runtime/cuda/cuda_intrinsics_reduce.cc +++ b/paddle/cinn/runtime/cuda/cuda_intrinsics_reduce.cc @@ -110,6 +110,24 @@ CINN_REGISTER_HELPER(cuda_intrinsics_reduce) { #undef REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL +#define REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL(REDUCE_TYPE, DTYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER( \ + cinn_block_reduce_##REDUCE_TYPE##_internal_shm, target) \ + .SetRetType() \ + .AddInputType() \ + .AddInputType() \ + .End(); + + EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL) + EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL) + EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL) + EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL) + EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL) + EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL) + EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL) + +#undef REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL + #define REGISTER_BLOCK_REDUCE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \ REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_block_reduce_##REDUCE_TYPE, target) \ .SetRetType() \ diff --git a/paddle/cinn/runtime/cuda/cuda_util.cc b/paddle/cinn/runtime/cuda/cuda_util.cc index 6fb82ccb8a05a..1b27e5f103cba 100644 --- a/paddle/cinn/runtime/cuda/cuda_util.cc +++ b/paddle/cinn/runtime/cuda/cuda_util.cc @@ -162,6 +162,8 @@ void cinn_call_cublas(void *v_args, int n = trans_o ? (trans_b ? b3 : b4) : (trans_a ? a4 : a3); int k = trans_a ? a3 : a4; + VLOG(3) << "m: " << m << ", n: " << n << ", k: " << k; + cublasOperation_t trans_op_l = trans_o ? (trans_a ? CUBLAS_OP_N : CUBLAS_OP_T) : (trans_b ? CUBLAS_OP_T : CUBLAS_OP_N); @@ -245,7 +247,7 @@ void cinn_call_cublas(void *v_args, int batch = std::max(a2, b2); VLOG(3) << "call cublasGemmStridedBatched with a1*b1 = 1, stride_l = " << stride_l << ", stride_r = " << stride_r - << ", batch = " << batch; + << ", batch = " << batch << ", dtype = " << cuda_dtype; cinn::utils::RecordEvent record_run("Call cublasGemmStridedBatched", cinn::utils::EventType::kInstruction); CUBLAS_CALL(cublasGemmStridedBatched(cuda_dtype, diff --git a/paddle/cinn/runtime/flags.cc b/paddle/cinn/runtime/flags.cc index 3d3801fa675fb..cad18f4084a5d 100644 --- a/paddle/cinn/runtime/flags.cc +++ b/paddle/cinn/runtime/flags.cc @@ -61,6 +61,10 @@ PD_DEFINE_bool(general_fusion_merge_pass, BoolFromEnv("FLAGS_general_fusion_merge_pass", true), "Whether to use general fusion_merge pass."); +PD_DEFINE_bool(cinn_new_group_scheduler, + BoolFromEnv("FLAGS_cinn_new_group_scheduler", false), + "Whether to use new group scheduler."); + PD_DEFINE_bool(cinn_use_common_subexpression_elimination, BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination", false), diff --git a/paddle/cinn/utils/attribute_util.h b/paddle/cinn/utils/attribute_util.h deleted file mode 100644 index 474bc09e2c64c..0000000000000 --- a/paddle/cinn/utils/attribute_util.h +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include - -#include "paddle/cinn/common/type.h" -#include "paddle/cinn/utils/type_defs.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/pir/core/builtin_op.h" -#include "paddle/pir/core/builtin_type.h" - -namespace cinn { -namespace utils { - -using NewIR_AttributeMap = std::unordered_map; - -Attribute ConvertAttribute(const ::pir::Attribute& src_attr) { - Attribute dst_attr; - if (src_attr.isa<::pir::BoolAttribute>()) { - dst_attr = src_attr.dyn_cast<::pir::BoolAttribute>().data(); - } else if (src_attr.isa<::pir::FloatAttribute>()) { - dst_attr = src_attr.dyn_cast<::pir::FloatAttribute>().data(); - } else if (src_attr.isa<::pir::Int32Attribute>()) { - dst_attr = src_attr.dyn_cast<::pir::Int32Attribute>().data(); - } else if (src_attr.isa<::pir::StrAttribute>()) { - dst_attr = src_attr.dyn_cast<::pir::StrAttribute>().AsString(); - } else if (src_attr.isa<::pir::Int64Attribute>()) { - dst_attr = src_attr.dyn_cast<::pir::Int64Attribute>().data(); - } else if (src_attr.isa<::pir::DoubleAttribute>()) { - dst_attr = src_attr.dyn_cast<::pir::DoubleAttribute>().data(); - } else if (src_attr.isa()) { - auto& arr = src_attr.dyn_cast() - .data() - .GetData(); - std::vector val(arr.begin(), arr.end()); - dst_attr = val; - } else if (src_attr.isa()) { - auto dtype = src_attr.dyn_cast().data(); - dst_attr = phi::DataTypeToString(dtype); - } else { - LOG(FATAL) << "unknown Attribute: " << src_attr; - } - - return dst_attr; -} - -AttributeMap ConvertAttributes(const NewIR_AttributeMap& src_attrs) { - AttributeMap dst_attrs; - for (auto& item : src_attrs) { - VLOG(4) << "deal with " << item.first; - if (item.first == ::pir::kStopGradientAttrName) { - continue; - } else if (item.second.isa()) { - auto is_cpu = - item.second.dyn_cast().data() == - phi::CPUPlace(); - dst_attrs["force_cpu"] = is_cpu; - } else { - dst_attrs[item.first] = std::move(ConvertAttribute(item.second)); - } - } - VLOG(4) << "dst_attrs.size(): " << dst_attrs.size(); - return dst_attrs; -} - -#define CASE_TYPE(src, dst) \ - else if (type.isa<::pir::src>()) return common::dst(); - -common::Type ConvertIRType(::pir::Type type) { - if (type.isa<::pir::BFloat16Type>()) return common::BF16(); - CASE_TYPE(Float16Type, F16) - CASE_TYPE(Float32Type, F32) - CASE_TYPE(Float64Type, F64) - CASE_TYPE(Int8Type, I8) - CASE_TYPE(UInt8Type, UI8) - CASE_TYPE(Int16Type, I16) - CASE_TYPE(Int32Type, I32) - CASE_TYPE(Int64Type, I64) - CASE_TYPE(IndexType, I32) - CASE_TYPE(BoolType, UI1) - - LOG(FATAL) << "unknown ir::Type " << type; -} - -} // namespace utils -} // namespace cinn diff --git a/paddle/fluid/distributed/auto_parallel/dist_attr.cc b/paddle/fluid/distributed/auto_parallel/dist_attr.cc index e6c31d06e21c2..58309ced6d1c7 100644 --- a/paddle/fluid/distributed/auto_parallel/dist_attr.cc +++ b/paddle/fluid/distributed/auto_parallel/dist_attr.cc @@ -76,6 +76,7 @@ OperatorDistAttr& OperatorDistAttr::operator=( std::swap(this->stream_priority_, tmp.stream_priority_); std::swap(this->scheduling_priority_, tmp.scheduling_priority_); std::swap(this->annotated_, tmp.annotated_); + std::swap(this->run_time_us_, tmp.run_time_us_); // Note: Make sure all tensor dist attr has the same process_mesh set_process_mesh(this->process_mesh_); return *this; @@ -125,6 +126,7 @@ void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) { set_events_to_wait(dist_attr.events_to_wait()); set_scheduling_priority(dist_attr.scheduling_priority()); set_annotated(dist_attr.annotated()); + set_run_time_us(dist_attr.run_time_us()); } void OperatorDistAttr::set_input_dist_attrs( diff --git a/paddle/fluid/distributed/auto_parallel/dist_attr.h b/paddle/fluid/distributed/auto_parallel/dist_attr.h index 347c7fc05dfa0..b2acb93a6e17d 100644 --- a/paddle/fluid/distributed/auto_parallel/dist_attr.h +++ b/paddle/fluid/distributed/auto_parallel/dist_attr.h @@ -229,6 +229,9 @@ class OperatorDistAttr { return key + "_" + std::to_string(id_++); } + double run_time_us() const { return this->run_time_us_; } + void set_run_time_us(const double& us) { this->run_time_us_ = us; } + private: static std::vector fields_; std::map input_dist_attrs_; @@ -245,6 +248,8 @@ class OperatorDistAttr { int stream_priority_ = 0; // lower value, higher priority int64_t scheduling_priority_ = 0; // lower value, higher priority std::map annotated_; + double run_time_us_ = -1.0; // stores the actual run time (us) of relevant + // op, negative value means invalid. }; inline std::ostream& operator<<(std::ostream& os, const OperatorDistAttr& obj) { diff --git a/paddle/fluid/distributed/auto_parallel/test/CMakeLists.txt b/paddle/fluid/distributed/auto_parallel/test/CMakeLists.txt index b0beaad0f6b1f..954af0cc852a0 100644 --- a/paddle/fluid/distributed/auto_parallel/test/CMakeLists.txt +++ b/paddle/fluid/distributed/auto_parallel/test/CMakeLists.txt @@ -13,6 +13,12 @@ cc_test( SRCS dist_attr_test.cc DEPS phi proto_desc) -cc_test_old(dist_mapper_test SRCS dist_mapper_test.cc DEPS phi) +cc_test( + dist_mapper_test + SRCS dist_mapper_test.cc + DEPS phi) -cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rules) +cc_test( + spmd_rule_test + SRCS spmd_rule_test.cc + DEPS spmd_rules) diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index 215f55f2d1883..a2267e1f6cebd 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -18,7 +18,7 @@ endif() if(WITH_NCCL OR WITH_RCCL) cc_library( process_group_nccl - SRCS process_group_nccl.cc nccl_tools.cc common.cc + SRCS process_group_nccl.cc common.cc DEPS process_group phi place diff --git a/paddle/fluid/distributed/collective/bkcl_tools.cc b/paddle/fluid/distributed/collective/bkcl_tools.cc index ba5afbbf1feb5..7e95eb8b748eb 100644 --- a/paddle/fluid/distributed/collective/bkcl_tools.cc +++ b/paddle/fluid/distributed/collective/bkcl_tools.cc @@ -14,8 +14,6 @@ #include "paddle/fluid/distributed/collective/bkcl_tools.h" -#include "paddle/fluid/distributed/collective/types.h" - namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/collective/bkcl_tools.h b/paddle/fluid/distributed/collective/bkcl_tools.h index 533498cd8e119..19d321080d47a 100644 --- a/paddle/fluid/distributed/collective/bkcl_tools.h +++ b/paddle/fluid/distributed/collective/bkcl_tools.h @@ -14,14 +14,15 @@ #pragma once -#include "paddle/fluid/distributed/collective/types.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/distributed/types.h" namespace paddle { namespace distributed { using XPUContext = phi::XPUContext; +using phi::distributed::ReduceOp; #define BKCLCHECK(cmd) \ do { \ diff --git a/paddle/fluid/distributed/collective/custom_ccl_tools.cc b/paddle/fluid/distributed/collective/custom_ccl_tools.cc index ccafcf12a6c26..15e8b680b7805 100644 --- a/paddle/fluid/distributed/collective/custom_ccl_tools.cc +++ b/paddle/fluid/distributed/collective/custom_ccl_tools.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/distributed/collective/custom_ccl_tools.h" -#include "paddle/fluid/distributed/collective/types.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/collective/custom_ccl_tools.h b/paddle/fluid/distributed/collective/custom_ccl_tools.h index 95557079a8252..4fb336e929065 100644 --- a/paddle/fluid/distributed/collective/custom_ccl_tools.h +++ b/paddle/fluid/distributed/collective/custom_ccl_tools.h @@ -22,7 +22,6 @@ #include -#include "paddle/fluid/distributed/collective/types.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/collective_helper.h" @@ -30,10 +29,13 @@ #include "paddle/fluid/platform/enforce.h" #include "paddle/phi/backends/device_guard.h" #include "paddle/phi/backends/device_manager.h" +#include "paddle/phi/core/distributed/types.h" namespace paddle { namespace distributed { +using phi::distributed::ReduceOp; + phi::ccl::CCLReduceOp ToXCCLRedType(ReduceOp reduction); } // namespace distributed diff --git a/paddle/fluid/distributed/collective/nccl_tools.h b/paddle/fluid/distributed/collective/nccl_tools.h deleted file mode 100644 index 135aadd2a2414..0000000000000 --- a/paddle/fluid/distributed/collective/nccl_tools.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/fluid/distributed/collective/types.h" - -#ifdef PADDLE_WITH_RCCL -#include -#include "paddle/phi/backends/dynload/rccl.h" -#else -#include -#include "paddle/phi/backends/dynload/nccl.h" -#endif - -namespace paddle { -namespace distributed { - -#define NCCL_CHECK(cmd) \ - do { \ - ncclResult_t r = cmd; \ - if (r != ncclSuccess) { \ - printf("Failed, NCCL error %s:%d '%s'\n", \ - __FILE__, \ - __LINE__, \ - phi::dynload::ncclGetErrorString(r)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) - -ncclRedOp_t ToNCCLRedType(ReduceOp reduction); - -std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID); - -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group.h b/paddle/fluid/distributed/collective/process_group.h index e643348eeed0d..8767dfa60cf18 100644 --- a/paddle/fluid/distributed/collective/process_group.h +++ b/paddle/fluid/distributed/collective/process_group.h @@ -20,10 +20,10 @@ #include #include -#include "paddle/fluid/distributed/collective/types.h" -#include "paddle/fluid/eager/utils.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/distributed/types.h" +#include "paddle/phi/core/distributed/utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" @@ -32,24 +32,18 @@ constexpr auto kWaitTimeout = std::chrono::milliseconds(0); namespace paddle { namespace distributed { +using phi::distributed::AllreduceOptions; +using phi::distributed::BarrierOptions; +using phi::distributed::BroadcastOptions; +using phi::distributed::CommType; +using phi::distributed::GatherOptions; +using phi::distributed::GetPartialTensor; +using phi::distributed::ReduceOp; +using phi::distributed::ReduceOptions; +using phi::distributed::ReduceScatterOptions; +using phi::distributed::ScatterOptions; constexpr int kIgnoreId = -1; -enum class CommType : std::uint8_t { - BROADCAST = 0, - ALLREDUCE = 1, - ALLREDUCE_SPARSE = 2, // TODO(shenliang03): to support sparse in allreduce - REDUCE = 3, - ALLGATHER = 4, - GATHER = 5, - SCATTER = 6, - REDUCE_SCATTER = 7, - ALLTOALL = 8, - SEND = 9, - RECV = 10, - BARRIER = 11, - UNKNOWN = 100, -}; - class ProcessGroup { public: class Task { @@ -95,6 +89,15 @@ class ProcessGroup { int GetSize() const { return size_; } + int GetGid() const { return gid_; } + + std::string GetGroupMessage() const { + return std::string("rank_in_group: ") + std::to_string(rank_) + + std::string(", nranks: ") + std::to_string(size_) + + std::string(", gid: ") + std::to_string(gid_) + + std::string(", backend: ") + GetBackendName(); + } + virtual std::string GetBackendName() const = 0; virtual phi::DeviceContext* GetDeviceContext( @@ -294,7 +297,7 @@ class ProcessGroup { const phi::DenseTensor& in_tensor UNUSED, const BroadcastOptions& opts UNUSED, bool sync_op UNUSED, - bool use_calc_stream UNUSED) { + bool use_calc_stream) { PADDLE_THROW( phi::errors::Unimplemented("ProcessGroup%s does not support broadcast " "with sync_op and use_calc_stream flag.", @@ -412,68 +415,57 @@ class ProcessGroup { // legacy APIs // TODO(liyurui): This API will be moved later virtual std::shared_ptr AllReduce( - std::vector& /* input tensors */, // NOLINT - std::vector& /* output tensors */, // NOLINT - const AllreduceOptions& UNUSED = AllreduceOptions()) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support allreduce", GetBackendName())); + std::vector& inputs, // NOLINT + std::vector& outputs, // NOLINT + const AllreduceOptions& options = AllreduceOptions()) { + return AllReduce(outputs.data(), inputs.front(), options, false); } virtual std::shared_ptr AllReduce( - std::vector& /* input tensors */, // NOLINT - std::vector& /* output tensors */, // NOLINT - const AllreduceOptions& UNUSED, - bool) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support allreduce with sync_op flag", - GetBackendName())); + std::vector& inputs, // NOLINT + std::vector& outputs, // NOLINT + const AllreduceOptions& options, + bool sync_op) { + return AllReduce(outputs.data(), inputs.front(), options, sync_op); } // TODO(sunyilun): methods below will be removed later virtual std::shared_ptr Broadcast( - std::vector& /* input tensors */, // NOLINT - std::vector& /* output tensors */, // NOLINT - const BroadcastOptions& UNUSED = BroadcastOptions()) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support broadcast", GetBackendName())); + std::vector& inputs, // NOLINT + std::vector& outputs, // NOLINT + const BroadcastOptions& options = BroadcastOptions()) { + return Broadcast(outputs.data(), inputs.front(), options, false); } virtual std::shared_ptr Broadcast( - std::vector& /* input tensors */, // NOLINT - std::vector& /* output tensors */, // NOLINT - const BroadcastOptions& UNUSED, - bool) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support broadcast with sync_op flag", - GetBackendName())); + std::vector& inputs, // NOLINT + std::vector& outputs, // NOLINT + const BroadcastOptions& options, + bool sync_op) { + return Broadcast(outputs.data(), inputs.front(), options, sync_op); } virtual std::shared_ptr Send( - std::vector&, int) { // NOLINT - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support send", GetBackendName())); + std::vector& tensors, int dst_rank) { // NOLINT + return Send(tensors.front(), dst_rank, false); } virtual std::shared_ptr Recv( - std::vector&, int) { // NOLINT - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support recv", GetBackendName())); + std::vector& tensors, int src_rank) { // NOLINT + return Recv(&tensors.front(), src_rank, false); } virtual std::shared_ptr AllGather( - std::vector&, // NOLINT - std::vector&) { // NOLINT - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support all_gather", GetBackendName())); + std::vector& in_tensors, // NOLINT + std::vector& out_tensors) { // NOLINT + return AllGather(out_tensors.data(), in_tensors.front(), false); } virtual std::shared_ptr AllGather( - std::vector&, // NOLINT - std::vector&, // NOLINT - bool) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support all_gather with sync_op flag", - GetBackendName())); + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + bool sync_op) { + return AllGather(out_tensors.data(), in_tensors.front(), sync_op); } virtual std::shared_ptr AllToAll( @@ -484,19 +476,17 @@ class ProcessGroup { } virtual std::shared_ptr Reduce( - std::vector&, // NOLINT - std::vector&, // NOLINT - const ReduceOptions& opts UNUSED) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support reduce", GetBackendName())); + std::vector& ins, // NOLINT + std::vector& outs, // NOLINT + const ReduceOptions& opts) { + return Reduce(outs.data(), ins.front(), opts, false); } virtual std::shared_ptr Scatter( - std::vector&, // NOLINT - std::vector&, // NOLINT - const ScatterOptions&) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support scatter", GetBackendName())); + std::vector& ins, // NOLINT + std::vector& outs, // NOLINT + const ScatterOptions& opts) { + return Scatter(outs.data(), ins.front(), opts, false); } protected: diff --git a/paddle/fluid/distributed/collective/process_group_bkcl.cc b/paddle/fluid/distributed/collective/process_group_bkcl.cc index 4331041c4f043..81f52bc97f334 100644 --- a/paddle/fluid/distributed/collective/process_group_bkcl.cc +++ b/paddle/fluid/distributed/collective/process_group_bkcl.cc @@ -16,7 +16,7 @@ #include "paddle/fluid/distributed/collective/bkcl_tools.h" #include "paddle/fluid/distributed/collective/common.h" -#include "paddle/fluid/distributed/collective/utils.h" +#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device/xpu/xpu_info.h" #include "paddle/phi/api/lib/utils/allocator.h" diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index 64dce7b4c6b11..1313d19a2bbfa 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -16,7 +16,6 @@ #include "paddle/fluid/distributed/collective/common.h" #include "paddle/fluid/distributed/collective/custom_ccl_tools.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/core/distributed/check/static_check.h" #include "paddle/phi/core/enforce.h" @@ -32,6 +31,8 @@ PD_DECLARE_bool(use_stream_safe_cuda_allocator); namespace paddle { namespace distributed { +using phi::distributed::CheckSizeOnEachRank; +using phi::distributed::GetPointerByOffset; static std::mutex g_unfinished_xccl_task_events_mutex; static std::list> g_unfinished_xccl_task_events; diff --git a/paddle/fluid/distributed/collective/process_group_custom.h b/paddle/fluid/distributed/collective/process_group_custom.h index 13970b2e349a0..a3fb060376597 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.h +++ b/paddle/fluid/distributed/collective/process_group_custom.h @@ -22,6 +22,7 @@ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group_with_stream.h" +#include "paddle/phi/backends/custom/custom_context.h" #include "paddle/phi/backends/device_manager.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/device_context.h" diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index 89f5dcb222e63..8877224eb7674 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -15,21 +15,27 @@ #include "paddle/fluid/distributed/collective/process_group_nccl.h" #include "paddle/fluid/distributed/collective/common.h" -#include "paddle/fluid/distributed/collective/nccl_tools.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/distributed/check/nccl_dynamic_check.h" #include "paddle/phi/core/distributed/check/static_check.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/comm_task_manager.h" +#include "paddle/phi/core/distributed/nccl_comm_task.h" +#include "paddle/phi/core/distributed/nccl_tools.h" +#include "paddle/phi/core/distributed/trace_utils.h" +#include "paddle/phi/core/distributed/utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/utils/data_type.h" -#include "paddle/phi/core/distributed/comm_context_manager.h" - +PHI_DECLARE_bool(benchmark); +PHI_DECLARE_bool(benchmark_nccl); PHI_DECLARE_bool(nccl_blocking_wait); -PD_DECLARE_bool(use_stream_safe_cuda_allocator); +PHI_DECLARE_bool(use_stream_safe_cuda_allocator); +PHI_DECLARE_bool(enable_async_trace); // set this flag to `true` and recompile to enable dynamic checks constexpr bool FLAGS_enable_nccl_dynamic_check = false; @@ -38,6 +44,17 @@ constexpr int64_t kWaitBlockTImeout = 10; namespace paddle { namespace distributed { +using phi::distributed::CheckSizeOnEachRank; +using phi::distributed::GetTraceEndKey; +using phi::distributed::GetTraceStartKey; +using phi::distributed::IsP2POP; +using phi::distributed::NCCLDTypeToString; +using phi::distributed::NCCLRedTypeToString; +using phi::distributed::SerializeNCCLUniqueId; +using phi::distributed::ToNCCLRedType; + +uint64_t ProcessGroupNCCL::s_group_call_counter = 0; + ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place, int rank, CommType comm_type, @@ -60,7 +77,7 @@ void ProcessGroupNCCL::NCCLTask::UpdateWaitChain( bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { // Warning here when use calc stream but also invoke waiting explicitly. if (UseCalcStream()) { - VLOG(3) << "Warning: The communication is on calc stream, wait here is " + VLOG(5) << "Warning: The communication is on calc stream, wait here is " "useless."; return true; } @@ -80,7 +97,7 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { // If we use the work to do barrier, we should block cpu #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); -#else +#else // PADDLE_WITH_HIP PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); #endif } @@ -94,20 +111,40 @@ ProcessGroupNCCL::ProcessGroupNCCL( const std::shared_ptr& store, int rank, int size, - int gid) - : ProcessGroupWithStream(rank, size, gid), store_(store) {} + int gid, + int64_t timeout) + : ProcessGroupWithStream(rank, size, gid), + store_(store), + pg_timeout_(timeout) { + LOG(INFO) << "ProcessGroupNCCL pg_timeout_ " << pg_timeout_; +} void ProcessGroupNCCL::GroupStart() { NCCL_CHECK(phi::dynload::ncclGroupStart()); + ++s_group_call_counter; } -void ProcessGroupNCCL::GroupEnd() { NCCL_CHECK(phi::dynload::ncclGroupEnd()); } +void ProcessGroupNCCL::GroupEnd() { + NCCL_CHECK(phi::dynload::ncclGroupEnd()); + --s_group_call_counter; + // NOTE: This is to sync the calc stream and comm stream for debug using + // batch_isend_irecv + if (FLAGS_benchmark || FLAGS_benchmark_nccl) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else // PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + } +} phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( const Place& place) const { return GetDeviceContext(place, /*use_calc_stream*/ false); } +// NOTE(shenliang03): GetDeviceContext is only used for collective, it can't +// be used for p2p op. phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( const Place& place, bool use_calc_stream) const { const std::string& key = GetKeyFromPlace(place); @@ -146,9 +183,21 @@ std::shared_ptr ProcessGroupNCCL::AllGather( // numel > 0 indicates the tensor need to be sliced const phi::DenseTensor& in_tensor_maybe_partial = numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor; - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { + VLOG(3) << "[ncclAllGather] " + << "sendbuff: " << in_tensor_maybe_partial.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor_maybe_partial.numel() + << ", datatype: " + << NCCLDTypeToString( + phi::ToNCCLDataType(in_tensor_maybe_partial.dtype())) + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", offset: " << offset + << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); comm_context->AllGather(out_tensor, in_tensor_maybe_partial, stream); }, in_tensor_maybe_partial, @@ -163,9 +212,21 @@ std::shared_ptr ProcessGroupNCCL::AllReduce( const AllreduceOptions& opts, bool sync_op, bool use_calc_stream) { - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { + VLOG(3) << "[ncclAllReduce] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", redop: " + << NCCLRedTypeToString(ToNCCLRedType(opts.reduce_op)) + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); + comm_context->AllReduce( out_tensor, in_tensor, ToNCCLRedType(opts.reduce_op), stream); }, @@ -191,9 +252,15 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( // simply be covered by static checks. Factors are set to 0 here to skip the // shape check. Its shape check will be done by dynamic checks with // FLAGS_enable_nccl_dynamic_check. - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + phi::distributed::CommStaticCheck::CheckShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + /*out_size_factor*/ 0, + /*in_size_factor*/ 0); + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { if (FLAGS_enable_nccl_dynamic_check) { phi::distributed::NCCLDynamicCheck::CheckShape( *out_tensor, @@ -203,13 +270,27 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( size_, comm_context->GetNcclComm()); } - int64_t in_row_size = in_tensor.numel() / in_dim[0], out_row_size = out_tensor->numel() / out_dim[0]; int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0; phi::DenseTensor input_partial, output_partial; - comm_context->GroupStart(); + VLOG(3) << "[AllToAll] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", out_size_each_rank: " + << string::join_strings(out_size_each_rank, ',') + << ", in_size_each_rank: " + << string::join_strings(in_size_each_rank, ',') + << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); + + GroupStart(); for (auto i = 0; i < size_; i++) { in_numel = in_size_each_rank[i] * in_row_size; input_partial = GetPartialTensor(in_tensor, in_offset, in_numel); @@ -221,7 +302,7 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( comm_context->Recv(&output_partial, out_numel, i, stream); out_offset += out_numel; } - comm_context->GroupEnd(); + GroupEnd(); }, in_tensor, CommType::ALLTOALL, @@ -241,6 +322,9 @@ std::shared_ptr ProcessGroupNCCL::Barrier( phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1}); phi::DenseTensor barrier_tensor{allocator.get(), meta}; + VLOG(3) << "[Barrier] " + << "barrier opt: " << opts.device_id; + auto task = AllReduce(&barrier_tensor, barrier_tensor, {}, @@ -257,10 +341,21 @@ std::shared_ptr ProcessGroupNCCL::Broadcast( const BroadcastOptions& opts, bool sync_op, bool use_calc_stream) { - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { int root = opts.source_rank + opts.source_root; - auto comm_context = this->GetCommContext(); + + VLOG(3) << "[ncclBroadcast] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", root: " << root + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); comm_context->Broadcast(out_tensor, in_tensor, root, stream); }, in_tensor, @@ -275,9 +370,21 @@ std::shared_ptr ProcessGroupNCCL::Reduce( const ReduceOptions& opts, bool sync_op, bool use_calc_stream) { - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { + VLOG(3) << "[ncclReduce] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", redop: " + << NCCLRedTypeToString(ToNCCLRedType(opts.reduce_op)) + << ", root: " << opts.root_rank + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); comm_context->Reduce(out_tensor, in_tensor, ToNCCLRedType(opts.reduce_op), @@ -296,9 +403,20 @@ std::shared_ptr ProcessGroupNCCL::ReduceScatter( const ReduceScatterOptions& opts, bool sync_op, bool use_calc_stream) { - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { + VLOG(3) << "[ncclReduceScatter] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", redop: " + << NCCLRedTypeToString(ToNCCLRedType(opts.reduce_op)) + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); comm_context->ReduceScatter( out_tensor, in_tensor, ToNCCLRedType(opts.reduce_op), stream); }, @@ -320,9 +438,8 @@ std::shared_ptr ProcessGroupNCCL::Scatter( /*dst_rank*/ opts.root_rank, /*cur_rank*/ rank_, size_); - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { if (FLAGS_enable_nccl_dynamic_check) { phi::distributed::NCCLDynamicCheck::CheckShape( *out_tensor, @@ -331,18 +448,30 @@ std::shared_ptr ProcessGroupNCCL::Scatter( comm_context->GetNcclComm()); } + VLOG(3) << "[Scatter] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", root: " << opts.root_rank + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); + int64_t numel = in_tensor.numel() / size_; if (rank_ == opts.root_rank) { int64_t offset = 0; phi::DenseTensor partial_tensor; - comm_context->GroupStart(); + GroupStart(); for (auto i = 0; i < size_; i++) { partial_tensor = GetPartialTensor(in_tensor, offset, numel); comm_context->Send(partial_tensor, numel, i, stream); offset += numel; } comm_context->Recv(out_tensor, numel, opts.root_rank, stream); - comm_context->GroupEnd(); + GroupEnd(); } else { comm_context->Recv(out_tensor, numel, opts.root_rank, stream); } @@ -385,8 +514,8 @@ std::shared_ptr ProcessGroupNCCL::Gather( "root world size [%d] is less than root rank [%d]", size_, opts.root_rank)); - auto gather_func = [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + auto gather_func = [&](phi::distributed::NCCLCommContext* comm_context, + gpuStream_t stream) { // shape check if (FLAGS_enable_nccl_dynamic_check) { phi::distributed::NCCLDynamicCheck::CheckGatherShape( @@ -398,7 +527,17 @@ std::shared_ptr ProcessGroupNCCL::Gather( comm_context->GetNcclComm()); } - comm_context->GroupStart(); + VLOG(3) << "[Gather] " + << "sendbuff: " << in_tensor.data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", root: " << opts.root_rank + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream << GetGroupMessage(); + + GroupStart(); // root receive from all devices if (rank_ == opts.root_rank) { for (auto i = 0; i < size_; i++) { @@ -408,9 +547,9 @@ std::shared_ptr ProcessGroupNCCL::Gather( } // send to root comm_context->Send(in_tensor, in_tensor.numel(), opts.root_rank, stream); - comm_context->GroupEnd(); + GroupEnd(); }; - return RunFnInNCCLEnv( + return Collective( gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream); } @@ -428,11 +567,25 @@ std::shared_ptr ProcessGroupNCCL::Recv( tensor = &partial_tensor; } - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); - comm_context->Recv(tensor, tensor->numel(), src_rank, stream); + return Point2Point( + [&](phi::distributed::NCCLCommContext* comm_context, + gpuStream_t stream, + int rank_in_group) { + VLOG(3) << "[ncclRecv] " + << "recvbuff: " << tensor->data() + << ", count: " << tensor->numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(tensor->dtype())) + << ", src_in_group: " << src_rank + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream + << ", rank_in_group: " << rank_in_group << ", nranks: " << size_ + << ", offset: " << offset << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); + + comm_context->Recv(tensor, tensor->numel(), rank_in_group, stream); }, + src_rank, *tensor, CommType::RECV, sync_op, @@ -450,14 +603,29 @@ std::shared_ptr ProcessGroupNCCL::Send( const phi::DenseTensor& tensor_maybe_partial = numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor; - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + return Point2Point( + [&](phi::distributed::NCCLCommContext* comm_context, + gpuStream_t stream, + int rank_in_group) { + VLOG(3) << "[ncclSend] " + << "sendbuff: " << tensor_maybe_partial.data() + << ", count: " << tensor_maybe_partial.numel() << ", datatype: " + << NCCLDTypeToString( + phi::ToNCCLDataType(tensor_maybe_partial.dtype())) + << ", dst_in_group: " << dst_rank + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream + << ", rank_in_group: " << rank_in_group << ", nranks: " << size_ + << ", offset: " << offset << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); + comm_context->Send(tensor_maybe_partial, tensor_maybe_partial.numel(), - dst_rank, + rank_in_group, stream); }, + dst_rank, tensor_maybe_partial, CommType::SEND, sync_op, @@ -474,84 +642,133 @@ std::shared_ptr ProcessGroupNCCL::CreateTask( place, rank, comm_type, is_sync, use_calc_stream); } -void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) { - const std::string key = - "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/0"; - if (rank_ == 0) { - std::vector nccl_id_wrapper( - reinterpret_cast(nccl_id), - reinterpret_cast(nccl_id) + NCCL_UNIQUE_ID_BYTES); - store_->set(key, nccl_id_wrapper); +void ProcessGroupNCCL::GetStoreKey(const std::string& place_key, + CommType comm_type, + std::string* store_key) { + bool is_batch_p2p = s_group_call_counter > 0; + bool is_p2p_op = IsP2POP(comm_type, is_batch_p2p); + + if (!is_p2p_op) { + *store_key = "nccl_ids/" + std::to_string(gid_) + "/0"; } else { - const auto& nccl_id_wrapper = store_->get(key); - std::memcpy(nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size()); + *store_key = "nccl_ids/" + std::to_string(gid_) + "/" + place_key; } } void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, - const std::string& place_key) { - if (!place_to_comm_ctx_.empty()) { - VLOG(3) << "Warning: Tensors from multiple devices are not supported yet."; + const std::string& place_key, + const std::string& store_key, + CommType comm_type, + int p2p_rank) { + VLOG(3) << "init nccl rank_in_group: " << rank_ << ", nranks: " << size_ + << ", gid: " << gid_ << ", place key: " << place_key + << ", store_key: " << store_key; + + for (size_t i = 0; i < s_group_call_counter; ++i) { + NCCL_CHECK(phi::dynload::ncclGroupEnd()); } - VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_ - << ", place: " << place_key; + bool is_batch_p2p = s_group_call_counter > 0; + bool is_p2p_op = IsP2POP(comm_type, is_batch_p2p); + + int num_ranks = is_p2p_op ? 2 : GetSize(); + int rank = is_p2p_op ? p2p_rank : GetRank(); + NCCL_CHECK(phi::dynload::ncclGroupStart()); + + phi::distributed::P2POption p2p_opts({is_p2p_op, p2p_rank, num_ranks, rank}); phi::distributed::CommContextManager::CreateNCCLCommContext( - store_, std::to_string(gid_), rank_, size_); + store_, store_key, rank_, size_, "", &p2p_opts); + + NCCL_CHECK(phi::dynload::ncclGroupEnd()); + + auto nccl_comm_ctx = this->GetCommContext(&store_key); + VLOG(3) << "Get nccl comm: " << nccl_comm_ctx->GetNcclComm() + << " for place_key: " << place_key << " on rank_in_group: " << rank + << " nranks: " << num_ranks << " gid: " << gid_; - auto* calc_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); auto comm_ctx = std::make_unique(place); - auto nccl_comm_ctx = this->GetCommContext(); comm_ctx->set_nccl_comm(nccl_comm_ctx->GetNcclComm()); + auto* calc_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + place_to_calc_event_.emplace( place_key, platform::DeviceEvent(place, platform::GenerateDeviceEventFlag())); place_to_calc_ctx_.emplace(place_key, calc_ctx); place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx)); - // TODO(sunyilun): for compatibility, will be removed later - std::vector comm_ctx_wrapper{ - place_to_comm_ctx_[place_key].get()}; - places_to_ctx_.emplace(place_key, comm_ctx_wrapper); + for (size_t i = 0; i < s_group_call_counter; ++i) { + NCCL_CHECK(phi::dynload::ncclGroupStart()); + } } -void ProcessGroupNCCL::SyncCalcStream(const Place& place) { - const std::string& key = GetKeyFromPlace(place); - auto& calc_event = place_to_calc_event_.at(key); - const auto* calc_ctx = place_to_calc_ctx_.at(key); - const auto* comm_ctx = place_to_comm_ctx_.at(key).get(); +void ProcessGroupNCCL::SyncCalcStream(const Place& place, + const std::string& place_key) { + auto& calc_event = place_to_calc_event_.at(place_key); + const auto* calc_ctx = place_to_calc_ctx_.at(place_key); + const auto* comm_ctx = place_to_comm_ctx_.at(place_key).get(); calc_event.Record(calc_ctx); calc_event.Wait(platform::Place2DeviceType(place), comm_ctx); } -std::shared_ptr ProcessGroupNCCL::RunFnInNCCLEnv( - std::function fn, +std::shared_ptr ProcessGroupNCCL::Collective( + std::function fn, const phi::DenseTensor& tensor, CommType comm_type, bool sync_op, bool use_calc_stream) { + comm_seq_++; const auto& place = tensor.place(); const auto& key = GetKeyFromPlace(place); platform::CUDADeviceGuard cuda_guard(place); + std::string store_key; + GetStoreKey(key, comm_type, &store_key); + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { - CreateNCCLEnvCache(place, key); + CreateNCCLEnvCache(place, key, store_key, comm_type); } if (!use_calc_stream) { - SyncCalcStream(place); + SyncCalcStream(place, key); } auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream); const auto* calc_ctx = place_to_calc_ctx_.at(key); const auto& comm_ctx = place_to_comm_ctx_.at(key); + auto nccl_comm = comm_ctx->nccl_comm(); auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); - fn(nccl_stream); + + auto nccl_comm_ctx = this->GetCommContext(&store_key); + + if (!FLAGS_enable_async_trace) { + fn(nccl_comm_ctx, nccl_stream); + } else { + auto comm_task = + std::make_shared(place, + rank_, + size_, + gid_, + comm_seq_, + tensor.numel(), + sync_op, + use_calc_stream, + nccl_comm, + nccl_stream, + comm_type, + pg_timeout_); + comm_task->StartRecord(); + fn(nccl_comm_ctx, nccl_stream); + comm_task->EndRecord(); + comm_task->SetStore(store_); + + auto& comm_task_manager = phi::distributed::CommTaskManager::GetInstance(); + comm_task_manager.CommTaskEnqueue(std::move(comm_task)); + } if (!use_calc_stream) { if (FLAGS_use_stream_safe_cuda_allocator) { @@ -564,443 +781,145 @@ std::shared_ptr ProcessGroupNCCL::RunFnInNCCLEnv( task->SetBlockCPUInWait(); task->Wait(); } - return task; -} -// TODO(sunyilun): methods below will be removed later -void SyncDefaultStream(const std::vector& places, - platform::DeviceEvent& nccl_event, // NOLINT - std::vector& dev_ctx) { // NOLINT - for (size_t i = 0; i < places.size(); ++i) { - auto* default_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(places[i])); - nccl_event.Record(default_ctx); - nccl_event.Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]); - } -} - -std::shared_ptr ProcessGroupNCCL::CreateTask( - std::vector places, - int rank, - CommType comm_type, - const std::vector& inputs) { - return std::make_shared( - places, rank, comm_type, inputs); -} - -ProcessGroupNCCL::NCCLTask::NCCLTask( - const std::vector& places, - int rank, - CommType CommType, - const std::vector& inputs) - : TaskStream(rank, inputs, CommType), - comm_event_(places[0], platform::GenerateDeviceEventFlag()), - task_place_(places[0]) {} - -// create NCCLManager cache for places_key -void ProcessGroupNCCL::CreateNCCLManagerCache( - const std::string& places_key, const std::vector& places) { - PADDLE_ENFORCE_EQ(places_key.empty(), - false, - phi::errors::PreconditionNotMet( - "Not able to create/get the NCCL Communicator since " - "the GPU place are not known")); - - ncclUniqueId nccl_id; - if (rank_ == 0) { - NCCL_CHECK(phi::dynload::ncclGetUniqueId(&nccl_id)); + if (sync_op) { + task->Wait(); } - BroadcastUniqueNCCLID(&nccl_id); - - VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_ - << ", place: " << places_key - << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id); - std::vector> dev_ctx; - dev_ctx.resize(places.size()); - - std::vector dev_ctx_raw; - dev_ctx_raw.resize(places.size()); - - GroupStart(); - - for (size_t i = 0; i < places.size(); ++i) { - platform::CUDADeviceGuard guard(places[i]); - - dev_ctx[i] = std::make_unique(places[i]); - ncclComm_t nccl_comm; - NCCL_CHECK(phi::dynload::ncclCommInitRank( - &nccl_comm, GetSize(), nccl_id, GetRank())); - dev_ctx[i]->set_nccl_comm(nccl_comm); - dev_ctx_raw[i] = dev_ctx[i].get(); + if (FLAGS_benchmark || FLAGS_benchmark_nccl) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else // PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif } - GroupEnd(); - - // TODO(sunyilun): for compatibility, will be removed later - place_to_calc_event_.emplace( - places_key, - platform::DeviceEvent(places[0], platform::GenerateDeviceEventFlag())); - place_to_calc_ctx_.emplace( - places_key, - static_cast( - platform::DeviceContextPool::Instance().Get(places[0]))); - place_to_comm_ctx_.emplace(places_key, std::move(dev_ctx[0])); - - // These caches will be useful to process sync/wait/communicate - places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw)); + return task; } -template -std::shared_ptr ProcessGroupNCCL::Collective( - std::vector& inputs, - std::vector& outputs, - Fn fn, - CommType op_type) { - const auto places = GetPlaceList(inputs); - const auto key = GetKeyFromPlaces(places); - - { - std::lock_guard lock(mutex_); - if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { - CreateNCCLManagerCache(key, places); - } - } - - SyncDefaultStream( - places, place_to_calc_event_.at(key), places_to_ctx_.at(key)); - - auto task = CreateTask(places, rank_, op_type, inputs); +std::shared_ptr ProcessGroupNCCL::Point2Point( + std::function + fn, + int peer, + const phi::DenseTensor& tensor, + CommType comm_type, + bool sync_op, + bool use_calc_stream) { + const auto& place = tensor.place(); - // construct uninitialize guard for device - platform::CUDADeviceGuard cuda_guard; + int p2p_rank = 0; + int p2p_target_rank = 0; + bool is_batch_p2p = s_group_call_counter > 0; + std::string key = ""; - { - platform::NCCLGroupGuard nccl_guard; - for (size_t i = 0; i < inputs.size(); ++i) { - cuda_guard.SetDevice(places[i]); - const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream(); - fn(inputs[i], - outputs[i], - places_to_ctx_.at(key)[i]->nccl_comm(), - nccl_stream); - } + if (is_batch_p2p) { + key = GetKeyFromPlace(place); + p2p_rank = rank_; + p2p_target_rank = peer; + } else { + int low_rank = rank_ < peer ? rank_ : peer; + int high_rank = rank_ < peer ? peer : rank_; + key = std::to_string(low_rank) + "->" + std::to_string(high_rank); + p2p_rank = rank_ < peer ? 0 : 1; + p2p_target_rank = 1 - p2p_rank; } - if (FLAGS_use_stream_safe_cuda_allocator) { - for (size_t i = 0; i < inputs.size(); ++i) { - cuda_guard.SetDevice(places[i]); - memory::RecordStream(inputs[i].Holder(), - places_to_ctx_.at(key)[i]->stream()); - } - } + platform::CUDADeviceGuard cuda_guard(place); + + std::string store_key; + GetStoreKey(key, comm_type, &store_key); - for (size_t i = 0; i < inputs.size(); ++i) { - cuda_guard.SetDevice(places[i]); - task->UpdateWaitChain(*places_to_ctx_.at(key)[i]); + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { + CreateNCCLEnvCache(place, key, store_key, comm_type, p2p_rank); } - return task; -} -template -std::shared_ptr ProcessGroupNCCL::PointToPoint( - std::vector& tensors, - Fn fn, - int dst_rank, - CommType op_type) { - const auto places = GetPlaceList(tensors); - const auto key = GetKeyFromPlaces(places); - - { - std::lock_guard lock(mutex_); - if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { - CreateNCCLManagerCache(key, places); - } + if (!use_calc_stream) { + SyncCalcStream(place, key); } - SyncDefaultStream( - places, place_to_calc_event_.at(key), places_to_ctx_.at(key)); + auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream); + const auto* calc_ctx = place_to_calc_ctx_.at(key); + const auto& comm_ctx = place_to_comm_ctx_.at(key); - auto task = CreateTask(places, rank_, op_type, tensors); + auto nccl_comm = comm_ctx->nccl_comm(); + auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); - // construct uninitialize guard for device - platform::CUDADeviceGuard cuda_guard; + auto comm_task = + std::make_shared(place, + rank_, + size_, + gid_, + comm_seq_, + tensor.numel(), + sync_op, + use_calc_stream, + nccl_comm, + nccl_stream, + comm_type); + + auto nccl_comm_ctx = this->GetCommContext(&store_key); + + if (!FLAGS_enable_async_trace) { + fn(nccl_comm_ctx, nccl_stream, p2p_target_rank); + } else { + comm_task->StartRecord(); + fn(nccl_comm_ctx, nccl_stream, p2p_target_rank); + comm_task->EndRecord(); + comm_task->SetStore(store_); - { - platform::NCCLGroupGuard nccl_guard; - for (size_t i = 0; i < tensors.size(); ++i) { - cuda_guard.SetDevice(places[i]); - const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream(); - fn(tensors[i], - places_to_ctx_.at(key)[i]->nccl_comm(), - nccl_stream, - dst_rank); - } + auto& comm_task_manager = phi::distributed::CommTaskManager::GetInstance(); + comm_task_manager.CommTaskEnqueue(std::move(comm_task)); } - if (FLAGS_use_stream_safe_cuda_allocator) { - for (size_t i = 0; i < tensors.size(); ++i) { - cuda_guard.SetDevice(places[i]); - memory::RecordStream(tensors[i].Holder(), - places_to_ctx_.at(key)[i]->stream()); + if (!use_calc_stream) { + if (FLAGS_use_stream_safe_cuda_allocator) { + memory::RecordStream(tensor.Holder(), nccl_stream); } + task->UpdateWaitChain(*comm_ctx); } - for (size_t i = 0; i < tensors.size(); ++i) { - cuda_guard.SetDevice(places[i]); - task->UpdateWaitChain(*places_to_ctx_.at(key)[i]); + if (FLAGS_enable_nccl_dynamic_check) { + task->SetBlockCPUInWait(); + task->Wait(); } - return task; -} - -std::shared_ptr ProcessGroupNCCL::AllReduce( - std::vector& in_tensors, - std::vector& out_tensors, - const AllreduceOptions& opts) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](const phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - auto comm_context = this->GetCommContext(); - comm_context->AllReduce( - &output, input, ToNCCLRedType(opts.reduce_op), stream); - }, - CommType::ALLREDUCE); -} -std::shared_ptr ProcessGroupNCCL::Broadcast( - std::vector& in_tensors, - std::vector& out_tensors, - const BroadcastOptions& opts) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - const auto root = - opts.source_rank * in_tensors.size() + opts.source_root; - auto comm_context = this->GetCommContext(); - comm_context->Broadcast(&output, input, root, stream); - }, - CommType::BROADCAST); -} - -void CheckTensorsInDifferentDevices( - const std::vector& tensors, const size_t num_devices) { - PADDLE_ENFORCE_EQ( - tensors.empty(), - false, - phi::errors::InvalidArgument("Tensor list must be nonempty.")); - PADDLE_ENFORCE_LE( - tensors.size(), - num_devices, - phi::errors::InvalidArgument( - "Tensor list mustn't be larger than the number of available GPUs.")); - - std::set used_devices; - - for (const auto& t : tensors) { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(t.place()), - true, - phi::errors::InvalidArgument("Tensors must be CUDA and dense tensor.")); - - const auto inserted = used_devices.insert(t.place()).second; - PADDLE_ENFORCE_EQ(inserted, - true, - phi::errors::InvalidArgument( - "Tensors must be on distinct GPU devices.")); + if (sync_op) { + task->Wait(); } -} -std::shared_ptr ProcessGroupNCCL::Send( - std::vector& tensors, int dst_rank) { - CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); - - auto task = PointToPoint( - tensors, - [&](phi::DenseTensor& input, - ncclComm_t comm, - const gpuStream_t& stream, - int dst_rank) { - auto comm_context = this->GetCommContext(); - comm_context->Send(input, input.numel(), dst_rank, stream); - }, - dst_rank, - CommType::SEND); - return task; -} + if (!is_batch_p2p && (FLAGS_benchmark || FLAGS_benchmark_nccl)) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else // PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + } -std::shared_ptr ProcessGroupNCCL::Recv( - std::vector& tensors, int src_rank) { - CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); - - auto task = PointToPoint( - tensors, - [&](phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream, - int src_rank) { - auto comm_context = this->GetCommContext(); - comm_context->Recv(&output, output.numel(), src_rank, stream); - }, - src_rank, - CommType::RECV); return task; } -std::shared_ptr ProcessGroupNCCL::AllGather( - std::vector& in_tensors, - std::vector& out_tensors) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(out_tensors), - true, - phi::errors::InvalidArgument("All outputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](const phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - auto comm_context = this->GetCommContext(); - comm_context->AllGather(&output, input, stream); - }, - CommType::ALLGATHER); -} - -std::shared_ptr ProcessGroupNCCL::AllToAll( - std::vector& in_tensors, - std::vector& out_tensors) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(out_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - size_t offset = 0; - size_t count = input.numel() / size_; - auto comm_context = this->GetCommContext(); - comm_context->GroupStart(); - for (auto i = 0; i < size_; i++) { - auto input_data = GetPartialTensor(input, offset, count); - comm_context->Send(input_data, count, i, stream); - auto output_data = GetPartialTensor(output, offset, count); - comm_context->Recv(&output_data, count, i, stream); - offset += count; - } - comm_context->GroupEnd(); - }, - CommType::ALLTOALL); -} - -std::shared_ptr ProcessGroupNCCL::Reduce( - std::vector& in_tensors, - std::vector& out_tensors, - const ReduceOptions& opts) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](const phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - auto comm_context = this->GetCommContext(); - comm_context->Reduce(&output, - input, - ToNCCLRedType(opts.reduce_op), - opts.root_rank, - stream); - }, - CommType::REDUCE); -} - -std::shared_ptr ProcessGroupNCCL::Scatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ScatterOptions& opts) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(out_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - auto comm_context = this->GetCommContext(); - size_t offset = 0; - size_t count = input.numel() / size_; - if (rank_ == opts.root_rank) { - comm_context->GroupStart(); - for (auto i = 0; i < size_; i++) { - auto input_data = reinterpret_cast( - GetPointerByOffset(input.data(), offset, input.dtype())); - comm_context->Send(*input_data, count, i, stream); - offset += count; - } - comm_context->Recv(&output, count, opts.root_rank, stream); - comm_context->GroupEnd(); - } else { - comm_context->Recv(&output, count, opts.root_rank, stream); - } - }, - CommType::SCATTER); -} - std::shared_ptr ProcessGroupNCCL::CreateProcessGroupNCCL( const std::shared_ptr& store, int rank, int size, - int gid) { + int gid, + int64_t timeout) { auto process_group = - std::make_shared(store, rank, size, gid); + std::make_shared(store, rank, size, gid, timeout); ProcessGroupIdMap::GetInstance().emplace(gid, process_group); return process_group; } -phi::distributed::NCCLCommContext* ProcessGroupNCCL::GetCommContext() { +phi::distributed::NCCLCommContext* ProcessGroupNCCL::GetCommContext( + const std::string* key) { + std::string store_key = std::to_string(this->gid_); + if (key && !key->empty()) { + store_key = *key; + } const auto& comm_context_manager = phi::distributed::CommContextManager::GetInstance(); auto comm_context = static_cast( - comm_context_manager.Get(std::to_string(this->gid_))); + comm_context_manager.Get(store_key)); PADDLE_ENFORCE_NE(comm_context, nullptr, phi::errors::Unavailable("NCCLCommContext is nullptr")); diff --git a/paddle/fluid/distributed/collective/process_group_nccl.h b/paddle/fluid/distributed/collective/process_group_nccl.h index b4f90dea77761..96c907e622b17 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.h +++ b/paddle/fluid/distributed/collective/process_group_nccl.h @@ -71,12 +71,14 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { const std::shared_ptr& store, int rank, int size, - int gid); + int gid, + int64_t timeout); ProcessGroupNCCL(const std::shared_ptr& store, int rank, int size, - int gid); + int gid, + int64_t timeout = 30 * 60 * 1000); std::string GetBackendName() const override { return "NCCL"; } @@ -170,42 +172,6 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ncclComm_t NCCLComm(const Place& place) const; - // TODO(liyurui): This API will be moved later - std::shared_ptr AllReduce( - std::vector& in_tensors, - std::vector& out_tensors, - const AllreduceOptions& = AllreduceOptions()) override; - - // TODO(sunyilun): methods below will be removed later - std::shared_ptr Broadcast( - std::vector& in_tensors, - std::vector& out_tensors, - const BroadcastOptions& = BroadcastOptions()) override; - - std::shared_ptr Send( - std::vector& tensors, int dst_rank) override; - - std::shared_ptr Recv( - std::vector& tensors, int src_rank) override; - - std::shared_ptr AllGather( - std::vector& in_tensors, - std::vector& out_tensors) override; - - std::shared_ptr AllToAll( - std::vector& in_tensors, - std::vector& out_tensors) override; - - std::shared_ptr Reduce( - std::vector& tensors, - std::vector& out_tensors, - const ReduceOptions& opts) override; - - std::shared_ptr Scatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ScatterOptions& opts) override; - private: std::shared_ptr CreateTask(const Place& place, int rank, @@ -213,44 +179,36 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { bool sync_op, bool use_calc_stream); - void BroadcastUniqueNCCLID(ncclUniqueId* nccl_id); + void GetStoreKey(const std::string& place_key, + CommType comm_type, + std::string* store_key); - void CreateNCCLEnvCache(const Place& place, const std::string& place_key); + void CreateNCCLEnvCache(const Place& place, + const std::string& place_key, + const std::string& store_key, + CommType comm_type, + int p2p_rank = 0); - void SyncCalcStream(const Place& place); + void SyncCalcStream(const Place& place, const std::string& place_key); - std::shared_ptr RunFnInNCCLEnv( - std::function fn, + std::shared_ptr Collective( + std::function fn, const phi::DenseTensor& tensor, CommType comm_type, bool sync_op, bool use_calc_stream); - // TODO(sunyilun): methods below will be removed later - std::shared_ptr CreateTask( - std::vector places, - int rank, - CommType op_type, - const std::vector& inputs); - - template - std::shared_ptr Collective( - std::vector& inputs, // NOLINT - std::vector& outputs, // NOLINT - Fn fn, - CommType op_type); - - template - std::shared_ptr PointToPoint( - std::vector& tensors, // NOLINT - Fn fn, - int dst_rank, - CommType op_type); - - void CreateNCCLManagerCache(const std::string& places_key, - const std::vector& places); + std::shared_ptr Point2Point( + std::function + fn, + int peer, + const phi::DenseTensor& tensor, + CommType comm_type, + bool sync_op, + bool use_calc_stream); - phi::distributed::NCCLCommContext* GetCommContext(); + phi::distributed::NCCLCommContext* GetCommContext( + const std::string* key = nullptr); private: std::shared_ptr store_; @@ -261,9 +219,13 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { std::unordered_map> place_to_comm_ctx_; + uint64_t comm_seq_{0}; + // TODO(sunyilun): attrs below will be removed later std::mutex mutex_; - std::unordered_map> places_to_ctx_; + static uint64_t s_group_call_counter; + // default 30 minutes + int64_t pg_timeout_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/collective/utils.h b/paddle/fluid/distributed/collective/utils.h deleted file mode 100644 index 90149f88bbc4f..0000000000000 --- a/paddle/fluid/distributed/collective/utils.h +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace paddle { -namespace distributed { - -inline phi::DenseTensor GetPartialTensor(const phi::DenseTensor& tensor, - int64_t offset, - int64_t numel) { - phi::DenseTensor tensor_flattened; - tensor_flattened.ShareDataWith(tensor); - tensor_flattened.Resize({tensor.numel()}); - return tensor_flattened.Slice(offset, offset + numel); -} - -inline void* GetPointerByOffset(void* raw_pointer, - size_t offset, - phi::DataType type) { - if (type == phi::DataType::FLOAT32) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::FLOAT64) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::FLOAT16) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::INT32) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::INT64) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::INT8) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::UINT8) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::BOOL) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::BFLOAT16) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Datatype %s in NCCL is not supported.", type)); - } - return nullptr; -} - -inline void CheckSizeOnEachRank(const phi::DDim& tensor_dim, - const std::vector& size_on_each_rank, - int world_size) { - int length_size_on_each_rank = size_on_each_rank.size(); - PADDLE_ENFORCE_EQ( - length_size_on_each_rank, - world_size, - phi::errors::InvalidArgument( - "The length of size_on_each_rank must be equal to world_size.")); - - int64_t sum_size_on_each_rank = - std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0); - PADDLE_ENFORCE_EQ( - sum_size_on_each_rank, - tensor_dim[0], - phi::errors::InvalidArgument( - "The sum of size_on_each_rank must be equal to tensor's dim[0].")); -} -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/test/CMakeLists.txt b/paddle/fluid/distributed/test/CMakeLists.txt index 0dd44c2318eec..aaae976133025 100644 --- a/paddle/fluid/distributed/test/CMakeLists.txt +++ b/paddle/fluid/distributed/test/CMakeLists.txt @@ -1,157 +1,115 @@ set_source_files_properties( table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old( +cc_test( table_test - SRCS - table_test.cc - DEPS - common_table - table - ps_framework_proto - ${COMMON_DEPS} - ${RPC_DEPS}) + SRCS table_test.cc + DEPS common_table table ps_framework_proto ${COMMON_DEPS} ${RPC_DEPS}) set_source_files_properties( dense_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old( +cc_test( dense_table_test - SRCS - dense_table_test.cc - DEPS - common_table - table - ps_framework_proto - ${COMMON_DEPS} - ${RPC_DEPS}) + SRCS dense_table_test.cc + DEPS common_table table ps_framework_proto ${COMMON_DEPS} ${RPC_DEPS}) set_source_files_properties( barrier_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old( +cc_test( barrier_table_test - SRCS - barrier_table_test.cc - DEPS - common_table - table - ps_framework_proto - ${COMMON_DEPS}) + SRCS barrier_table_test.cc + DEPS common_table table ps_framework_proto ${COMMON_DEPS}) set_source_files_properties( brpc_service_dense_sgd_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old( +cc_test( brpc_service_dense_sgd_test - SRCS - brpc_service_dense_sgd_test.cc - DEPS - scope - ps_service - table - ps_framework_proto - ${COMMON_DEPS}) + SRCS brpc_service_dense_sgd_test.cc + DEPS scope ps_service table ps_framework_proto ${COMMON_DEPS}) set_source_files_properties( brpc_service_sparse_sgd_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old( +cc_test( brpc_service_sparse_sgd_test - SRCS - brpc_service_sparse_sgd_test.cc - DEPS - scope - ps_service - table - ps_framework_proto - ${COMMON_DEPS}) + SRCS brpc_service_sparse_sgd_test.cc + DEPS scope ps_service table ps_framework_proto ${COMMON_DEPS}) set_source_files_properties( brpc_utils_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old( +cc_test( brpc_utils_test - SRCS - brpc_utils_test.cc - DEPS - brpc_utils - scope - phi - sendrecv_rpc - ps_service - ${COMMON_DEPS} - ${RPC_DEPS}) + SRCS brpc_utils_test.cc + DEPS brpc_utils + scope + phi + sendrecv_rpc + ps_service + ${COMMON_DEPS} + ${RPC_DEPS}) set_source_files_properties( graph_node_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old( +cc_test( graph_node_test - SRCS - graph_node_test.cc - DEPS - scope - ps_service - table - ps_framework_proto - ${COMMON_DEPS}) + SRCS graph_node_test.cc + DEPS scope ps_service table ps_framework_proto ${COMMON_DEPS}) set_source_files_properties( graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old( +cc_test( graph_node_split_test - SRCS - graph_node_split_test.cc - DEPS - scope - ps_service - table - ps_framework_proto - ${COMMON_DEPS}) + SRCS graph_node_split_test.cc + DEPS scope ps_service table ps_framework_proto ${COMMON_DEPS}) set_source_files_properties( graph_table_sample_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old( +cc_test( graph_table_sample_test - SRCS - graph_table_sample_test.cc - DEPS - table - ps_framework_proto - ${COMMON_DEPS}) + SRCS graph_table_sample_test.cc + DEPS table ps_framework_proto ${COMMON_DEPS}) set_source_files_properties( feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old( +cc_test( feature_value_test - SRCS - feature_value_test.cc - DEPS - table - common_table - sendrecv_rpc - ${COMMON_DEPS}) + SRCS feature_value_test.cc + DEPS table common_table sendrecv_rpc ${COMMON_DEPS}) set_source_files_properties( sparse_sgd_rule_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old(sparse_sgd_rule_test SRCS sparse_sgd_rule_test.cc DEPS - ${COMMON_DEPS} table) +cc_test( + sparse_sgd_rule_test + SRCS sparse_sgd_rule_test.cc + DEPS ${COMMON_DEPS} table) set_source_files_properties( ctr_accessor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old(ctr_accessor_test SRCS ctr_accessor_test.cc DEPS ${COMMON_DEPS} - table) +cc_test( + ctr_accessor_test + SRCS ctr_accessor_test.cc + DEPS ${COMMON_DEPS} table) set_source_files_properties( ctr_dymf_accessor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old(ctr_dymf_accessor_test SRCS ctr_dymf_accessor_test.cc DEPS - ${COMMON_DEPS} table) +cc_test( + ctr_dymf_accessor_test + SRCS ctr_dymf_accessor_test.cc + DEPS ${COMMON_DEPS} table) set_source_files_properties( memory_sparse_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old(memory_sparse_table_test SRCS memory_sparse_table_test.cc DEPS - ${COMMON_DEPS} table) +cc_test( + memory_sparse_table_test + SRCS memory_sparse_table_test.cc + DEPS ${COMMON_DEPS} table) set_source_files_properties( memory_geo_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test_old(memory_sparse_geo_table_test SRCS memory_geo_table_test.cc DEPS - ${COMMON_DEPS} table) +cc_test( + memory_sparse_geo_table_test + SRCS memory_geo_table_test.cc + DEPS ${COMMON_DEPS} table) diff --git a/paddle/fluid/eager/CMakeLists.txt b/paddle/fluid/eager/CMakeLists.txt index 96210c16dd9ef..f948e050387bc 100755 --- a/paddle/fluid/eager/CMakeLists.txt +++ b/paddle/fluid/eager/CMakeLists.txt @@ -75,3 +75,8 @@ cc_library( generated_op autograd_meta hook_utils) + +# FIXME(Aurelius84): It seems utils library is depended in cycle, but +# CMake only find it twice to deal cycle depend problem. If it is still +# not found, ld error will be raised. +set_target_properties(utils PROPERTIES LINK_INTERFACE_MULTIPLICITY 3) diff --git a/paddle/fluid/eager/accumulation/accumulation_node.cc b/paddle/fluid/eager/accumulation/accumulation_node.cc index c2c09444aab2f..c15739385dd43 100644 --- a/paddle/fluid/eager/accumulation/accumulation_node.cc +++ b/paddle/fluid/eager/accumulation/accumulation_node.cc @@ -113,6 +113,24 @@ static void CopyOrAddTensor(paddle::Tensor* tensor, &tensor_values); } } + } else if (LIKELY(t.is_dist_tensor())) { + PADDLE_ENFORCE( + tensor->is_dist_tensor(), + paddle::platform::errors::Fatal("A DistTensor can only do gradient " + "merge with another DistTensor.")); + PADDLE_ENFORCE(!t.is_custom_device(), + paddle::platform::errors::Fatal( + "DistTensor doesn't support custom device.")); + auto t_dist = + std::dynamic_pointer_cast(t.impl()); + paddle::Tensor t_values( + std::make_shared(t_dist->value())); + auto tensor_dist = + std::dynamic_pointer_cast( + tensor->impl()); + paddle::Tensor tensor_values( + std::make_shared(tensor_dist->value())); + paddle::imperative::TensorAdd(t_values, &tensor_values); } else { // TODO(jiabin): Support Other TensorBase later // TODO(zhanlve): Replace SelectedRowsAddTensor with diff --git a/paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h b/paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h index 8302af3169ee0..5ff677b143d60 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h +++ b/paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" paddle::Tensor add_n_ad_func(const std::vector& x); @@ -49,6 +50,11 @@ sync_batch_norm__ad_func(const paddle::Tensor& x, std::string data_layout, bool use_global_stats, bool trainable_statistics); + +paddle::Tensor reshard_ad_function( + const paddle::Tensor& tensor, + const phi::distributed::TensorDistAttr dist_attr); + namespace sparse { std::tuple grad_node; + + // Set grad_node before API Call + if (require_any_grad) { + paddle::platform::RecordEvent node_creation_record_event( + "reshard node_creation", + paddle::platform::TracerEventType::Communication, + 1); + + // Node Construction + grad_node = + std::shared_ptr(new ReshardGradNode(1, 1)); // NOLINT + + // Set TensorWrappers for Forward Inputs if needed + grad_node->SetTensorWrapperNoNeedBufferInput(input); + } + + // Forward API Call + // reshard_func(input, api_result, dist_attr); + auto dist_out_ptr = paddle::reshard(input, dist_attr); + auto api_result = paddle::Tensor(dist_out_ptr); + + // Get Outputs + auto& out = api_result; + + // Get Output AutoGradMeta + egr::AutogradMeta* out_autograd_meta = egr::EagerUtils::autograd_meta(&out); + + // Set grad_node after API call + if (require_any_grad) { + egr::EagerUtils::PassStopGradient(false, out_autograd_meta); + + // SetGradOutMeta & SetEdges + grad_node->SetGradOutMeta(input, 0); + // SetOutRank & SetHistory & SetGradInMeta + if (out_autograd_meta) { + egr::EagerUtils::SetOutRankWithSlot(out_autograd_meta, 0); + egr::EagerUtils::SetHistory(out_autograd_meta, grad_node); + } + grad_node->SetGradInMeta(out, 0); + } + + return out; +#else + PADDLE_THROW(phi::errors::Unavailable( + "Reshard is not supported in this version of Paddle. Try to recompile it " + "with WITH_DISTRIBTUE=ON and reinstall this package.")); + return paddle::Tensor(); +#endif +} diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/CMakeLists.txt b/paddle/fluid/eager/api/manual/eager_manual/nodes/CMakeLists.txt index efdcaa70131e6..7072c5568ab06 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/nodes/CMakeLists.txt +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/CMakeLists.txt @@ -3,4 +3,5 @@ set(eager_manual_nodes ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/add_n_node.cc ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/sync_batch_norm_node.cc ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc + ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/reshard_node.cc PARENT_SCOPE) diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc b/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc index ed83bb29714ff..2e4489fdcc12e 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc @@ -74,6 +74,12 @@ MultiplyGradNode::operator()( // Runtime check if we need next grad bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph; + // Set DistAttr of Out Tensor for semi-auto parallel + if (IsRunAutoParallel()) { + egr::EagerUtils::SetGradOutputDistAttr( + out_metas, {0, 1}, api_output_0, api_output_1); + } + // Inplace Check // Inplace Strategy diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h b/paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h index 8f63f4fdfeb61..bc6d1d9f1a1b6 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h @@ -396,6 +396,53 @@ class SyncBatchNormGradNode : public egr::GradNodeBase { bool trainable_statistics_; }; +class ReshardGradNode : public egr::GradNodeBase { + public: + ReshardGradNode() : egr::GradNodeBase() { + VLOG(3) << " Construct ReshardGrad Node."; + } + + ReshardGradNode(size_t bwd_in_slot_num, size_t bwd_out_slot_num) + : egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) { + VLOG(3) << " Construct ReshardGrad Node, bwd_in_slot_num: " + << bwd_in_slot_num << ", bwd_out_slot_num: " << bwd_out_slot_num; + } + + ~ReshardGradNode() override { VLOG(3) << " Destruct ReshardGrad Node."; } + + virtual paddle::small_vector, + egr::kSlotSmallVectorSize> + operator()(paddle::small_vector, + egr::kSlotSmallVectorSize>& grads, // NOLINT + bool create_graph = false, + bool is_new_grad = false) override; + + void ClearTensorWrappers() override { + input_.clear(); + SetIsTensorWrappersCleared(true); + } + + std::string name() override { return "ReshardGradNode"; } + + std::shared_ptr Copy() const override { + { + auto copied_node = + std::shared_ptr(new ReshardGradNode(*this)); + return copied_node; + } + } + + // SetTensorWrapperX + // Only input's meta is needed. + void SetTensorWrapperNoNeedBufferInput(const paddle::Tensor& input) { + input_ = egr::TensorWrapper(input, true); + } + + private: + // TensorWrappers + egr::TensorWrapper input_; +}; + namespace sparse { class SyncBatchNormGradNode : public egr::GradNodeBase { public: diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/reshard_node.cc b/paddle/fluid/eager/api/manual/eager_manual/nodes/reshard_node.cc new file mode 100644 index 0000000000000..2df60f6097704 --- /dev/null +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/reshard_node.cc @@ -0,0 +1,106 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h" +#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/imperative/tracer.h" + +paddle::small_vector, + egr::kSlotSmallVectorSize> // NOLINT +ReshardGradNode::operator()( + paddle::small_vector, + egr::kSlotSmallVectorSize>& grads, + bool create_graph, + bool is_new_grad) { +#ifdef PADDLE_WITH_DISTRIBUTE + VLOG(3) << "Running AD API GRAD: " + << "reshard_grad"; + + // Apply Gradient Hooks + auto hooked_grad = ApplyGradientHooks(grads); + + // Collect GradIn Tensors, Attrs and Recovered TensorWrappers + auto input = egr::EagerUtils::RecoverTensorWrapper(&this->input_); + const auto& dist_attr = + std::static_pointer_cast(input.impl()) + ->dist_attr(); + auto& grad_out = hooked_grad[0][0]; + // Prepare Grad function call + + const auto& out_metas = OutputMeta(); + paddle::small_vector, egr::kSlotSmallVectorSize> + returns(1); + + out_metas[0].size() == 0 ? returns[0].resize(1) + : returns[0].resize(out_metas[0].size()); + + auto& grad_input = returns[0][0]; + + VLOG(5) << "Running C++ API: " + << "reshard_func"; + + if (VLOG_IS_ON(3)) { + const char* INPUT_PRINT_TEMPLATE = "{ Input: [%s]} "; + + std::string input_str = ""; + const char* TENSOR_OUT_GRAD_TEMPLATE = " \n( out_grad , [%s]), "; + std::string input_out_grad_str = paddle::string::Sprintf( + TENSOR_OUT_GRAD_TEMPLATE, egr::EagerUtils::TensorStr(grad_out)); + input_str += input_out_grad_str; + const char* TENSOR_X_TEMPLATE = " \n( x , [%s]), "; + std::string input_x_str = paddle::string::Sprintf( + TENSOR_X_TEMPLATE, egr::EagerUtils::TensorStr(input)); + input_str += input_x_str; + VLOG(3) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str); + } + + // Backward call reshard_func function + auto dist_out_ptr = paddle::reshard(grad_out, dist_attr); + grad_input.set_impl(dist_out_ptr); + + VLOG(5) << "Finish C++ API: reshard_func"; + VLOG(6) << "gradnode_ptr = " << this; + + if (VLOG_IS_ON(4)) { + const char* INPUT_PRINT_TEMPLATE = "{ Input: [%s], \n Output: [%s] } "; + std::string input_str = ""; + std::string output_str = ""; + const char* TENSOR_OUT_GRAD_TEMPLATE = " \n( out_grad , [%s]), "; + std::string input_out_grad_str = paddle::string::Sprintf( + TENSOR_OUT_GRAD_TEMPLATE, egr::EagerUtils::TensorStr(grad_out)); + input_str += input_out_grad_str; + const char* TENSOR_X_TEMPLATE = " \n( x , [%s]), "; + std::string input_x_str = paddle::string::Sprintf( + TENSOR_X_TEMPLATE, egr::EagerUtils::TensorStr(input)); + input_str += input_x_str; + const char* TENSOR_X_GRAD_TEMPLATE = " \n ( input_grad , [%s]), "; + std::string output_x_grad_str = paddle::string::Sprintf( + TENSOR_X_GRAD_TEMPLATE, egr::EagerUtils::TensorStr(grad_input)); + output_str += output_x_grad_str; + VLOG(4) << paddle::string::Sprintf( + INPUT_PRINT_TEMPLATE, input_str, output_str); + } + + return returns; +#else + PADDLE_THROW(phi::errors::Unavailable( + "ReshardGrad is not supported in this version of Paddle. Try to " + "recompile it with WITH_DISTRIBTUE=ON and reinstall this package.")); + return paddle::small_vector, + egr::kSlotSmallVectorSize>(1); +#endif +} diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 5fb5c99f1c09f..c2613dffa201d 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -1877,6 +1877,18 @@ static std::pair GenerateForwardFunctionContents( trace_op_body_str += trace_op_str; trace_op_body_str += "\n"; + // [Generation] Log memory infomation + const char* LOG_MEMORY_INFO_TEMPLATE = + " // Log memory information\n" + " " + "paddle::memory::LogDeviceMemoryStats(egr::Controller::Instance()." + "GetExpectedPlace(), \"%s\");\n"; + std::string log_memory_info_str = + paddle::string::Sprintf(LOG_MEMORY_INFO_TEMPLATE, op_type); + + trace_op_body_str += log_memory_info_str; + trace_op_body_str += "\n"; + VLOG(6) << "Generated AttrMap & TraceOp"; // [Generation] Convert output VarBase to Vector/Tensor @@ -2968,6 +2980,7 @@ static std::string GenerateDygraphHFileIncludes() { "#pragma once\n" "#include \"glog/logging.h\"\n" "#include \"paddle/fluid/eager/autograd_meta.h\"\n" + "#include \"paddle/fluid/memory/stats.h\"\n" "#include \"paddle/phi/api/all.h\"\n" "#include \"paddle/fluid/eager/utils.h\"\n" "#include \"paddle/fluid/imperative/tracer.h\"\n" diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 073c5588b1eb8..ad6bef79a9f90 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -270,6 +270,8 @@ class {} : public egr::GradNodeBase {{ {} // Forward API Call +{} + // Log memory infomation {} // Check NaN and Inf if needed {} @@ -320,6 +322,8 @@ class {} : public egr::GradNodeBase {{ // Before log info {} // Forward API Call +{} + // Log memory infomation {} // Check NaN and Inf if needed {} @@ -412,6 +416,7 @@ class {} : public egr::GradNodeBase {{ #include "paddle/fluid/prim/api/all.h" #include "paddle/fluid/prim/utils/utils.h" #include "paddle/phi/core/flags.h" +#include "paddle/fluid/memory/stats.h" #include "paddle/phi/api/lib/data_transform.h" PHI_DECLARE_bool(check_nan_inf); {} @@ -575,6 +580,18 @@ class {} : public egr::GradNodeBase {{ }} """ +FILL_ZERO_GRAD_TEMPLATE_BACKWARD = """ + egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}], input_metas[{fwd_position}]); +""" + +FILL_ZERO_PLAIN_GRAD_TEMPLATE_BACKWARD = """ + egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}][0], input_metas[{fwd_position}][0]); +""" + +FILL_ZERO_OPTIONAL_PLAIN_GRAD_TEMPLATE_BACKWARD = """ + egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[{fwd_position}][0], input_metas[{fwd_position}][0]); +""" + inplace_optional_out_type_map = { "Tensor": "paddle::optional&", "std::vector": "paddle::optional>&", @@ -1054,7 +1071,11 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False): or IsVectorTensorType(atype) or (name in self.optional_inputs) ): - set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});" + if for_backward is False: + set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});" + else: + set_tensor_wrappers = f"{indent}if({name}_optional) grad_node->SetTensorWrapper{name}(*{name}_optional);" + else: need_pre_contiguous_set.add(name) set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name}_tmp);" @@ -1133,7 +1154,10 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False): ) if is_optional: - set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});" + if for_backward is False: + set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});" + else: + set_grad_out_meta = f"{indent}if({name}_optional.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}_optional.get_ptr()), {pos});" else: if ( is_special_forward_api @@ -1742,6 +1766,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): forward_call_str = f"{indent}{api_out_type} api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str_tmp});" dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n" + log_memory_info_str = f"{indent}paddle::memory::LogDeviceMemoryStats(egr::Controller::Instance().GetExpectedPlace(), \"{forward_api_name}\");" forward_ad_function_name = GetDygraphForwardFunctionName( forward_api_name ) @@ -1828,6 +1853,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): forward_api_name, before_log_str, forward_call_str, + log_memory_info_str, check_nan_inf_str, get_outputs_str, forward_api_name, @@ -1854,6 +1880,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): node_creation_pre_contiguous_str, node_creation_before_call_str, forward_call_str, + log_memory_info_str, check_nan_inf_str, get_outputs_str, outputs_autograd_meta_str, @@ -2209,12 +2236,22 @@ def GenerateNodeDefinition( ) in backward_grad_inputs_map.items(): if name in self.optional_inputs: if IsPlainTensorType(ttype): - fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[{fwd_position}][0], input_metas[{fwd_position}][0]);\n" + fill_zero_str += FILL_ZERO_OPTIONAL_PLAIN_GRAD_TEMPLATE_BACKWARD.format( + fwd_position=fwd_position + ) else: if IsPlainTensorType(ttype): - fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}][0], input_metas[{fwd_position}][0]);\n" + fill_zero_str += ( + FILL_ZERO_PLAIN_GRAD_TEMPLATE_BACKWARD.format( + fwd_position=fwd_position + ) + ) else: - fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}], input_metas[{fwd_position}]);\n" + fill_zero_str += ( + FILL_ZERO_GRAD_TEMPLATE_BACKWARD.format( + fwd_position=fwd_position + ) + ) inplace_grad_input_str = "" inplace_check_str = "" diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 60e02a29d72b4..8aa2f64ccb2ec 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/eager/backward.h" #include "paddle/fluid/eager/general_grad.h" +#include "paddle/fluid/memory/stats.h" #include "paddle/phi/kernels/autotune/switch_autotune.h" namespace egr { @@ -111,6 +112,8 @@ std::vector RunBackward( const std::vector& no_grad_vars = {}) { VLOG(3) << "Start Backward"; + auto place = egr::Controller::Instance().GetExpectedPlace(); + std::queue force_sequential_nodes_forward_queue = egr::Controller::Instance().GetForceSequentialNodes(); std::deque force_sequential_nodes_queue; @@ -405,6 +408,7 @@ std::vector RunBackward( } } } + paddle::memory::LogDeviceMemoryStats(place, std::string((*node).name())); } VLOG(7) << "Run Backward Final hook size: " diff --git a/paddle/fluid/eager/custom_operator/CMakeLists.txt b/paddle/fluid/eager/custom_operator/CMakeLists.txt index a2648d3e32556..a74ba2dc8c628 100644 --- a/paddle/fluid/eager/custom_operator/CMakeLists.txt +++ b/paddle/fluid/eager/custom_operator/CMakeLists.txt @@ -1,4 +1,9 @@ cc_library( custom_operator_node SRCS custom_operator_node.cc + DEPS phi grad_node_info custom_operator utils custom_operator_utils) + +cc_library( + custom_operator_utils + SRCS custom_operator_utils.cc DEPS phi grad_node_info custom_operator utils) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.cc b/paddle/fluid/eager/custom_operator/custom_operator_node.cc index 5643c0e69391f..9b6318c7a43ed 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/eager/custom_operator/custom_operator_node.h" +#include "paddle/fluid/eager/custom_operator/custom_operator_utils.h" #include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/custom_operator_utils.h" #include "paddle/fluid/platform/profiler/event_tracing.h" @@ -172,8 +173,6 @@ RunCustomOpNode::operator()(paddle::small_vector, paddle::OpMetaInfoHelper::GetInputs(vec_map[1]); const auto& grad_outputs_names = paddle::OpMetaInfoHelper::GetOutputs(vec_map[1]); - const auto& grad_inplace_map = - paddle::OpMetaInfoHelper::GetInplaceMap(vec_map[1]); const auto& map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_); @@ -251,11 +250,12 @@ RunCustomOpNode::operator()(paddle::small_vector, } VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad"; - // handle inplace map - ctx.UpdatePlainOutputs( - grad_inputs_name, grad_outputs_names, grad_inplace_map); - (*paddle::OpMetaInfoHelper::GetKernelFn(vec_map[1]))(&ctx); - ctx.AssignInplaceOutputs(); + run_custom_op_impl(vec_map[1], false, false, ctx); + + for (size_t i = 0; i < ctx.OutputRange().size(); ++i) { + auto output_pair = ctx.OutputRangeAt(i); + outs[i] = ctx.OutputsBetween(output_pair.first, output_pair.second); + } // handle optional None output when construct backward graph for (size_t i = 0; i < ctx.OutputRange().size(); i++) { @@ -264,7 +264,9 @@ RunCustomOpNode::operator()(paddle::small_vector, ctx.MutableOutputAt(ctx.OutputRangeAt(i).first); if (!out_tensor->initialized()) { PADDLE_ENFORCE( - paddle::framework::detail::IsOptionalVar(grad_outputs_names.at(i)), + paddle::framework::detail::IsOptionalVar( + grad_outputs_names.at(i)) || + out_tensor->is_dist_tensor(), phi::errors::InvalidArgument( "Custom grad operator's %d-th output is not initialized. " "Please check your implementation again. If you are " @@ -386,8 +388,6 @@ RunCustomOpDoubleGradNode::operator()( paddle::OpMetaInfoHelper::GetInputs(vec_map[2]); const auto& grad_outputs_names = paddle::OpMetaInfoHelper::GetOutputs(vec_map[2]); - const auto& grad_inplace_map = - paddle::OpMetaInfoHelper::GetInplaceMap(vec_map[2]); const auto& map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_); @@ -451,11 +451,12 @@ RunCustomOpDoubleGradNode::operator()( } VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad_grad"; - // handle inplace map - ctx.UpdatePlainOutputs( - grad_inputs_name, grad_outputs_names, grad_inplace_map); - (*paddle::OpMetaInfoHelper::GetKernelFn(vec_map[2]))(&ctx); - ctx.AssignInplaceOutputs(); + run_custom_op_impl(vec_map[2], false, true, ctx); + + for (size_t i = 0; i < ctx.OutputRange().size(); ++i) { + auto output_pair = ctx.OutputRangeAt(i); + outs[i] = ctx.OutputsBetween(output_pair.first, output_pair.second); + } return outs; } diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc new file mode 100644 index 0000000000000..104087b55c6a7 --- /dev/null +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -0,0 +1,709 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/custom_operator/custom_operator_utils.h" + +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/framework/custom_operator.h" +#include "paddle/fluid/framework/custom_operator_utils.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/api/lib/data_transform.h" +#include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/flags.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/api/lib/api_gen_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#endif + +namespace egr { + +using Tensor = paddle::Tensor; + +static std::vector> RunDefaultInferShapeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + if (inplace_map.empty()) { // general case, assure single input and output + PADDLE_ENFORCE_EQ( + inputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple inputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferShapeFn. " + "At this time, the input shape will be directly set to " + "the output shape.\n" + "Please set the InferShapeFn of custom " + "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple outputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferShapeFn. " + "At this time, the input shape will be directly set to " + "the output shape.\n" + "Please set the InferShapeFn of custom " + "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); + + VLOG(3) << "Custom Operator: Default InferShape - share ddim."; + result.push_back({ctx.InputAt(0).dims()}); + } else { // inplace case + PADDLE_ENFORCE_EQ( + inplace_map.size(), + outputs.size(), + phi::errors::Unavailable( + "Your custom operator uses `SetInplaceMap` without setting the " + "InferShapeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap` size = %d. Please check `SetInplaceMap` again or set " + "the InferShapeFn of custom operator by " + "`.SetInferShapeFn(PD_INFER_SHAPE(...)`)", + outputs.size(), + inplace_map.size())); + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + if (paddle::framework::detail::IsDuplicableVar(inputs[i])) { + std::vector shapes; + auto duplicable_input_pair = ctx.InputRangeAt(i); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + shapes.push_back(ctx.InputAt(j).dims()); + } + result.emplace_back(std::move(shapes)); + } else { + auto duplicable_input_pair = ctx.InputRangeAt(i); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dims()}); + } + } + } + return result; +} + +static std::vector> RunDefaultGradInferShapeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& grad_op_inputs, + const std::vector& grad_op_outputs, + bool is_double_grad) { + std::vector> result; + // 1. if forward input exists, gradient's shape is same with forward + // input + // default + // [Suitable for most situations] + // 2. if forward input not exists, and only contains one grad input and + // output, + // use grad input shape as grad output shape + // [Suitable for the situation that forward input is not used as + // backward input] + for (auto& out_name : grad_op_outputs) { + auto fwd_name = paddle::framework::detail::NoGrad(out_name, is_double_grad); + if (paddle::framework::detail::IsDuplicableVar(fwd_name)) { + // Duplicable forward var must as backward input + auto iter = + std::find(grad_op_inputs.begin(), grad_op_inputs.end(), fwd_name); + PADDLE_ENFORCE_NE( + iter, + grad_op_inputs.end(), + phi::errors::NotFound("Custom grad operator should have the forward " + "input(%s) as backward input", + fwd_name)); + auto pair = ctx.InputRangeAt(iter - grad_op_inputs.begin()); + std::vector tmp; + for (size_t i = pair.first; i < pair.second; ++i) { + tmp.emplace_back(ctx.InputAt(i).dims()); + } + result.emplace_back(std::move(tmp)); + } else { + if (grad_op_inputs.size() == grad_op_outputs.size()) { + result.push_back({ctx.InputAt(0).dims()}); + } else { + auto iter = + std::find(grad_op_inputs.begin(), grad_op_inputs.end(), fwd_name); + PADDLE_ENFORCE_NE( + iter, + grad_op_inputs.end(), + phi::errors::NotFound("Custom grad operator should have the " + "forward input(%s) as backward input", + fwd_name)); + auto pair = ctx.InputRangeAt(iter - grad_op_inputs.begin()); + result.push_back({ctx.InputAt(pair.first).dims()}); + } + } + } + return result; +} + +static std::vector> RunInferShapeFunc( + const paddle::CustomOpKernelContext& ctx, + const paddle::InferShapeFunc& func, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + std::vector> input_shapes; + std::vector>> vec_input_shapes; + + VLOG(3) << "Custom Operator: InferShape - get input ddim."; + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + const auto& input_pair = ctx.InputRangeAt(i); + if (input_pair.first == input_pair.second - 1) { + input_shapes.emplace_back( + std::move(ctx.InputAt(input_pair.first).shape())); + } else { + std::vector> shapes; + for (size_t j = input_pair.first; j < input_pair.second; j++) { + shapes.push_back(std::move(ctx.InputAt(j).shape())); + } + vec_input_shapes.emplace_back(std::move(shapes)); + } + } + + VLOG(3) << "Custom Operator: InferShape - calc output ddim."; + auto output_shapes = func(input_shapes, vec_input_shapes, ctx.Attrs()); + if (inplace_map.empty()) { + PADDLE_ENFORCE_EQ(outputs.size(), + output_shapes.size(), + phi::errors::InvalidArgument( + "Your custom operator has set the InferShapeFn. " + "However, `Outputs` size = %d does not match the " + "returned vector size of InferShapeFn = %d. Please " + "check InferShapeFn again.", + outputs.size(), + output_shapes.size())); + } else { + PADDLE_ENFORCE_EQ( + outputs.size(), + output_shapes.size() + inplace_map.size(), + phi::errors::InvalidArgument( + "Your custom operator uses `SetInplaceMap` and sets the " + "InferShapeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap size + InferShapeFn output size` = %d. Please check " + "InplaceMap and InferShapeFn again", + outputs.size(), + output_shapes.size() + inplace_map.size())); + } + + VLOG(3) + << "Custom Operator: InferShape - set output ddim: inplace_map.size() = " + << inplace_map.size() + << ", output_shapes.size() = " << output_shapes.size(); + size_t output_shape_idx = 0; + auto inplace_reverse_map = ctx.GetInplaceReverseIndexMap(); + for (size_t i = 0; i < outputs.size(); ++i) { + if (paddle::framework::detail::IsDuplicableVar(outputs[i])) { + PADDLE_ENFORCE( + inplace_reverse_map.find(i) != inplace_reverse_map.end(), + phi::errors::InvalidArgument( + "Custom operator only supports `paddle::Vec(...)` inputs and " + "cannot support `paddle::Vec(...)` output without setting " + "InplaceMap. If you have to use `paddle::Vec(...)` output, " + "please indicate it by setting InplaceMap manully.")); + std::vector shapes; + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + shapes.push_back(ctx.InputAt(j).dims()); + } + result.emplace_back(std::move(shapes)); + } else { + if (inplace_reverse_map.find(i) != inplace_reverse_map.end()) { + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dims()}); + } else { + result.push_back({phi::make_ddim(output_shapes[output_shape_idx++])}); + } + } + } + return result; +} + +static std::vector> RunDefaultInferDtypeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + if (inplace_map.empty()) { // general case, assure single input and output + PADDLE_ENFORCE_EQ( + inputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple inputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferDtypeFn. " + "At this time, the input dtype will be directly set to " + "the output dtype.\n" + "Please set the InferDtypeFn of custom " + "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple outputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferDtypeFn. " + "At this time, the input dtype will be directly set to " + "the output dtype.\n" + "Please set the InferDtypeFn of custom " + "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); + + VLOG(3) << "Custom Operator: InferDtype - share dtype."; + result.push_back({ctx.InputAt(0).dtype()}); + } else { // inplace case + PADDLE_ENFORCE_EQ( + inplace_map.size(), + outputs.size(), + phi::errors::Unavailable( + "Your custom operator uses `SetInplaceMap` without setting the " + "InferDtypeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap` size = %d. Please check `SetInplaceMap` again or set " + "the InferDtypeFn of custom operator by " + "`.SetInferDtypeFn(PD_INFER_DTYPE(...))`", + outputs.size(), + inplace_map.size())); + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + if (paddle::framework::detail::IsDuplicableVar(inputs[i])) { + std::vector shapes; + auto duplicable_input_pair = ctx.InputRangeAt(i); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + shapes.push_back(ctx.InputAt(j).dtype()); + } + result.emplace_back(std::move(shapes)); + } else { + auto duplicable_input_pair = ctx.InputRangeAt(i); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); + } + } + } + return result; +} + +static std::vector> RunDefaultGradInferDtypeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& grad_op_inputs, + const std::vector& grad_op_outputs, + bool is_double_grad) { + std::vector> result; + for (auto& out_name : grad_op_outputs) { + auto fwd_name = paddle::framework::detail::NoGrad(out_name, is_double_grad); + if (paddle::framework::detail::IsDuplicableVar(fwd_name)) { + // Duplicable forward var must as backward input + auto iter = + std::find(grad_op_inputs.begin(), grad_op_inputs.end(), fwd_name); + PADDLE_ENFORCE_NE( + iter, + grad_op_inputs.end(), + phi::errors::NotFound("Custom grad operator should have the forward " + "input(%s) as backward input", + fwd_name)); + auto pair = ctx.InputRangeAt(iter - grad_op_inputs.begin()); + std::vector tmp; + for (size_t i = pair.first; i < pair.second; ++i) { + tmp.emplace_back(ctx.InputAt(i).dtype()); + } + result.emplace_back(std::move(tmp)); + } else { + if (grad_op_inputs.size() == grad_op_outputs.size()) { + result.push_back({ctx.InputAt(0).dtype()}); + } else { + auto iter = + std::find(grad_op_inputs.begin(), grad_op_inputs.end(), fwd_name); + PADDLE_ENFORCE_NE( + iter, + grad_op_inputs.end(), + phi::errors::NotFound("Custom grad operator should have the " + "forward input(%s) as backward input", + fwd_name)); + auto pair = ctx.InputRangeAt(iter - grad_op_inputs.begin()); + result.push_back({ctx.InputAt(pair.first).dtype()}); + } + } + } + return result; +} + +static std::vector> RunInferDtypeFunc( + const paddle::CustomOpKernelContext& ctx, + const paddle::InferDtypeFunc& func, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + std::vector input_dtypes; + std::vector> vec_input_dtypes; + + VLOG(3) << "Custom Operator: InferDtype - get input dtype."; + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + const auto& input_pair = ctx.InputRangeAt(i); + if (input_pair.first == input_pair.second - 1) { + input_dtypes.emplace_back( + std::move(ctx.InputAt(input_pair.first).dtype())); + } else { + std::vector dtypes; + for (size_t j = input_pair.first; j < input_pair.second; j++) { + dtypes.emplace_back(ctx.InputAt(j).dtype()); + } + vec_input_dtypes.emplace_back(std::move(dtypes)); + } + } + + VLOG(3) << "Custom Operator: InferDtype - infer output dtype."; + auto output_dtypes = func(input_dtypes, vec_input_dtypes, ctx.Attrs()); + if (inplace_map.empty()) { + PADDLE_ENFORCE_EQ(outputs.size(), + output_dtypes.size(), + phi::errors::InvalidArgument( + "Your custom operator has set the InferDtypeFn. " + "However, `Outputs` size = %d does not match the " + "returned vector size of InferDtypeFn = %d. Please " + "check InferDtypeFn again.", + outputs.size(), + output_dtypes.size())); + } else { + PADDLE_ENFORCE_EQ( + outputs.size(), + output_dtypes.size() + inplace_map.size(), + phi::errors::InvalidArgument( + "Your custom operator uses `SetInplaceMap` and sets the " + "InferDtypeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap size + InferDtypeFn output size` = %d. Please check " + "InplaceMap and InferDtypeFn again", + outputs.size(), + output_dtypes.size() + inplace_map.size())); + } + + VLOG(3) + << "Custom Operator: InferDtype - set output dtype: inplace_map.size() = " + << inplace_map.size() + << ", output_dtypes.size() = " << output_dtypes.size(); + size_t output_dtype_idx = 0; + auto inplace_reverse_map = ctx.GetInplaceReverseIndexMap(); + for (size_t i = 0; i < outputs.size(); ++i) { + if (paddle::framework::detail::IsDuplicableVar(outputs[i])) { + PADDLE_ENFORCE( + inplace_reverse_map.find(i) != inplace_reverse_map.end(), + phi::errors::InvalidArgument( + "Custom operator only supports `paddle::Vec(...)` inputs and " + "cannot support `paddle::Vec(...)` output without setting " + "InplaceMap. If you have to use `paddle::Vec(...)` output, " + "please indicate it by setting InplaceMap manully.")); + std::vector dtypes; + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + dtypes.push_back(ctx.InputAt(j).dtype()); + } + result.emplace_back(std::move(dtypes)); + } else { + if (inplace_reverse_map.find(i) != inplace_reverse_map.end()) { + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); + } else { + result.push_back({output_dtypes[output_dtype_idx++]}); + } + } + } + return result; +} + +#ifdef PADDLE_WITH_DISTRIBUTE +paddle::Tensor BuildEmptyDistPaddleTensor( + const phi::distributed::ProcessMesh& process_mesh, + const phi::DDim& dims, + phi::DataType dtype) { + paddle::Tensor empty_tensor; + phi::DenseTensorMeta meta; + meta.dims = dims; + meta.dtype = dtype; + + auto dist_attr = phi::distributed::TensorDistAttr(phi::vectorize(dims)); + dist_attr.set_process_mesh(process_mesh); + + auto dist_t = std::make_shared( + std::make_shared( + std::make_shared( + nullptr, 0, phi::distributed::GetDefaultPlace()), + meta), + dist_attr); + empty_tensor.set_impl(dist_t); + empty_tensor.set_autograd_meta(std::make_shared()); + return empty_tensor; +} +#endif + +#ifdef PADDLE_WITH_DISTRIBUTE +std::tuple PrepareCtxForAutoParallel( + const paddle::OpMetaInfo& op_info, + bool is_forward, + bool is_double_grad, + paddle::CustomOpKernelContext& ctx) { // NOLINT + bool run_auto_parallel = false; + bool rank_is_in_current_mesh = true; + phi::distributed::ProcessMesh current_process_mesh; + + const auto& inputs = paddle::OpMetaInfoHelper::GetInputs(op_info); + const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(op_info); + const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(op_info); + + std::vector* all_inputs = ctx.AllMutableInput(); + std::vector x = *all_inputs; + const phi::distributed::ProcessMesh* mesh = nullptr; + for (auto& input : x) { + if (input.is_dist_tensor()) { + mesh = &( + std::dynamic_pointer_cast(input.impl()) + ->dist_attr() + .process_mesh()); + break; + } + } + + if (mesh) { + for (auto& input : x) { + if (input.is_dist_tensor()) { + PADDLE_ENFORCE_EQ( + std::dynamic_pointer_cast( + input.impl()) + ->dist_attr() + .process_mesh(), + *mesh, + phi::errors::InvalidArgument( + "Input %s has different mesh. However all inputs should " + "have the same mesh.", + input.name())); + } else { + PADDLE_ENFORCE_EQ( + phi::DenseTensor::classof(input.impl().get()), + true, + phi::errors::InvalidArgument("Failed to convert input %s impl " + "to phi::distributed::DistTensor " + "as it's not phi::DenseTensor.", + input.name())); + phi::distributed::TensorDistAttr dist_attr( + phi::vectorize(input.impl()->dims())); + dist_attr.set_process_mesh(*mesh); + auto dense_t = std::static_pointer_cast(input.impl()); + input.set_impl( + std::make_shared(dense_t, dist_attr)); + } + } + } + + run_auto_parallel = paddle::experimental::AllInputsAreDistTensor(x); + rank_is_in_current_mesh = true; + if (run_auto_parallel) { + auto mesh = + std::static_pointer_cast(x.at(0).impl()) + ->dist_attr() + .process_mesh(); + rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); + + std::vector input_x(x.size()); + for (size_t i = 0; i < input_x.size(); ++i) { + input_x[i] = x.at(i).impl().get(); + } + + auto meta_dist_input_x = paddle::experimental::MakeDistMetaTensor(input_x); + auto spmd_info = + phi::distributed::VariadicReplicatedInferSpmdDynamic(meta_dist_input_x); + current_process_mesh = + paddle::holds_alternative( + spmd_info.first[0]) + ? paddle::get<0>(spmd_info.first[0]).process_mesh() + : paddle::get<1>(spmd_info.first[0]).at(0).process_mesh(); + + if (rank_is_in_current_mesh) { + auto* dev_ctx = phi::DeviceContextPool::Instance().Get(x.at(0).place()); + auto dist_input_x = paddle::experimental::ReshardApiInputToKernelInput( + dev_ctx, x, spmd_info.first[0]); + for (size_t i = 0; i < x.size(); ++i) { + all_inputs->at(i).set_impl( + std::make_shared(dist_input_x[i]->value())); + } + } else { + auto& infer_shape_func = + paddle::OpMetaInfoHelper::GetInferShapeFn(op_info); + auto& infer_dtype_func = + paddle::OpMetaInfoHelper::GetInferDtypeFn(op_info); + + std::vector> out_dims; + if (infer_shape_func) { + out_dims = RunInferShapeFunc( + ctx, infer_shape_func, inputs, outputs, inplace_map); + } else { + if (is_forward) { + out_dims = + RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); + } else { + out_dims = RunDefaultGradInferShapeFunc( + ctx, inputs, outputs, is_double_grad); + } + } + + std::vector> out_dtypes; + if (infer_dtype_func) { + out_dtypes = RunInferDtypeFunc( + ctx, infer_dtype_func, inputs, outputs, inplace_map); + } else { + if (is_forward) { + out_dtypes = + RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); + } else { + out_dtypes = RunDefaultGradInferDtypeFunc( + ctx, inputs, outputs, is_double_grad); + } + } + + PADDLE_ENFORCE_EQ( + out_dims.size(), + ctx.OutputRange().size(), + phi::errors::InvalidArgument( + "Custome op infer_shape return size should be %d, but got %d.", + ctx.OutputRange().size(), + out_dims.size())); + + PADDLE_ENFORCE_EQ( + out_dtypes.size(), + ctx.OutputRange().size(), + phi::errors::InvalidArgument( + "Custome op infer_dtype return size should be %d, but got %d.", + ctx.OutputRange().size(), + out_dtypes.size())); + + for (size_t i = 0; i < out_dims.size(); ++i) { + const auto& out_dim = out_dims.at(i); + const auto& out_dtype = out_dtypes.at(i); + const auto& pair = ctx.OutputRangeAt(i); + PADDLE_ENFORCE_EQ( + out_dim.size(), + pair.second - pair.first, + phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " + "size should be %d, but got %d.", + i, + pair.second - pair.first, + out_dim.size())); + PADDLE_ENFORCE_EQ( + out_dtype.size(), + pair.second - pair.first, + phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " + "size should be %d, but got %d.", + i, + pair.second - pair.first, + out_dtype.size())); + + if (out_dim.size() == 1) { + *(ctx.MutableOutputAt(pair.first)) = BuildEmptyDistPaddleTensor( + current_process_mesh, out_dim[0], out_dtype[0]); + } else { + for (size_t j = pair.first; j < pair.second; j++) { + *(ctx.MutableOutputAt(j)) = BuildEmptyDistPaddleTensor( + current_process_mesh, out_dim[j], out_dtype[j]); + } + } + } + return std::tuple( + run_auto_parallel, rank_is_in_current_mesh, current_process_mesh); + } + } + return std::tuple( + run_auto_parallel, rank_is_in_current_mesh, current_process_mesh); +} +#endif + +#ifdef PADDLE_WITH_DISTRIBUTE +void TransCtxTensorsToDistTensors( + paddle::CustomOpKernelContext& ctx, // NOLINT + bool run_auto_parallel, + const phi::distributed::ProcessMesh& current_process_mesh) { + if (run_auto_parallel) { + std::vector* output_all = ctx.AllMutableOutput(); + for (size_t i = 0; i < output_all->size(); ++i) { + auto& tensor = output_all->at(i); + phi::distributed::TensorDistAttr dist_attr = + phi::distributed::TensorDistAttr(phi::vectorize(tensor.dims())); + dist_attr.set_process_mesh(current_process_mesh); + auto dist_t = std::make_shared( + std::dynamic_pointer_cast(tensor.impl()), + dist_attr); + tensor.set_impl(dist_t); + } + std::vector* input_all = ctx.AllMutableInput(); + for (size_t i = 0; i < input_all->size(); ++i) { + auto& tensor = input_all->at(i); + phi::distributed::TensorDistAttr dist_attr = + phi::distributed::TensorDistAttr(phi::vectorize(tensor.dims())); + dist_attr.set_process_mesh(current_process_mesh); + auto dist_t = std::make_shared( + std::dynamic_pointer_cast(tensor.impl()), + dist_attr); + tensor.set_impl(dist_t); + } + } +} +#endif + +void run_custom_op_impl(const paddle::OpMetaInfo& op_info, + bool is_forward, + bool is_double_grad, + paddle::CustomOpKernelContext& ctx) { // NOLINT + const auto& inputs = paddle::OpMetaInfoHelper::GetInputs(op_info); + const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(op_info); + const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(op_info); + ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); + +#ifdef PADDLE_WITH_DISTRIBUTE + auto result = + PrepareCtxForAutoParallel(op_info, is_forward, is_double_grad, ctx); + bool run_auto_parallel = std::get<0>(result); + bool rank_is_in_current_mesh = std::get<1>(result); + phi::distributed::ProcessMesh current_process_mesh = std::get<2>(result); + if (!rank_is_in_current_mesh) { + return; + } +#endif + + std::vector* all_inputs = ctx.AllMutableInput(); + for (size_t i = 0; i < all_inputs->size(); ++i) { + auto& tensor = all_inputs->at(i); + if (tensor.initialized() && tensor.is_dense_tensor() && + !std::dynamic_pointer_cast(tensor.impl()) + ->meta() + .is_contiguous()) { + tensor.set_impl(std::make_shared( + std::move(paddle::experimental::Trans2Contiguous( + *(std::dynamic_pointer_cast(tensor.impl())))))); + } + } + + // handle inplace map + ctx.UpdatePlainOutputs(inputs, outputs, inplace_map); + VLOG(7) << "Begin run Kernel of Custom Op"; + (*paddle::OpMetaInfoHelper::GetKernelFn(op_info))(&ctx); + ctx.AssignInplaceOutputs(); + +#ifdef PADDLE_WITH_DISTRIBUTE + TransCtxTensorsToDistTensors(ctx, run_auto_parallel, current_process_mesh); +#endif +} + +} // namespace egr diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.h b/paddle/fluid/eager/custom_operator/custom_operator_utils.h new file mode 100644 index 0000000000000..ac2dec37f3d34 --- /dev/null +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.h @@ -0,0 +1,24 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/api/ext/op_meta_info.h" + +namespace egr { +void run_custom_op_impl(const paddle::OpMetaInfo& op_info, + bool is_forward, + bool is_double_grad, + paddle::CustomOpKernelContext& ctx); // NOLINT +} // namespace egr diff --git a/paddle/fluid/eager/grad_node_info.cc b/paddle/fluid/eager/grad_node_info.cc index 2619e706cfa13..b1de532d3e6b8 100644 --- a/paddle/fluid/eager/grad_node_info.cc +++ b/paddle/fluid/eager/grad_node_info.cc @@ -105,9 +105,16 @@ void GradNodeBase::SetGradInMeta(const paddle::Tensor& fwd_out, } if (!fwd_out.initialized()) { - VLOG(7) - << "Skip Configuring GradSlotMeta for uninitialized GradInput Tensor"; - return; + if (fwd_out.defined() && fwd_out.is_dist_tensor() && + phi::distributed::NeedComputationClipForPP(fwd_out.impl())) { + VLOG(3) << "Tensor " << fwd_out.name() << " is DistTensor," + << " and needs computation clip for pipeline parallel." + << " Still SetGradInMeta for it."; + } else { + VLOG(7) + << "Skip Configuring GradSlotMeta for uninitialized GradInput Tensor"; + return; + } } const phi::DenseTensor* dense_tensor = nullptr; @@ -124,11 +131,16 @@ void GradNodeBase::SetGradInMeta(const paddle::Tensor& fwd_out, static_cast(fwd_out.impl().get()); dense_tensor = csr_tensor->mutable_non_zero_elements(); } else if (phi::distributed::DistTensor::classof(fwd_out.impl().get())) { - // TODO(chenweihang): DistTensor contains global and local meta, here - // only set the local meta now, we should set global meta later dense_tensor = // NOLINT &(static_cast(fwd_out.impl().get()) ->value()); + meta.SetDistAttr( + static_cast(fwd_out.impl().get()) + ->dist_attr()); + meta.SetDistTensorGlobalDims( + static_cast(fwd_out.impl().get()) + ->dims()); + SetIsRunAutoParallel(true); } else { VLOG(7) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with " "non-DenseTensor argument."; @@ -183,9 +195,16 @@ void GradNodeBase::SetGradInMeta(const std::vector& fwd_out, } if (!fwd_out_tensor.initialized()) { - VLOG(7) - << "Skip Configuring GradSlotMeta for uninitialized GradInput Tensor"; - return; + if (fwd_out_tensor.defined() && fwd_out_tensor.is_dist_tensor() && + !phi::distributed::NeedComputationClipForPP(fwd_out_tensor.impl())) { + VLOG(3) << "Tensor " << fwd_out_tensor.name() << " is DistTensor," + << " and needs computation clip for pipeline parallel." + << " Still SetGradInMeta for it."; + } else { + VLOG(7) << "Skip Configuring GradSlotMeta for uninitialized GradInput " + "Tensor"; + return; + } } // Record TensorMeta @@ -207,6 +226,34 @@ void GradNodeBase::SetGradInMeta(const std::vector& fwd_out, dense_tensor->type() == phi::DataType::COMPLEX128) { need_complex_to_real_ = true; } + } else if (phi::distributed::DistTensor::classof( + fwd_out_tensor.impl().get())) { + // Only Copy Meta + meta.SetDistAttr(static_cast( + fwd_out_tensor.impl().get()) + ->dist_attr()); + meta.SetDistTensorGlobalDims(static_cast( + fwd_out_tensor.impl().get()) + ->dims()); + SetIsRunAutoParallel(true); + + auto dense_tensor = static_cast( + fwd_out_tensor.impl().get()) + ->value(); + + PADDLE_ENFORCE_NE( + dense_tensor.meta().dtype, + phi::DataType::UNDEFINED, + paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta " + "with phi::DataType::UNDEFINED," + "which is illegal.")); + meta.SetTensorMeta(dense_tensor.meta()); + meta.SetPlace(fwd_out_tensor.place()); + + if (dense_tensor.type() == phi::DataType::COMPLEX64 || + dense_tensor.type() == phi::DataType::COMPLEX128) { + need_complex_to_real_ = true; + } } else { VLOG(7) << "Unable to initialize the DenseTensorMeta of GradSlotMeta " "with non-DenseTensor argument."; @@ -277,16 +324,14 @@ void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in, meta.SetTensorMeta(dense_tensor.meta()); meta.SetPlace(fwd_in.place()); // Set DistAttr - PADDLE_ENFORCE_EQ(dist_tensor->defined(), - true, - phi::errors::InvalidArgument( - "The forward input DistTensor is not defined.")); + // Forward input DistTensor could be uninitialized. PADDLE_ENFORCE_NE( dist_tensor->dist_attr().empty(), true, phi::errors::InvalidArgument( "The forward input DistTensor's dist attr is empty.")); meta.SetDistAttr(dist_tensor->dist_attr()); + meta.SetDistTensorGlobalDims(dist_tensor->dims()); SetIsRunAutoParallel(true); } else { VLOG(7) @@ -355,6 +400,26 @@ void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in, "which is illegal.")); meta.SetTensorMeta(dense_tensor->meta()); meta.SetPlace(fwd_in.place()); + } else if (phi::distributed::DistTensor::classof(fwd_in.impl().get())) { + // Only Copy Meta + meta.SetDistAttr( + static_cast(fwd_in.impl().get()) + ->dist_attr()); + meta.SetDistTensorGlobalDims( + static_cast(fwd_in.impl().get()) + ->dims()); + SetIsRunAutoParallel(true); + auto dense_tensor = + static_cast(fwd_in.impl().get()) + ->value(); + PADDLE_ENFORCE_NE( + dense_tensor.meta().dtype, + phi::DataType::UNDEFINED, + paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta " + "with phi::DataType::UNDEFINED," + "which is illegal.")); + meta.SetTensorMeta(dense_tensor.meta()); + meta.SetPlace(fwd_in.place()); } } else { VLOG(7) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with " @@ -403,8 +468,6 @@ void GradNodeBase::SetGradOutMeta(const std::vector& fwd_in, // Record TensorMeta if (fwd_in_tensor.impl() && fwd_in_tensor.impl().get()) { if (phi::DenseTensor::classof(fwd_in_tensor.impl().get())) { - // TODO(chenweihang): DistTensor contains global and local meta, here - // only set the local meta now, we should set global meta later phi::DenseTensor* dense_tensor = static_cast(fwd_in_tensor.impl().get()); PADDLE_ENFORCE_NE(dense_tensor->dtype(), @@ -415,6 +478,26 @@ void GradNodeBase::SetGradOutMeta(const std::vector& fwd_in, "which is illegal.")); meta.SetTensorMeta(dense_tensor->meta()); meta.SetPlace(fwd_in_tensor.place()); + } else if (phi::distributed::DistTensor::classof( + fwd_in_tensor.impl().get())) { + meta.SetDistAttr(static_cast( + fwd_in_tensor.impl().get()) + ->dist_attr()); + meta.SetDistTensorGlobalDims(static_cast( + fwd_in_tensor.impl().get()) + ->dims()); + SetIsRunAutoParallel(true); + auto dense_tensor = static_cast( + fwd_in_tensor.impl().get()) + ->value(); + PADDLE_ENFORCE_NE(dense_tensor.dtype(), + phi::DataType::UNDEFINED, + paddle::platform::errors::Fatal( + "Attempting to copy DenseTensorMeta " + "with phi::DataType::UNDEFINED," + "which is illegal.")); + meta.SetTensorMeta(dense_tensor.meta()); + meta.SetPlace(fwd_in_tensor.place()); } } else { VLOG(7) @@ -476,6 +559,27 @@ void GradNodeBase::SetGradOutMeta( "which is illegal.")); meta.SetTensorMeta(dense_tensor->meta()); meta.SetPlace(fwd_in_tensor.place()); + } else if (phi::distributed::DistTensor::classof( + fwd_in_tensor.impl().get())) { + // Only Copy Meta + meta.SetDistAttr(static_cast( + fwd_in_tensor.impl().get()) + ->dist_attr()); + meta.SetDistTensorGlobalDims(static_cast( + fwd_in_tensor.impl().get()) + ->dims()); + SetIsRunAutoParallel(true); + auto dense_tensor = static_cast( + fwd_in_tensor.impl().get()) + ->value(); + PADDLE_ENFORCE_NE(dense_tensor.dtype(), + phi::DataType::UNDEFINED, + paddle::platform::errors::Fatal( + "Attempting to copy DenseTensorMeta " + "with phi::DataType::UNDEFINED," + "which is illegal.")); + meta.SetTensorMeta(dense_tensor.meta()); + meta.SetPlace(fwd_in_tensor.place()); } } else { VLOG(7) diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 15f8d35d2ab79..2318cf0789ed7 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -21,6 +21,7 @@ #include "paddle/fluid/eager/hooks.h" #include "paddle/phi/api/all.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #include "paddle/utils/test_macros.h" namespace egr { @@ -166,8 +167,20 @@ class GradSlotMeta { void SetDistAttr(const phi::distributed::TensorDistAttr& dist_attr) { dist_attr_ = dist_attr; + is_dist_meta_ = true; } + const phi::DDim& DistTensorGlobalDims() const { + return dist_tensor_global_dims_; + } + + void SetDistTensorGlobalDims(const phi::DDim& dims) { + dist_tensor_global_dims_ = dims; + is_dist_meta_ = true; + } + + bool IsDistMeta() const { return is_dist_meta_; } + private: bool stop_gradient_{false}; phi::Place place_; @@ -177,6 +190,8 @@ class GradSlotMeta { // Save the dist attr of the forward input Tensor for proper resharding // operation when compute the input Tensor's gradient phi::distributed::TensorDistAttr dist_attr_; + phi::DDim dist_tensor_global_dims_; + bool is_dist_meta_{false}; }; class GradNodeBase { diff --git a/paddle/fluid/eager/grad_tensor_holder.cc b/paddle/fluid/eager/grad_tensor_holder.cc index 34469f875198b..a28b741651333 100644 --- a/paddle/fluid/eager/grad_tensor_holder.cc +++ b/paddle/fluid/eager/grad_tensor_holder.cc @@ -111,8 +111,19 @@ void GradTensorHolder::add(size_t slot_id, const paddle::Tensor& t, bool create_graph) { if (!t.initialized()) { - VLOG(3) << "No need to do accumulate for uninitialized t."; - return; + if (t.defined() && t.is_dist_tensor() && + phi::distributed::NeedComputationClipForPP(t.impl())) { + // Pipeline parallel still needs to construct GradNode graph + // to make DistTensor's global shape and DistAttr information flow. + // Skip grad accumulation will cause GradTensor disconnect to next + // GradNode. + VLOG(3) << "Do accumulate for uninitialized Tensor " << t.name() + << " as it's DistTensor and it needs computation clip for " + "pipeline parallel."; + } else { + VLOG(3) << "No need to do accumulate for uninitialized t."; + return; + } } // TODO(jiabin): Remove this when we fix all kernel. PADDLE_ENFORCE(slot_id < buffer_.size(), diff --git a/paddle/fluid/eager/nan_inf_utils.cc b/paddle/fluid/eager/nan_inf_utils.cc index 29922e37beb43..a1e62ea6ba519 100644 --- a/paddle/fluid/eager/nan_inf_utils.cc +++ b/paddle/fluid/eager/nan_inf_utils.cc @@ -19,6 +19,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/selected_rows.h" @@ -90,8 +91,12 @@ void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor) { } else if (tensor.is_selected_rows()) { dense_tensor = &( static_cast(tensor.impl().get())->value()); + } else if (tensor.is_dist_tensor()) { + dense_tensor = &( + static_cast(tensor.impl().get()) + ->value()); } else { - VLOG(10) << "Only DenseTensor or SelectedRows need to check, " + VLOG(10) << "Only DenseTensor,SelectedRows,DistTensor need to check, " << tensor_name << " is no need."; return; } diff --git a/paddle/fluid/eager/pylayer/py_layer_node.cc b/paddle/fluid/eager/pylayer/py_layer_node.cc index cec1d49f95a85..5ac0fc4640eb0 100644 --- a/paddle/fluid/eager/pylayer/py_layer_node.cc +++ b/paddle/fluid/eager/pylayer/py_layer_node.cc @@ -62,35 +62,80 @@ GradNodePyLayer::operator()( PyObject* pylist = PyList_New((Py_ssize_t)grads[i].size()); for (size_t j = 0; j < grads[i].size(); j++) { if (ctx->materialize_grads && !grads[i][j].initialized()) { - paddle::Tensor tensor_tmp; - auto dense_tensor = std::make_shared(); - dense_tensor->set_meta(forward_outputs_meta_[i][j]); - tensor_tmp.set_impl(dense_tensor); - PyList_SET_ITEM( - pylist, - static_cast(i), - paddle::pybind::ToPyObject(paddle::experimental::zeros_like( - tensor_tmp, - tensor_tmp.dtype(), - forward_outputs_place_[i][j]))); + if (forward_outputs_is_dist_meta_[i][j]) { + paddle::Tensor dist_tensor; + dist_tensor.set_impl(std::make_shared( + forward_outputs_global_dims_[i][j], + forward_outputs_dist_attr_[i][j])); + if (forward_outputs_meta_[i][j].dims.size() != -1) { + paddle::Tensor tensor_tmp; + auto dense_tensor = std::make_shared(); + dense_tensor->set_meta(forward_outputs_meta_[i][j]); + tensor_tmp.set_impl(dense_tensor); + auto zero_tensor = paddle::experimental::zeros_like( + tensor_tmp, tensor_tmp.dtype(), forward_outputs_place_[i][j]); + *(static_cast( + dist_tensor.impl().get()) + ->unsafe_mutable_value()) = + *(static_cast(zero_tensor.impl().get())); + } + PyTuple_SET_ITEM(pylist, + static_cast(j), + paddle::pybind::ToPyObject(dist_tensor)); + } else { + paddle::Tensor tensor_tmp; + auto dense_tensor = std::make_shared(); + dense_tensor->set_meta(forward_outputs_meta_[i][j]); + tensor_tmp.set_impl(dense_tensor); + PyTuple_SET_ITEM( + pylist, + static_cast(j), + paddle::pybind::ToPyObject(paddle::experimental::zeros_like( + tensor_tmp, + tensor_tmp.dtype(), + forward_outputs_place_[i][j]))); + } } else { PyList_SET_ITEM(pylist, - static_cast(i), + static_cast(0), paddle::pybind::ToPyObject(grads[i][0], true)); } } PyTuple_SET_ITEM(backward_args, i, pylist); } else { if (ctx->materialize_grads && !grads[i][0].initialized()) { - paddle::Tensor tensor_tmp; - auto dense_tensor = std::make_shared(); - dense_tensor->set_meta(forward_outputs_meta_[i][0]); - tensor_tmp.set_impl(dense_tensor); - PyTuple_SET_ITEM( - backward_args, - i, - paddle::pybind::ToPyObject(paddle::experimental::zeros_like( - tensor_tmp, tensor_tmp.dtype(), forward_outputs_place_[i][0]))); + if (forward_outputs_is_dist_meta_[i][0]) { + paddle::Tensor dist_tensor; + dist_tensor.set_impl(std::make_shared( + forward_outputs_global_dims_[i][0], + forward_outputs_dist_attr_[i][0])); + if (forward_outputs_meta_[i][0].dims.size() != -1) { + paddle::Tensor tensor_tmp; + auto dense_tensor = std::make_shared(); + dense_tensor->set_meta(forward_outputs_meta_[i][0]); + tensor_tmp.set_impl(dense_tensor); + auto zero_tensor = paddle::experimental::zeros_like( + tensor_tmp, tensor_tmp.dtype(), forward_outputs_place_[i][0]); + *(static_cast( + dist_tensor.impl().get()) + ->unsafe_mutable_value()) = + *(static_cast(zero_tensor.impl().get())); + } + PyTuple_SET_ITEM( + backward_args, i, paddle::pybind::ToPyObject(dist_tensor)); + } else { + paddle::Tensor tensor_tmp; + auto dense_tensor = std::make_shared(); + dense_tensor->set_meta(forward_outputs_meta_[i][0]); + tensor_tmp.set_impl(dense_tensor); + PyTuple_SET_ITEM( + backward_args, + i, + paddle::pybind::ToPyObject(paddle::experimental::zeros_like( + tensor_tmp, + tensor_tmp.dtype(), + forward_outputs_place_[i][0]))); + } } else { PyTuple_SET_ITEM( backward_args, i, paddle::pybind::ToPyObject(grads[i][0], true)); diff --git a/paddle/fluid/eager/pylayer/py_layer_node.h b/paddle/fluid/eager/pylayer/py_layer_node.h index ac6fc29140681..11becea2e6e9a 100644 --- a/paddle/fluid/eager/pylayer/py_layer_node.h +++ b/paddle/fluid/eager/pylayer/py_layer_node.h @@ -24,6 +24,7 @@ #include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/hooks.h" #include "paddle/phi/core/compat/convert_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/tensor_meta.h" namespace egr { @@ -63,15 +64,35 @@ class GradNodePyLayer : public GradNodeBase { const std::vector>& outputs_tensor) { forward_outputs_meta_.resize(outputs_tensor.size()); forward_outputs_place_.resize(outputs_tensor.size()); + forward_outputs_dist_attr_.resize(outputs_tensor.size()); + forward_outputs_global_dims_.resize(outputs_tensor.size()); + forward_outputs_is_dist_meta_.resize(outputs_tensor.size()); for (size_t i = 0; i < outputs_tensor.size(); i++) { forward_outputs_meta_[i].reserve(outputs_tensor[i].size()); forward_outputs_place_[i].reserve(outputs_tensor[i].size()); + forward_outputs_dist_attr_[i].reserve(outputs_tensor[i].size()); + forward_outputs_global_dims_[i].reserve(outputs_tensor[i].size()); + forward_outputs_is_dist_meta_[i].reserve(outputs_tensor[i].size()); for (auto tensor : outputs_tensor[i]) { if (tensor->is_dense_tensor()) { forward_outputs_meta_[i].push_back( static_cast(tensor->impl().get())->meta()); + forward_outputs_is_dist_meta_[i].push_back(false); + } else if (tensor->is_dist_tensor()) { + forward_outputs_meta_[i].push_back( + static_cast(tensor->impl().get()) + ->value() + .meta()); + forward_outputs_dist_attr_[i].push_back( + static_cast(tensor->impl().get()) + ->dist_attr()); + forward_outputs_global_dims_[i].push_back( + static_cast(tensor->impl().get()) + ->dims()); + forward_outputs_is_dist_meta_[i].push_back(true); } else { forward_outputs_meta_[i].emplace_back(); + forward_outputs_is_dist_meta_[i].push_back(false); } forward_outputs_place_[i].emplace_back(tensor->place()); } @@ -89,6 +110,10 @@ class GradNodePyLayer : public GradNodeBase { std::string name_{""}; std::vector> forward_outputs_meta_; std::vector> forward_outputs_place_; + std::vector> + forward_outputs_dist_attr_; + std::vector> forward_outputs_global_dims_; + std::vector> forward_outputs_is_dist_meta_; }; } // namespace egr diff --git a/paddle/fluid/eager/tensor_wrapper.h b/paddle/fluid/eager/tensor_wrapper.h index 7aa6ec8e6cddb..8030b1273e6db 100644 --- a/paddle/fluid/eager/tensor_wrapper.h +++ b/paddle/fluid/eager/tensor_wrapper.h @@ -43,11 +43,17 @@ class TensorWrapper { bool no_need_buffer = false) { // set inplace_version_snapshot_ according to tensor's current inplace // version. - if (tensor.impl() && phi::DenseTensor::classof(tensor.impl().get())) { + if (tensor.initialized() && tensor.is_dense_tensor()) { phi::DenseTensor* dense_tensor = static_cast(tensor.impl().get()); auto& inplace_version_counter = dense_tensor->InplaceVersionCounter(); inplace_version_snapshot_ = inplace_version_counter.CurrentVersion(); + } else if (tensor.initialized() && tensor.is_dist_tensor()) { + phi::DenseTensor* dense_tensor = + static_cast(tensor.impl().get()) + ->unsafe_mutable_value(); + auto& inplace_version_counter = dense_tensor->InplaceVersionCounter(); + inplace_version_snapshot_ = inplace_version_counter.CurrentVersion(); } /** @@ -200,10 +206,20 @@ class TensorWrapper { "no_need_buffer_ is true."; return; } - if (intermidiate_tensor_.impl() && - phi::DenseTensor::classof(intermidiate_tensor_.impl().get())) { - phi::DenseTensor* dense_tensor = - static_cast(intermidiate_tensor_.impl().get()); + if (intermidiate_tensor_.impl()) { + phi::DenseTensor* dense_tensor = nullptr; + if (phi::DenseTensor::classof(intermidiate_tensor_.impl().get())) { + dense_tensor = + static_cast(intermidiate_tensor_.impl().get()); + } else if (phi::distributed::DistTensor::classof( + intermidiate_tensor_.impl().get())) { + dense_tensor = static_cast( + intermidiate_tensor_.impl().get()) + ->unsafe_mutable_value(); + } else { + return; + } + auto& inplace_version_counter = dense_tensor->InplaceVersionCounter(); uint32_t wrapper_version_snapshot = inplace_version_snapshot_; diff --git a/paddle/fluid/eager/to_static/run_program_op_func.h b/paddle/fluid/eager/to_static/run_program_op_func.h index 1d02457bbe748..8e788bd94162e 100644 --- a/paddle/fluid/eager/to_static/run_program_op_func.h +++ b/paddle/fluid/eager/to_static/run_program_op_func.h @@ -92,7 +92,7 @@ static std::vector filter_unused_input_var_in_backward( return filter_x; } -static std::vector newir_filter_unused_input_var_in_backward( +static std::vector pir_filter_unused_input_var_in_backward( const std::vector& x, const std::string x_key_name, const paddle::framework::AttributeMap& attrs) { @@ -134,7 +134,6 @@ inline void run_program_ad_func( const std::vector& params, std::vector& out, // NOLINT std::vector& step_scope, // NOLINT - std::vector& dout, // NOLINT const paddle::framework::AttributeMap& attrs) { // Prepare Autograd Meta VLOG(2) << "start run run_program ad function."; @@ -156,8 +155,7 @@ inline void run_program_ad_func( auto params_tmp = Trans2ContiguousTensors(params); // Call forward function // if require_any_grad is False, don't save any middle vars. - RunProgramAPI( - x_tmp, params_tmp, out, step_scope, dout, require_any_grad, attrs); + RunProgramAPI(x_tmp, params_tmp, out, step_scope, require_any_grad, attrs); VLOG(2) << "start run run_program grad"; auto is_test = false; if (attrs.count("is_test")) { @@ -189,20 +187,7 @@ inline void run_program_ad_func( grad_node->SetFwdParams(params_tmp); grad_node->SetStepScope(step_scope); - // Set Grad out rank as same as fwd input and set stop gradient to bwd - // NOTE(@xiongkun): Not every tensor in x(list of tensor) is required - // gradient. for example: x[1] is not used for output, the x[1] is ignored. - - std::vector x_require_grad; - for (size_t i = 0; i < x.size(); ++i) { - auto& name = x_names[i]; - if (forward_global_block->HasVar(name) || - backward_global_block->HasVar(name)) { - x_require_grad.push_back(&x[i]); - } - } - - grad_node->SetGradOutMeta(x_require_grad, /*slot id*/ 0); + grad_node->SetGradOutMeta(x, /*slot id*/ 0); grad_node->SetGradOutMeta(params, /*slot id*/ 1); VLOG(2) << "clear_no_grad_edges."; @@ -221,15 +206,14 @@ inline void run_program_ad_func( } } -inline void newir_run_program_ad_func( +inline void pir_run_program_ad_func( const std::vector& x, const std::vector& params, std::vector& out, // NOLINT std::vector& step_scope, // NOLINT - std::vector& dout, // NOLINT const paddle::framework::AttributeMap& attrs) { // Prepare Autograd Meta - VLOG(2) << "start run newir run_program ad function."; + VLOG(2) << "start run pir run_program ad function."; auto deref_out = details::DereferenceTensors(out); std::vector p_autograd_x = egr::EagerUtils::nullable_autograd_meta(x); @@ -248,13 +232,18 @@ inline void newir_run_program_ad_func( auto output_size = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fo")).size(); auto middles = std::vector(); - std::shared_ptr grad_node; + + auto is_test = false; + if (attrs.count("is_test")) { + is_test = PADDLE_GET_CONST(bool, attrs.at("is_test")); + } + std::shared_ptr grad_node; VLOG(2) << "start run run_program with require_any_grad = " - << require_any_grad; + << require_any_grad << ", is_test = " << is_test; - if (require_any_grad) { + if (!is_test && require_any_grad) { // Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad]) - grad_node = std::make_shared(1, 2); + grad_node = std::make_shared(1, 2); grad_node->GetMiddle().resize(middle_size); grad_node->GetOutputs().resize(output_size); for (size_t i = 0; i < middle_size; ++i) { @@ -278,16 +267,16 @@ inline void newir_run_program_ad_func( // Call forward function // if require_any_grad is False, don't save any middle vars. - NewIRRunProgramAPI( - x, params, out, middles, step_scope, dout, require_any_grad, attrs); - if (require_any_grad) { + PirRunProgramAPI( + x, params, out, middles, step_scope, require_any_grad, attrs); + if (!is_test && require_any_grad) { egr::EagerUtils::PassStopGradient(false, &p_autograd_outs); // Set Attributes grad_node->SetAttrMap(attrs); // Clear unused x vars - auto filter_x = newir_filter_unused_input_var_in_backward(x, "bx", attrs); + auto filter_x = pir_filter_unused_input_var_in_backward(x, "bx", attrs); // Set TensorWrappers grad_node->SetFwdX(filter_x); diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 83e4424a21251..620a0132f2f87 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -31,7 +31,7 @@ #include "paddle/pir/core/program.h" #include "paddle/pir/core/value.h" -PHI_DECLARE_bool(enable_new_ir_in_executor); +PHI_DECLARE_bool(enable_pir_in_executor); PHI_DECLARE_bool(print_ir); namespace details { @@ -317,22 +317,12 @@ static void ShareTensorsFromScopeWithPartialBlock( paddle::framework::Scope *scope) { for (size_t i = 0; i < tensors.size(); ++i) { auto &name = tensors[i]->name(); - bool in_forward_block = forward_global_block.HasVar(name); - bool in_backward_block = - backward_global_block && backward_global_block->HasVar(name); + auto *var = scope->FindVar(name); if (name == paddle::framework::kEmptyVarName || - name == paddle::framework::kFakeVarName || - (!in_forward_block && !in_backward_block)) { + name == paddle::framework::kFakeVarName || var == nullptr) { VLOG(2) << "find tensor name is " << name << ", skip it!"; continue; } - auto *var = scope->FindVar(name); - PADDLE_ENFORCE_NOT_NULL( - var, - paddle::platform::errors::NotFound("The output tensor %s is not in " - "RunProgram(Grad)Op'" - "s internal scope.", - name)); CheckOutputVarStatus(*var, *tensors[i]); // share tensor if (var->IsType()) { @@ -416,13 +406,12 @@ void print_collection(const T &t) { } // namespace details -inline void NewIRRunProgramAPI( +inline void PirRunProgramAPI( const std::vector &x, const std::vector ¶ms, std::vector &out, // NOLINT std::vector &middles, // NOLINT std::vector &step_scope, // NOLINT - std::vector &dout, // NOLINT bool require_any_grad, const paddle::framework::AttributeMap &attrs) { VLOG(2) << "RunProgramOpKernel Compute"; @@ -467,15 +456,19 @@ inline void NewIRRunProgramAPI( auto *forward_program = forward_global_block->GetParentOp()->GetParentProgram(); - auto *backward_program = - backward_global_block->GetParentOp()->GetParentProgram(); if (FLAGS_print_ir) { std::ostringstream print_stream; print_stream << "ForwardProgram is :\n"; forward_program->Print(print_stream); - print_stream << "BackwardProgram is:\n"; - backward_program->Print(print_stream); + if (!is_test) { + auto *backward_program = + backward_global_block->GetParentOp()->GetParentProgram(); + print_stream << "BackwardProgram is:\n"; + backward_program->Print(print_stream); + } else { + print_stream << "BackwardProgram is empty in test mode.\n"; + } std::cout << "Program (fwd | bwd): \n" << print_stream.str() << std::endl; } @@ -502,7 +495,7 @@ inline void NewIRRunProgramAPI( // Step 2. create new interpretercore auto kernel_forward_program = paddle::dialect::PdOpLowerToKernelPass(forward_program, place); - interpreter_core = paddle::framework::CreateNewIRInterpreterCoreInfoToCache( + interpreter_core = paddle::framework::CreatePirInterpreterCoreInfoToCache( std::move(kernel_forward_program), place, /*is_grad=*/false, @@ -607,7 +600,6 @@ inline void RunProgramAPI( const std::vector ¶ms, std::vector &out, // NOLINT std::vector &step_scope, // NOLINT - std::vector &dout, // NOLINT bool require_any_grad, const paddle::framework::AttributeMap &attrs) { VLOG(2) << "RunProgramOpKernel Compute"; @@ -618,6 +610,7 @@ inline void RunProgramAPI( if (attrs.count("is_test")) { is_test = PADDLE_GET_CONST(bool, attrs.at("is_test")); } + auto need_grad = !is_test && require_any_grad; int64_t program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id")); auto place = egr::Controller::Instance().GetExpectedPlace(); @@ -640,7 +633,6 @@ inline void RunProgramAPI( PADDLE_GET_CONST(std::vector, attrs.at("x_names")); auto output_names = details::GetTensorsName(out); auto param_names = details::GetTensorsName(params); - auto dout_names = details::GetTensorsName(dout); if (VLOG_IS_ON(6)) { std::stringstream s; @@ -659,11 +651,6 @@ inline void RunProgramAPI( s << name << " "; } s << std::endl; - s << "dout_names: "; - for (auto name : dout_names) { - s << name << " "; - } - s << std::endl; VLOG(6) << s.str(); } @@ -674,7 +661,7 @@ inline void RunProgramAPI( paddle::framework::BlockDesc *backward_global_block = nullptr; paddle::framework::ProgramDesc *backward_program = nullptr; - if (!is_test) { + if (need_grad) { backward_global_block = PADDLE_GET_CONST(paddle::framework::BlockDesc *, attrs.at("backward_global_block")); backward_program = backward_global_block->Program(); @@ -698,7 +685,7 @@ inline void RunProgramAPI( details::ShareTensorsIntoScope(params, global_inner_scope); // Step 2. create new interpretercore - if (FLAGS_enable_new_ir_in_executor) { + if (FLAGS_enable_pir_in_executor) { // build new ir program auto ir_program = paddle::framework::ConstructFowardIrProgram(forward_global_block, @@ -708,13 +695,12 @@ inline void RunProgramAPI( input_names, params, place); - interpreter_core = - paddle::framework::CreateNewIRInterpreterCoreInfoToCache( - std::move(ir_program), - place, - /*is_grad=*/false, - program_id, - global_inner_scope); + interpreter_core = paddle::framework::CreatePirInterpreterCoreInfoToCache( + std::move(ir_program), + place, + /*is_grad=*/false, + program_id, + global_inner_scope); } else { interpreter_core = paddle::framework::CreateProgramInterpreterCoreInfoToCache( @@ -726,7 +712,7 @@ inline void RunProgramAPI( } // Step 3. get all eager gc vars std::set skip_eager_delete_vars; - if (!is_test) { + if (need_grad) { skip_eager_delete_vars = paddle::framework::details::ParseSafeEagerDeletionSkipVarsSet( *backward_program); @@ -734,7 +720,6 @@ inline void RunProgramAPI( // all out_vars are skip_eager_var skip_eager_delete_vars.insert(output_names.begin(), output_names.end()); - skip_eager_delete_vars.insert(dout_names.begin(), dout_names.end()); // update interpretercore skip_gc_var interpreter_core->SetSkipGcVars(skip_eager_delete_vars); @@ -790,10 +775,8 @@ inline void RunProgramAPI( // Get Output details::ShareTensorsFromScopeWithPartialBlock( out, *forward_global_block, backward_global_block, global_inner_scope); - details::ShareTensorsFromScopeWithPartialBlock( - dout, *forward_global_block, backward_global_block, global_inner_scope); - if (is_test || !require_any_grad) { + if (!need_grad) { VLOG(4) << "don't require any grad, set this scope can reused"; VLOG(4) << "is_test: " << is_test << ", require_any_grad: " << require_any_grad; @@ -856,7 +839,7 @@ inline void RunProgramGradAPI( VLOG(2) << "No interpretercore cahce, so create a new interpretercore"; details::ShareTensorsIntoScope(out_grad, global_inner_scope); - if (FLAGS_enable_new_ir_in_executor) { + if (FLAGS_enable_pir_in_executor) { auto res = paddle::framework::ConstructBackwardIrProgram(backward_global_block, out_grad, @@ -865,13 +848,12 @@ inline void RunProgramGradAPI( global_inner_scope, place); - interpreter_core = - paddle::framework::CreateNewIRInterpreterCoreInfoToCache( - std::move(res), - place, - /*is_grad=*/true, - program_id, - global_inner_scope); + interpreter_core = paddle::framework::CreatePirInterpreterCoreInfoToCache( + std::move(res), + place, + /*is_grad=*/true, + program_id, + global_inner_scope); } else { interpreter_core = paddle::framework::CreateProgramInterpreterCoreInfoToCache( @@ -965,7 +947,7 @@ inline void RunProgramGradAPI( } } -inline void NewIRRunProgramGradAPI( +inline void PirRunProgramGradAPI( const std::vector &x, const std::vector ¶ms, const std::vector &out_grad, @@ -1041,7 +1023,7 @@ inline void NewIRRunProgramGradAPI( // Step 1. share input_vars & parameters into scope auto kernel_backward_program = paddle::dialect::PdOpLowerToKernelPass(backward_program, place); - interpreter_core = paddle::framework::CreateNewIRInterpreterCoreInfoToCache( + interpreter_core = paddle::framework::CreatePirInterpreterCoreInfoToCache( std::move(kernel_backward_program), place, /*is_grad=*/true, @@ -1193,6 +1175,10 @@ class GradNodeRunProgram : public egr::GradNodeBase { VLOG(3) << "End Eager Backward Node: GradNodeRunProgram"; executed_ = true; + egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&x_grad, + this->OutputMeta()[0]); + egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(¶ms_grad, + this->OutputMeta()[1]); return {x_grad, params_grad}; } @@ -1237,7 +1223,8 @@ class GradNodeRunProgram : public egr::GradNodeBase { if (x[i].is_dense_tensor()) { x_grad->emplace_back(std::make_shared()); } else if (x[i].is_selected_rows()) { - x_grad->emplace_back(std::make_shared()); + auto selected_row = std::make_shared(); + x_grad->emplace_back(selected_row); } x_grad->back().set_name(x_grad_names[i]); } @@ -1288,15 +1275,15 @@ class GradNodeRunProgram : public egr::GradNodeBase { bool executed_{false}; }; -class NewIRGradNodeRunProgram : public egr::GradNodeBase { +class PirGradNodeRunProgram : public egr::GradNodeBase { public: - NewIRGradNodeRunProgram(size_t bwd_in_slot_num, size_t bwd_out_slot_num) + PirGradNodeRunProgram(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {} - ~NewIRGradNodeRunProgram() override { + ~PirGradNodeRunProgram() override { if (!executed_) { auto *out_scope_vec = &step_scope_; - VLOG(4) << "~NewIRGradNodeRunProgram"; + VLOG(4) << "~PirGradNodeRunProgram"; // Normally out_scope_vec.size() == 1. for safty, we add for-loop here. for (size_t i = 0; i < out_scope_vec->size(); ++i) { paddle::framework::Scope *global_inner_scope = out_scope_vec->at(i); @@ -1315,9 +1302,9 @@ class NewIRGradNodeRunProgram : public egr::GradNodeBase { egr::kSlotSmallVectorSize> &grads, // NOLINT bool create_graph UNUSED, bool is_new_grad UNUSED) override { - VLOG(3) << "Running Eager Backward Node: NewIRGradNodeRunProgram"; + VLOG(3) << "Running Eager Backward Node: PirGradNodeRunProgram"; paddle::small_vector, egr::kSlotSmallVectorSize> - hooked_grads = NewIRGradNodeRunProgram::ApplyGradientHooks(grads); + hooked_grads = PirGradNodeRunProgram::ApplyGradientHooks(grads); PADDLE_ENFORCE_EQ(hooked_grads.size(), 1, paddle::platform::errors::InvalidArgument( @@ -1357,16 +1344,16 @@ class NewIRGradNodeRunProgram : public egr::GradNodeBase { "The hooked_grads[0].size() and " "out_grad_values.size() should be equal.")); - NewIRRunProgramGradAPI(x_, - params_, - hooked_grads[0], - middles_, - outputs_, - step_scope_, - attrs_, - x_grad_ptr, - params_grad_ptr); - VLOG(3) << "End Eager Backward Node: NewIRGradNodeRunProgram"; + PirRunProgramGradAPI(x_, + params_, + hooked_grads[0], + middles_, + outputs_, + step_scope_, + attrs_, + x_grad_ptr, + params_grad_ptr); + VLOG(3) << "End Eager Backward Node: PirGradNodeRunProgram"; executed_ = true; return {x_grad, params_grad}; @@ -1451,8 +1438,8 @@ class NewIRGradNodeRunProgram : public egr::GradNodeBase { } std::shared_ptr Copy() const override { - auto copied_node = std::shared_ptr( - new NewIRGradNodeRunProgram(*this)); + auto copied_node = std::shared_ptr( + new PirGradNodeRunProgram(*this)); return copied_node; } diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc index 28ca8636720dc..f8959f3b8d78a 100644 --- a/paddle/fluid/eager/utils.cc +++ b/paddle/fluid/eager/utils.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/eager/accumulation/accumulation_node.h" #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/hook_utils.h" +#include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/tensor_wrapper.h" #include "paddle/phi/api/all.h" @@ -336,6 +337,10 @@ void EagerUtils::HandleViewBetweenInputAndOutput( std::dynamic_pointer_cast(input_tensor.impl()); if (view_output_tensor->impl() == nullptr) { view_output_tensor->set_impl(std::make_shared()); + } else { + PADDLE_ENFORCE(view_output_tensor->is_dense_tensor(), + phi::errors::Unavailable( + "DenseTensor can not be inplaced with other Tensor.")); } auto view_output_dense_tensor = std::dynamic_pointer_cast(view_output_tensor->impl()); @@ -343,6 +348,35 @@ void EagerUtils::HandleViewBetweenInputAndOutput( view_output_dense_tensor->ShareInplaceVersionCounterWith( *input_dense_tensor); + VLOG(4) << "Perform View between Output Tensor(" + << view_output_tensor->name() << ") and Input Tensor(" + << input_tensor.name() + << "), share allocation and inplace version."; + } else if (input_tensor.is_dist_tensor()) { + auto input_dense_tensor = + std::dynamic_pointer_cast( + input_tensor.impl()) + ->unsafe_mutable_value(); + if (view_output_tensor->impl() == nullptr) { + view_output_tensor->set_impl( + std::make_shared( + input_tensor.dims(), + std::dynamic_pointer_cast( + input_tensor.impl()) + ->dist_attr())); + } else { + PADDLE_ENFORCE(view_output_tensor->is_dist_tensor(), + phi::errors::Unavailable( + "DistTensor can not be inplaced with other Tensor.")); + } + auto view_output_dense_tensor = + std::dynamic_pointer_cast( + view_output_tensor->impl()) + ->unsafe_mutable_value(); + view_output_dense_tensor->ShareBufferWith(*input_dense_tensor); + view_output_dense_tensor->ShareInplaceVersionCounterWith( + *input_dense_tensor); + VLOG(4) << "Perform View between Output Tensor(" << view_output_tensor->name() << ") and Input Tensor(" << input_tensor.name() @@ -498,12 +532,64 @@ void EagerUtils::FillZeroForEmptyOptionalGradInput( for (size_t i = 0; i < in_grads->size(); i++) { paddle::Tensor& grad = (*in_grads)[i]; if (!grad.initialized() && grad_in_metas[i].HasTensorMeta()) { - auto tensor_with_zero = paddle::experimental::full( - phi::vectorize(grad_in_metas[i].GetTensorMeta().dims), - 0.0, - grad_in_metas[i].GetTensorMeta().dtype, - grad_in_metas[i].GetPlace()); - grad.set_impl(tensor_with_zero.impl()); + if (grad_in_metas[i].IsDistMeta()) { + grad.set_impl(std::make_shared( + grad_in_metas[i].DistTensorGlobalDims(), + grad_in_metas[i].DistAttr())); + if (grad_in_metas[i].GetTensorMeta().dims.size() != -1) { + auto tensor_with_zero = paddle::experimental::full( + phi::vectorize(grad_in_metas[i].GetTensorMeta().dims), + 0.0, + grad_in_metas[i].GetTensorMeta().dtype, + grad_in_metas[i].GetPlace()); + *(static_cast(grad.impl().get()) + ->unsafe_mutable_value()) = + *(static_cast(tensor_with_zero.impl().get())); + } + } else { + auto tensor_with_zero = paddle::experimental::full( + phi::vectorize(grad_in_metas[i].GetTensorMeta().dims), + 0.0, + grad_in_metas[i].GetTensorMeta().dtype, + grad_in_metas[i].GetPlace()); + grad.set_impl(tensor_with_zero.impl()); + } + } + } +} + +void EagerUtils::FillZeroForEmptyOptionalGradOutput( + std::vector* output_grads, + const std::vector& grad_output_metas) { + for (size_t i = 0; i < output_grads->size(); i++) { + paddle::Tensor& grad = (*output_grads)[i]; + if (!grad.initialized() && grad_output_metas[i].HasTensorMeta()) { + if (grad.defined() && grad.is_selected_rows()) { + continue; + } + if (grad_output_metas[i].IsDistMeta()) { + grad.set_impl(std::make_shared( + grad_output_metas[i].DistTensorGlobalDims(), + grad_output_metas[i].DistAttr())); + if (grad_output_metas[i].GetTensorMeta().dims.size() != -1) { + auto tensor_with_zero = paddle::experimental::full( + phi::vectorize(grad_output_metas[i].GetTensorMeta().dims), + 0.0, + grad_output_metas[i].GetTensorMeta().dtype, + grad_output_metas[i].GetPlace()); + *(static_cast(grad.impl().get()) + ->unsafe_mutable_value()) = + *(static_cast(tensor_with_zero.impl().get())); + } + } else { + auto tensor_with_zero = + paddle::experimental::full( // only create dense tensor. + phi::vectorize(grad_output_metas[i].GetTensorMeta().dims), + 0.0, + grad_output_metas[i].GetTensorMeta().dtype, + grad_output_metas[i].GetPlace()); + grad.set_impl(tensor_with_zero.impl()); + } } } } @@ -516,12 +602,27 @@ void EagerUtils::FillZeroForEmptyGradInput(paddle::Tensor* in_grad, paddle::platform::errors::Fatal( "Unable to fill empty grad inputs due to empty GradSlotMeta")); const auto& tensor_meta = grad_in_meta.GetTensorMeta(); - auto tensor_with_zero = - paddle::experimental::full(phi::vectorize(tensor_meta.dims), - 0.0, - tensor_meta.dtype, - grad_in_meta.GetPlace()); - in_grad->set_impl(tensor_with_zero.impl()); + if (grad_in_meta.IsDistMeta()) { + in_grad->set_impl(std::make_shared( + grad_in_meta.DistTensorGlobalDims(), grad_in_meta.DistAttr())); + if (tensor_meta.dims.size() != -1) { + auto tensor_with_zero = + paddle::experimental::full(phi::vectorize(tensor_meta.dims), + 0.0, + tensor_meta.dtype, + grad_in_meta.GetPlace()); + *(static_cast(in_grad->impl().get()) + ->unsafe_mutable_value()) = + *(static_cast(tensor_with_zero.impl().get())); + } + } else { + auto tensor_with_zero = + paddle::experimental::full(phi::vectorize(tensor_meta.dims), + 0.0, + tensor_meta.dtype, + grad_in_meta.GetPlace()); + in_grad->set_impl(tensor_with_zero.impl()); + } } } @@ -529,12 +630,27 @@ void EagerUtils::FillZeroForEmptyOptionalGradInput( paddle::Tensor* in_grad, const GradSlotMeta& grad_in_meta) { if (!in_grad->initialized() && grad_in_meta.HasTensorMeta()) { const auto& tensor_meta = grad_in_meta.GetTensorMeta(); - auto tensor_with_zero = - paddle::experimental::full(phi::vectorize(tensor_meta.dims), - 0.0, - tensor_meta.dtype, - grad_in_meta.GetPlace()); - in_grad->set_impl(tensor_with_zero.impl()); + if (grad_in_meta.IsDistMeta()) { + in_grad->set_impl(std::make_shared( + grad_in_meta.DistTensorGlobalDims(), grad_in_meta.DistAttr())); + if (tensor_meta.dims.size() != -1) { + auto tensor_with_zero = + paddle::experimental::full(phi::vectorize(tensor_meta.dims), + 0.0, + tensor_meta.dtype, + grad_in_meta.GetPlace()); + *(static_cast(in_grad->impl().get()) + ->unsafe_mutable_value()) = + *(static_cast(tensor_with_zero.impl().get())); + } + } else { + auto tensor_with_zero = + paddle::experimental::full(phi::vectorize(tensor_meta.dims), + 0.0, + tensor_meta.dtype, + grad_in_meta.GetPlace()); + in_grad->set_impl(tensor_with_zero.impl()); + } } } @@ -634,22 +750,29 @@ std::string EagerUtils::TensorStr(const paddle::Tensor& t) { std::string tensor_info_str = ""; if (t.defined()) { if (t.is_dist_tensor()) { + const char* DIST_TENSOR_INFO_TEMPLATE = + "Type: %s, Dtype: %s, Place: %s, Is_defined: %s, Is_initialized: %s, " + "Shape: %s, DistAttr: %s"; auto dist_t = std::static_pointer_cast(t.impl()); if (t.initialized()) { tensor_info_str += paddle::string::Sprintf( - TENSOR_INFO_TEMPLATE, + DIST_TENSOR_INFO_TEMPLATE, t.impl()->type_info().name(), t.dtype(), t.place().DebugString(), + dist_t->defined(), + dist_t->initialized(), paddle::string::Sprintf( "%s, Local Shape: %s", t.dims(), dist_t->local_dims()), dist_t->dist_attr()); } else { - tensor_info_str += paddle::string::Sprintf(TENSOR_INFO_TEMPLATE, + tensor_info_str += paddle::string::Sprintf(DIST_TENSOR_INFO_TEMPLATE, t.impl()->type_info().name(), "Unknown", "Unknown", + dist_t->defined(), + dist_t->initialized(), t.dims(), dist_t->dist_attr()); } diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index 8dd950be0cbe2..6013820c76b67 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -242,6 +242,9 @@ class TEST_API EagerUtils { static void FillZeroForEmptyOptionalGradInput( std::vector* in_grads, const std::vector& grad_in_metas); + static void FillZeroForEmptyOptionalGradOutput( + std::vector* out_grads, + const std::vector& grad_out_metas); static void FillZeroForEmptyGradInput(paddle::Tensor* in_grad, const GradSlotMeta& grad_in_meta); static void FillZeroForEmptyOptionalGradInput( diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index b83568cfdd69a..0a88038de5078 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -55,13 +55,13 @@ function(pass_library TARGET DEST) ${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry - ${pass_library_DEPS}) + quantize_helper ${pass_library_DEPS}) else() cc_library( ${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry - ${pass_library_DEPS}) + quantize_helper ${pass_library_DEPS}) endif() # add more DEST here, such as train, dist and collect USE_PASS into a file automatically. @@ -122,69 +122,16 @@ cc_library( SRCS data_type.cc DEPS framework_proto) -cc_test( - data_type_test - SRCS data_type_test.cc - DEPS data_type place tensor) - cc_library( tensor SRCS tensor_util.cc DEPS place memory data_type device_context phi) -cc_test( - tensor_test - SRCS tensor_test.cc - DEPS tensor isfinite_op) -if(WITH_GPU) - nv_test( - tensor_util_test - SRCS tensor_util_test.cc tensor_util_test.cu - DEPS tensor dlpack_tensor isfinite_op) -elseif(WITH_ROCM) - hip_test( - tensor_util_test - SRCS tensor_util_test.cc tensor_util_test.cu - DEPS tensor dlpack_tensor isfinite_op) -else() - cc_test( - tensor_util_test - SRCS tensor_util_test.cc - DEPS tensor dlpack_tensor isfinite_op) -endif() - -cc_test( - copy_same_tensor_test - SRCS copy_same_tensor_test.cc - DEPS tensor) - -cc_test( - eigen_test - SRCS eigen_test.cc - DEPS tensor) - cc_library( lod_tensor SRCS lod_tensor.cc DEPS phi place tensor framework_proto version) -cc_test( - lod_tensor_test - SRCS lod_tensor_test.cc - DEPS phi lod_tensor memory) - -if(WITH_GPU) - nv_test( - lod_tensor_gpu_test - SRCS lod_tensor_test.cu - DEPS lod_tensor) -elseif(WITH_ROCM) - hip_test( - lod_tensor_gpu_test - SRCS lod_tensor_test.cu - DEPS lod_tensor) -endif() - cc_library( garbage_collector SRCS garbage_collector.cc @@ -194,15 +141,6 @@ cc_library( reader SRCS reader.cc DEPS lod_tensor phi) -cc_test( - reader_test - SRCS reader_test.cc - DEPS reader) - -cc_test( - threadpool_test - SRCS threadpool_test.cc - DEPS phi) cc_library( var_type_traits @@ -221,11 +159,6 @@ if(WITH_MKLDNN) add_dependencies(var_type_traits mkldnn) endif() -cc_test( - var_type_traits_test - SRCS var_type_traits_test.cc - DEPS var_type_traits) - set(BRPC_DEPS "") if(WITH_PSCORE) set(BRPC_DEPS ${EXTERNAL_BRPC_DEPS}) @@ -249,39 +182,15 @@ cc_library( device_worker SRCS device_worker.cc DEPS trainer_desc_proto lod_tensor scope ${BRPC_DEPS}) -cc_test( - device_worker_test - SRCS device_worker_test.cc - DEPS device_worker) - cc_library( scope_pool SRCS scope_pool.cc DEPS scope) -cc_test( - scope_test - SRCS scope_test.cc - DEPS scope) -cc_test( - variable_test - SRCS variable_test.cc - DEPS tensor var_type_traits) cc_library( data_device_transform SRCS data_device_transform.cc DEPS tensor) -if(WITH_GPU) - nv_test( - data_device_transform_test - SRCS data_device_transform_test.cu - DEPS operator op_registry device_context phi scope) -elseif(WITH_ROCM) - hip_test( - data_device_transform_test - SRCS data_device_transform_test.cu - DEPS operator op_registry device_context phi scope) -endif() if(WITH_GPU) if(WIN32) @@ -299,47 +208,27 @@ if(WITH_GPU) SRCS data_type_transform.cu DEPS tensor) endif() - nv_test( - data_type_transform_test - SRCS data_type_transform_test.cc data_type_transform_test.cu - DEPS data_type_transform) elseif(WITH_ROCM) hip_library( data_type_transform SRCS data_type_transform.cu DEPS tensor) - hip_test( - data_type_transform_test - SRCS data_type_transform_test.cc data_type_transform_test.cu - DEPS data_type_transform) elseif(WITH_XPU) cc_library( data_type_transform SRCS data_type_transform.cc DEPS tensor xpulib) - cc_test( - data_type_transform_test - SRCS data_type_transform_test.cc - DEPS data_type_transform) else() cc_library( data_type_transform SRCS data_type_transform.cc DEPS tensor) - cc_test( - data_type_transform_test - SRCS data_type_transform_test.cc - DEPS data_type_transform) endif() cc_library( data_layout_transform SRCS data_layout_transform.cc DEPS tensor phi) -cc_test( - data_layout_transform_test - SRCS data_layout_transform_test.cc - DEPS data_layout_transform) cc_library( data_transform @@ -357,18 +246,6 @@ cc_library( attribute SRCS attribute.cc DEPS framework_proto enforce) -cc_test( - attribute_test - SRCS attribute_test.cc - DEPS attribute framework_proto proto_desc) -cc_test( - program_desc_test - SRCS program_desc_test.cc - DEPS proto_desc device_context) -cc_test( - op_desc_test - SRCS op_desc_test.cc - DEPS proto_desc) cc_library( op_version_proto SRCS op_version_proto.cc @@ -378,19 +255,11 @@ cc_library( op_version_registry SRCS op_version_registry.cc DEPS op_version_proto framework_proto) -cc_test( - op_version_registry_test - SRCS op_version_registry_test.cc - DEPS op_version_registry) cc_library( op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute ops_extra_info glog auto_parallel_proto) -cc_test( - op_proto_maker_test - SRCS op_proto_maker_test.cc - DEPS op_proto_maker) cc_library( no_need_buffer_vars_inference SRCS no_need_buffer_vars_inference.cc @@ -410,11 +279,6 @@ if(WITH_MKLDNN) add_dependencies(shape_inference mkldnn) endif() -cc_test( - no_need_buffer_vars_inference_test - SRCS no_need_buffer_vars_inference_test.cc - DEPS no_need_buffer_vars_inference layer) - cc_library( transfer_scope_cache SRCS transfer_scope_cache.cc @@ -503,20 +367,7 @@ else() type_info) endif() -cc_test( - operator_test - SRCS operator_test.cc - DEPS operator op_registry device_context) -cc_test( - operator_exception_test - SRCS operator_exception_test.cc - DEPS operator op_registry device_context) - cc_library(version SRCS version.cc) -cc_test( - version_test - SRCS version_test.cc - DEPS version) add_library(proto_desc_base OBJECT var_desc.cc op_desc.cc block_desc.cc program_desc.cc) @@ -556,31 +407,11 @@ cc_library( op_call_stack SRCS op_call_stack.cc DEPS op_proto_maker enforce) -cc_test( - op_call_stack_test - SRCS op_call_stack_test.cc - DEPS op_call_stack) cc_library( program_utils SRCS program_utils.cc DEPS proto_desc) -cc_test( - program_utils_test - SRCS program_utils_test.cc - DEPS proto_desc program_utils) - -if(WITH_GPU) - nv_test( - op_registry_test - SRCS op_registry_test.cc - DEPS op_registry) -elseif(WITH_ROCM) - hip_test( - op_registry_test - SRCS op_registry_test.cc - DEPS op_registry) -endif() if(WITH_PYTHON) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) @@ -1049,104 +880,25 @@ cc_library( cc_library( executor_cache SRCS executor_cache.cc - DEPS parallel_executor standalone_executor pir_adaptor pd_inplace_pass - pd_op_to_kernel_pass pir) -if(WITH_PSCORE) - get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) - if(WITH_HETERPS) - cc_test( - dist_multi_trainer_test - SRCS dist_multi_trainer_test.cc - DEPS conditional_block_op executor gloo_wrapper ${RPC_DEPS} - graph_gpu_wrapper) - cc_test( - heter_pipeline_trainer_test - SRCS heter_pipeline_trainer_test.cc - DEPS conditional_block_op - generated_op - heter_listen_and_serv_op - executor - heter_server - gloo_wrapper - phi - ${RPC_DEPS} - graph_gpu_wrapper) - else() - cc_test( - dist_multi_trainer_test - SRCS dist_multi_trainer_test.cc - DEPS conditional_block_op executor gloo_wrapper ${RPC_DEPS}) - cc_test( - heter_pipeline_trainer_test - SRCS heter_pipeline_trainer_test.cc - DEPS conditional_block_op - generated_op - heter_listen_and_serv_op - executor - heter_server - gloo_wrapper - phi - ${RPC_DEPS}) - endif() -else() - cc_test( - dist_multi_trainer_test - SRCS dist_multi_trainer_test.cc - DEPS conditional_block_op executor gloo_wrapper) -endif() + DEPS parallel_executor standalone_executor pir_adaptor transform pir) cc_library( prune SRCS prune.cc DEPS framework_proto auto_parallel_proto proto_desc) -cc_test( - prune_test - SRCS prune_test.cc - DEPS op_info prune recurrent_op device_context) -cc_test( - var_type_inference_test - SRCS var_type_inference_test.cc - DEPS op_registry proto_desc) cc_library( selected_rows_utils SRCS selected_rows_utils.cc DEPS phi device_context) -cc_test( - selected_rows_utils_test - SRCS selected_rows_utils_test.cc - DEPS selected_rows_utils) - -cc_test( - op_kernel_type_test - SRCS op_kernel_type_test.cc - DEPS place device_context framework_proto op_kernel_type) -cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) - -cc_test(tuple_test SRCS tuple_test.cc) - -cc_test(inlined_vector_test SRCS inlined_vector_test.cc) cc_library( dlpack_tensor SRCS dlpack_tensor.cc DEPS tensor dlpack) -cc_test( - dlpack_tensor_test - SRCS dlpack_tensor_test.cc - DEPS dlpack_tensor glog) cc_library( op_compatible_info SRCS op_compatible_info.cc DEPS string_helper proto_desc) -cc_test_old( - op_compatible_info_test - SRCS - op_compatible_info_test.cc - DEPS - op_compatible_info - proto_desc - string_helper - glog) cc_library( infershape_utils @@ -1160,10 +912,6 @@ cc_library( phi_utils op_info shape_inference) -cc_test( - infershape_utils_test - SRCS infershape_utils_test.cc - DEPS infershape_utils phi) # Get the current working branch execute_process( @@ -1215,15 +963,3 @@ set(FLUID_FRAMEWORK_MODULES custom_operator) cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES}) - -if(WITH_TESTING AND TEST selected_rows_utils_test) - set_tests_properties(selected_rows_utils_test PROPERTIES TIMEOUT 120) -endif() - -cc_test(scope_guard_test SRCS scope_guard_test.cc) -cc_test( - phi_utils_test - SRCS phi_utils_test.cc - DEPS phi_utils) - -cc_test(convert_utils_test SRCS convert_utils_test.cc) diff --git a/paddle/fluid/framework/block_desc.h b/paddle/fluid/framework/block_desc.h index def0d1742ba95..ce085688dd582 100644 --- a/paddle/fluid/framework/block_desc.h +++ b/paddle/fluid/framework/block_desc.h @@ -48,6 +48,10 @@ class TEST_API BlockDesc { int32_t Parent() const { return desc_->parent_idx(); } + void SetParent(int32_t parent_id) const { + return desc_->set_parent_idx(parent_id); + } + int32_t ForwardBlockID() const { return desc_->forward_block_idx(); } VarDesc *Var(const std::string &name_bytes); diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index ded28eaf5cc12..1ab08b46b57ff 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -308,6 +308,7 @@ cc_test( memory device_context broadcast_op_handle) + cc_test_old( gather_op_test SRCS diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index 19b3a281afee6..31782e0d7bc9e 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -318,15 +318,38 @@ struct OpInfoFiller { } }; +template +struct InferMetaTrait { + static void call(const char* op_type UNUSED, OpInfo* info) { + info->infer_shape_ = [](InferShapeContext* ctx) { + T inference; + inference(ctx); + }; + } +}; + template -struct OpInfoFiller { - void operator()(const char* op_type UNUSED, OpInfo* info) const { - // Note: if fill InferShapeFN by this Filler, the infershape here - // will overwrite the op->InferShape func registered in kOperator Filler +struct InferMetaTrait().infer_meta_( + std::declval()))> { + static void call(const char* op_type UNUSED, OpInfo* info) { info->infer_shape_ = [](InferShapeContext* ctx) { T inference; inference(ctx); }; + info->infer_meta_ = [](phi::InferMetaContext* ctx) { + T inference; + inference.infer_meta_(ctx); + }; + } +}; + +template +struct OpInfoFiller { + void operator()(const char* op_type UNUSED, OpInfo* info) const { + // Note: if fill InferShapeFN by this Filler, the infershape here + // will overwrite the op->InferShape func registered in kOperator Filler + InferMetaTrait::call(op_type, info); } }; diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 50a16d8f686e7..6da191f2124b6 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -23,11 +23,19 @@ enum Mode { HETER = 4; // support XPU and GPU computing server } +message RefinedOpsPattern { + repeated string main_ops = 1; + optional int32 num = 2 [default = 0]; + repeated string pre_ops = 3; + repeated string suf_ops = 4; +} + message RecomputeConfig { repeated string checkpoints = 1; optional bool enable_offload = 2 [ default = false ]; repeated int32 checkpoint_shape = 3; optional bool enable_tuning = 4 [ default = false ]; // incubate for auto parallel + repeated RefinedOpsPattern refined_ops_patterns = 5; } message ShardingConfig { @@ -78,6 +86,7 @@ message DygraphShardingConfig { optional bool tensor_fusion = 1 [ default = false ]; optional int32 accumulate_steps = 2 [ default = 1 ]; optional bool comm_overlap = 3 [ default = false ]; + optional bool split_param = 4 [ default = false ]; } message HybridConfig { @@ -105,6 +114,7 @@ message AMPConfig { optional bool use_fp16_guard = 11 [ default = true ]; optional bool use_optimizer_fp16 = 12 [ default = false ]; // auto parallel effective only + optional bool use_pure_bf16 = 13 [ default = false ]; } message LocalSGDConfig { diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 2e1eb0a58fe5a..9e34f4c4c6739 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -25,7 +25,7 @@ #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" -PHI_DECLARE_bool(new_ir_apply_inplace_pass); +PHI_DECLARE_bool(pir_apply_inplace_pass); namespace paddle { namespace framework { @@ -324,7 +324,7 @@ std::shared_ptr CreateProgramInterpreterCoreInfoToCache( return core; } -std::shared_ptr CreateNewIRInterpreterCoreInfoToCache( +std::shared_ptr CreatePirInterpreterCoreInfoToCache( std::unique_ptr<::pir::Program> ir_program, const platform::Place &place, bool is_grad, @@ -458,7 +458,7 @@ std::unique_ptr<::pir::Program> ConstructFowardIrProgram( auto ir_res = paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - if (FLAGS_new_ir_apply_inplace_pass) { + if (FLAGS_pir_apply_inplace_pass) { ::pir::PassManager pm(::pir::IrContext::Instance(), 3); pm.AddPass(::pir::CreateInplacePass()); pm.Run(ir_res.get()); @@ -540,9 +540,12 @@ std::unique_ptr<::pir::Program> ConstructBackwardIrProgram( auto res = paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - if (FLAGS_new_ir_apply_inplace_pass) { + if (FLAGS_pir_apply_inplace_pass) { ::pir::PassManager pm(::pir::IrContext::Instance(), 3); pm.AddPass(::pir::CreateInplacePass()); + if (VLOG_IS_ON(6)) { + pm.EnableIRPrinting(); + } pm.Run(res.get()); } diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index d30ed6396e65e..8a1878787fcd1 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -34,7 +34,7 @@ #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/program.h" -PHI_DECLARE_bool(enable_new_ir_in_executor); +PHI_DECLARE_bool(enable_pir_in_executor); namespace paddle { namespace framework { namespace ir { @@ -190,7 +190,7 @@ class InterpreterCoreInfoCache { static InterpreterCoreInfoCache& Instance(); bool Has(int64_t program_id, const framework::Scope* scope, bool is_grad) { - if (FLAGS_enable_new_ir_in_executor) { + if (FLAGS_enable_pir_in_executor) { int64_t scope_i = reinterpret_cast(scope); program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2); } @@ -201,7 +201,7 @@ class InterpreterCoreInfoCache { InterpreterCoreInfo::CacheValue& GetMutable(int64_t program_id, const framework::Scope* scope, bool is_grad) { - if (FLAGS_enable_new_ir_in_executor) { + if (FLAGS_enable_pir_in_executor) { int64_t scope_i = reinterpret_cast(scope); program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2); } @@ -243,7 +243,7 @@ std::shared_ptr CreateProgramInterpreterCoreInfoToCache( int64_t program_id, framework::Scope* scope); -std::shared_ptr CreateNewIRInterpreterCoreInfoToCache( +std::shared_ptr CreatePirInterpreterCoreInfoToCache( std::unique_ptr<::pir::Program> ir_prog, const platform::Place& place, bool is_grad, diff --git a/paddle/fluid/framework/feed_fetch_method.cc b/paddle/fluid/framework/feed_fetch_method.cc index 7a62b5563f30a..a7c76322422e8 100644 --- a/paddle/fluid/framework/feed_fetch_method.cc +++ b/paddle/fluid/framework/feed_fetch_method.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "glog/logging.h" -PHI_DECLARE_bool(enable_new_ir_in_executor); +PHI_DECLARE_bool(enable_pir_in_executor); PHI_DECLARE_bool(enable_pir_api); namespace phi { @@ -37,7 +37,7 @@ void SetFeedVariable(Scope* scope, // If var_name Variable is not found in GlobalScope, a new variable will // be created. VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index; - if (FLAGS_enable_new_ir_in_executor) { + if (FLAGS_enable_pir_in_executor) { // shared data with input tensor auto feed_ele = scope->Var(var_name); if (!feed_ele->IsType()) { diff --git a/paddle/fluid/framework/infershape_utils.h b/paddle/fluid/framework/infershape_utils.h index 17587ba3aea34..6e1170c6ee0fe 100644 --- a/paddle/fluid/framework/infershape_utils.h +++ b/paddle/fluid/framework/infershape_utils.h @@ -151,6 +151,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, paddle::framework::BuildInferMetaContext(ctx, #op_type); \ fn(&infer_meta_context); \ } \ + void infer_meta_(phi::InferMetaContext* ctx) const { fn(ctx); } \ } } // namespace framework diff --git a/paddle/fluid/operators/fused/fusion_gru_op.h b/paddle/fluid/framework/init_default_kernel_signature_map.h similarity index 50% rename from paddle/fluid/operators/fused/fusion_gru_op.h rename to paddle/fluid/framework/init_default_kernel_signature_map.h index e811df655099d..a6b6400dd19f5 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.h +++ b/paddle/fluid/framework/init_default_kernel_signature_map.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,26 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/fluid/framework/op_registry.h" -namespace paddle { -namespace operators { - -class FusionGRUOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override; +#include "paddle/utils/test_macros.h" - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override; -}; - -class FusionGRUOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override; -}; - -} // namespace operators +// The implementation of InitDefaultKernelSignatureMap is in phi_utils.cc +namespace paddle { +namespace framework { +TEST_API void InitDefaultKernelSignatureMap(); +} // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 92d316fdea0a3..50af3d1779ecc 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -46,7 +46,7 @@ cc_library( cc_library( op_compat_sensible_pass SRCS op_compat_sensible_pass.cc - DEPS graph_pattern_detector op_def_api pass pir_core) + DEPS graph_pattern_detector op_def_api pass pir) cc_library( subgraph_detector SRCS subgraph_detector.cc @@ -59,6 +59,10 @@ cc_library( placement_pass_base SRCS placement_pass_base.cc DEPS pass) +cc_library( + quantize_helper + SRCS quantize_helper.cc + DEPS graph graph_helper) cc_library( coalesce_grad_tensor_pass @@ -237,7 +241,11 @@ if(WITH_XPU) xpu_pass_utils SRCS xpu/pass_utils.cc DEPS pass xpu_quant_utils) - set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils) + cc_library( + xpu_graph_pattern_detector + SRCS xpu/xpu_graph_pattern_detector.cc + DEPS graph_pattern_detector) + set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils xpu_graph_pattern_detector) pass_library(cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) @@ -247,6 +255,8 @@ if(WITH_XPU) # pass_library(conv1d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(conv2d_bias_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(xpu_quantize_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(xpu_quantize_squash_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(redundant_unsqueeze_squeeze_elimination_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(redundant_squeeze_unsqueeze_elimination_pass inference DIR xpu @@ -506,14 +516,14 @@ if(WITH_MKLDNN) test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) - cc_test_old( - test_int8_scale_calculation_mkldnn_pass SRCS - mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc DEPS - int8_scale_calculation_mkldnn_pass pass_test_util) - cc_test_old( - test_params_quantization_mkldnn_pass SRCS - mkldnn/params_quantization_mkldnn_pass_tester.cc DEPS - params_quantization_mkldnn_pass) + cc_test( + test_int8_scale_calculation_mkldnn_pass + SRCS mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc + DEPS int8_scale_calculation_mkldnn_pass pass_test_util) + cc_test( + test_params_quantization_mkldnn_pass + SRCS mkldnn/params_quantization_mkldnn_pass_tester.cc + DEPS params_quantization_mkldnn_pass) set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index 14f42b129effa..d29ef0f9ad1fa 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -523,7 +523,6 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { vars_should_not_low_precision.insert(in_var_node->Var()->Name()); } } - // when op_1 only support cpu kernel. if op_2's intput var is op_1's // output var, then op_2 should not run at low precision. if (GetOpOriginalType(op_type) != "feed" && @@ -687,6 +686,16 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert( if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { return true; } + } else if (GetOpOriginalType(op_desc->Type()) == "quantize_linear" || + GetOpOriginalType(op_desc->Type()) == "dequantize_linear") { + auto vecs = op_desc->Input("Scale"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("ZeroPoint"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } } } @@ -733,6 +742,11 @@ bool AutoMixedPrecisionPass::OutputVarsNotConvert( } void AutoMixedPrecisionPass::SetVarPrecision() const { + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL(scope, + platform::errors::PreconditionNotMet( + "During the auto_mixed_precision_pass, the scope " + "should not be null.")); for (const auto& nodes : all_op_nodes_) { for (auto* op_node : nodes) { if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) { @@ -749,7 +763,21 @@ void AutoMixedPrecisionPass::SetVarPrecision() const { if (!IsFP32AndFP64(real_in_var_node->Var()->GetDataType())) continue; if (!VarNodeHasDtype(real_in_var_node)) continue; if (InputVarsNotConvert(op_node, in_var_name)) continue; - + // Judge the real tensor is same to variable, Paddle-Slim weight use + // fp32 variable to save int8 tensor. + if (real_in_var_node->Var()->Persistable()) { + auto* tensor = scope->Var(real_in_var_node->Name()) + ->GetMutable(); + if (framework::TransToProtoVarType(tensor->type()) != + real_in_var_node->Var()->GetDataType()) { + VLOG(3) << "[AutoMixedPrecisionPass] variable " + << real_in_var_node->Name() << "'s proto data type " + << real_in_var_node->Var()->GetDataType() + << " is different from real dense tensor " + << framework::TransToProtoVarType(tensor->type()); + continue; + } + } if (real_in_var_node->Var()->Persistable()) { real_in_var_node->Var()->SetDataType( framework::TransToProtoVarType(low_precision_)); diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc index 286f7f08cdfc9..916d577d23d60 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc @@ -19,6 +19,7 @@ #include #include #include +#include "paddle/fluid/framework/ir/quantize_helper.h" namespace paddle { namespace framework { @@ -94,6 +95,8 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { scope, platform::errors::InvalidArgument( "Scope in DeleteQuantDequantLinearOpPass should not be null.")); + std::unordered_map> var_quant_scales{}; + // Create pattern patterns::DeleteQuantDequantLinearOpPattern pattern(gpd.mutable_pattern(), pattern_name); @@ -141,7 +144,11 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { auto* any_op_desc = dequantize_linear_op_out->outputs[i]->Op(); any_op_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(), input_scale); - + if (!var_quant_scales.count(quantize_linear_op_x->Var()->Name())) { + var_quant_scales.insert( + std::make_pair(quantize_linear_op_x->Var()->Name(), + std::vector({input_scale}))); + } // link x to any_op2 any_op_desc->RenameInput(dequantize_linear_op_out->Var()->Name(), quantize_linear_op_x->Var()->Name()); @@ -161,6 +168,9 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { }; gpd(graph, handler); AddStatis(found_count); + + SaveQuantInfoInTheGraph( + graph, "has_quant_info", "var_quant_scales", var_quant_scales); } } // namespace ir diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc index cf5c9a2c94cf9..87f2de2a59e0d 100644 --- a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/quantize_helper.h" #include "glog/logging.h" @@ -35,18 +36,20 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { true, platform::errors::InvalidArgument( "Graph must have kParamScopeAttr attribute.")); - + VLOG(3) << "Handle delete weight dequant linear op pass ..."; auto& scope = graph->Get(kParamScopeAttr); bool is_int8 = false; std::unordered_set nodes2rm; + std::unordered_map> var_quant_scales{}; for (const Node* n : graph->Nodes()) { if (n->IsOp()) { auto* op = n->Op(); if (op->Type() == "dequantize_linear") { - Node *weight_var_node = nullptr, *calcu_op_node = nullptr, - *while_op_node = nullptr; + Node* weight_var_node = nullptr; + Node* calcu_op_node = nullptr; + Node* while_op_node = nullptr; Node *dequantized_weight_var_node = nullptr, *scale_var_node = nullptr; // 1. Judge whether for dequant weight and find // weight_var_node/scale_var_node @@ -59,9 +62,12 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { scale_var_node = input_node; } } else { - return; + break; } } + if (weight_var_node == nullptr || scale_var_node == nullptr) { + continue; + } // 2. Find next_op_node // For while op: delete its input which is related to dequantized // For calculation op: set weight scale as their attributes @@ -106,7 +112,7 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { } } else { PADDLE_THROW(platform::errors::Unimplemented( - "The dtype of quantization scale must be FP32/16, " + "The dtype of quantization scale must be FP32/FP16, " "but received %d, which is not supported.", weight_scale_tensor->dtype())); } @@ -125,14 +131,34 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { calcu_op_desc->SetAttr("weight_scale", weight_scale[0]); } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Delete Weight Dequant Linear Op Pass is not supported " - "for " - "per-channel quantization")); + std::vector weights_shape = + weight_var_node->Var()->GetShape(); + quant_axis = quant_axis >= 0 + ? quant_axis + : quant_axis + weights_shape.size(); + PADDLE_ENFORCE_EQ( + weight_scale_nums, + weights_shape[quant_axis], + platform::errors::InvalidArgument( + "When quant_axis != -1, it means using per_channel " + "dequantization. In this situation, the number of " + "weight_scale should be equal with " + "weights_shape[quant_axis=%d]=%ld , but received " + "%d.", + quant_axis, + weights_shape[quant_axis], + weight_scale_nums)); + calcu_op_desc->SetAttr("weight_scale", weight_scale); } + if (!var_quant_scales.count(weight_var_node->Var()->Name())) { + var_quant_scales.insert(std::make_pair( + weight_var_node->Var()->Name(), weight_scale)); + } + calcu_op_desc->RenameInput( dequantized_weight_var_node->Var()->Name(), weight_var_node->Var()->Name()); + calcu_op_desc->Flush(); } } } @@ -153,6 +179,8 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { } GraphSafeRemoveNodes(graph, nodes2rm); + SaveQuantInfoInTheGraph( + graph, "has_quant_info", "var_quant_scales", var_quant_scales); graph->Set("enable_int8", new bool(is_int8)); } } // namespace ir diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 3596f4e0f0e29..e42334aac0593 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -412,6 +412,20 @@ class Graph { return sub_graphs_.size(); } + std::vector AttrNames() const { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->AttrNames(); + } + } + std::vector res; + res.reserve(attrs_.size()); + for (auto &attr : attrs_) { + res.push_back(attr.first); + } + return res; + } + private: // TODO(levi): delete this interface after when we can convert all // blocks into sub_graphs. diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc index 770a3a7a1d117..d0fb6d58443ae 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc @@ -15,7 +15,6 @@ #include #include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h" -#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/imperative/type_defs.h" namespace paddle { diff --git a/paddle/fluid/framework/ir/quantize_helper.cc b/paddle/fluid/framework/ir/quantize_helper.cc new file mode 100644 index 0000000000000..08f2cc457ef2c --- /dev/null +++ b/paddle/fluid/framework/ir/quantize_helper.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/quantize_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SaveQuantInfoInTheGraph( + ir::Graph* graph, + const std::string& flag, + const std::string& key_suffix, + const std::unordered_map>& info_map) { + const std::string suffix = "_" + key_suffix + "_" + flag; + if (!graph->Has(flag)) { + graph->Set(flag, new bool(true)); + } + for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) { + graph->Set(iter->first + suffix, new std::vector(iter->second)); + } +} + +std::unordered_map> GetQuantInfoFromTheGraph( + ir::Graph* graph, const std::string& flag, const std::string& key_suffix) { + std::unordered_map> info_map; + const std::string suffix = "_" + key_suffix + "_" + flag; + if (graph->Has(flag)) { + std::vector attr_names = graph->AttrNames(); + for (auto fake_name : attr_names) { + size_t pos = fake_name.find(suffix); + if (pos != std::string::npos) { + std::string name = fake_name.substr(0, pos); + auto scales_vector = graph->Get>(fake_name); + info_map.insert(std::make_pair(name, scales_vector)); + } + } + } + return info_map; +} + +bool AreScalesPresentForNodes( + std::unordered_map>* var_quant_scales, + std::initializer_list nodes) { + bool present = true; + for (auto node : nodes) { + if (var_quant_scales->count(node->Name()) == 0) { + present = false; + } + } + return present; +} + +float GetScaleValueForNode( + std::unordered_map>* var_quant_scales, + Node* node) { + return var_quant_scales->at(node->Name())[0]; +} + +std::vector GetScaleVecValueForNode( + std::unordered_map>* var_quant_scales, + Node* node) { + return var_quant_scales->at(node->Name()); +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/quantize_helper.h b/paddle/fluid/framework/ir/quantize_helper.h new file mode 100644 index 0000000000000..4876cd35a1cf3 --- /dev/null +++ b/paddle/fluid/framework/ir/quantize_helper.h @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SaveQuantInfoInTheGraph( + ir::Graph* graph, + const std::string& flag, + const std::string& key_suffix, + const std::unordered_map>& info_map); + +std::unordered_map> GetQuantInfoFromTheGraph( + ir::Graph* graph, const std::string& flag, const std::string& key_suffix); + +bool AreScalesPresentForNodes( + std::unordered_map>* var_quant_scales, + std::initializer_list nodes); + +float GetScaleValueForNode( + std::unordered_map>* var_quant_scales, + Node* node); + +std::vector GetScaleVecValueForNode( + std::unordered_map>* var_quant_scales, + Node* node); + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc b/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc index 0403330f77cd1..5a086acd7cac2 100644 --- a/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc +++ b/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc @@ -126,6 +126,26 @@ bool ModelLayoutIsNHWC(const std::vector &op_nodes) { return false; } +// Do additional check if OP's weight is not persistable +typedef std::string OP_NAME; +typedef std::string WEIGHT_NAME; +typedef std::unordered_map OP_WEIGHT_NAME; +bool IsWeight(ir::Node *op_node, + ir::Node *var_node, + const OP_WEIGHT_NAME &op_weight_pair) { + if (var_node->Var()->Persistable()) return true; + auto *op_desc = op_node->Op(); + std::string op_type = op_desc->Type(); + std::string var_name = var_node->Var()->Name(); + if (op_weight_pair.count(op_type)) { + if (var_name == + op_desc->Input(op_weight_pair.find(op_type)->second).front()) { + return true; + } + } + return false; +} + } // namespace void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { @@ -155,6 +175,9 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { "bilinear_interp_v2", "nearest_interp", "nearest_interp_v2"}; + // Op's weight could be temporary variable, so we save the name of OP's weight + // input + OP_WEIGHT_NAME op_weight_pair{{"conv2d", "Filter"}}; // Ops must run under the original layout even though it has // data_format/data_layout attribute, otherwise it will be very troublesome! std::unordered_set must_original_layout_ops{ @@ -193,7 +216,7 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { auto op_inputs = op_node->inputs; for (auto *in_var_node : op_inputs) { CHECK_EQ(in_var_node->IsVar(), true); - if (in_var_node->Var()->Persistable()) continue; + if (IsWeight(op_node, in_var_node, op_weight_pair)) continue; auto input_shape = in_var_node->Var()->GetShape(); input_shape_4 &= (input_shape.size() == 4); @@ -326,7 +349,7 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { for (auto *in_var_node : op_inputs) { CHECK_EQ(in_var_node->IsVar(), true); - if (in_var_node->Var()->Persistable()) continue; + if (IsWeight(op_node, in_var_node, op_weight_pair)) continue; if (vars_to_nchw.count(in_var_node)) continue; DoInsertTransposeOp(graph, diff --git a/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass.cc index ef8759153b0cc..1a56e4d660431 100644 --- a/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass.cc @@ -127,6 +127,7 @@ int CastMixedPrecisionOpFusePass::ApplyCastBeforePass( GraphPatternDetector gpd; patterns::CastBeforePattern pattern( gpd.mutable_pattern(), name_scope_, mixed_precision_op_type); + auto* scope = param_scope(); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -136,7 +137,22 @@ int CastMixedPrecisionOpFusePass::ApplyCastBeforePass( GET_IR_NODE(cast); GET_IR_NODE(cast_out); GET_IR_NODE(mixed_precision_op); - + // Note: conv2d_xpu/fc_xpu not support float32/int8/float16, can not fuse. + if (mixed_precision_op_type == "conv2d_xpu") { + auto filter_name = mixed_precision_op->Op()->Input("filter")[0]; + auto filter_data_type = + scope->FindVar(filter_name)->GetMutable()->dtype(); + if (filter_data_type == phi::DataType::INT8) { + return; + } + } else if (mixed_precision_op_type == "fc_xpu") { + auto w_name = mixed_precision_op->Op()->Input("w")[0]; + auto w_data_type = + scope->FindVar(w_name)->GetMutable()->dtype(); + if (w_data_type == phi::DataType::INT8) { + return; + } + } mixed_precision_op->Op()->RenameInput(cast_out->Name(), cast_in->Name()); IR_NODE_LINK_TO(cast_in, mixed_precision_op); @@ -155,6 +171,7 @@ int CastMixedPrecisionOpFusePass::ApplyCastAfterPass( GraphPatternDetector gpd; patterns::CastAfterPattern pattern( gpd.mutable_pattern(), name_scope_, mixed_precision_op_type); + auto* scope = param_scope(); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -164,7 +181,30 @@ int CastMixedPrecisionOpFusePass::ApplyCastAfterPass( GET_IR_NODE(cast_in); GET_IR_NODE(cast); GET_IR_NODE(cast_out); - + // Note: conv2d_xpu/fc_xpu not support float16/int8/float32, can not fuse. + if (mixed_precision_op_type == "conv2d_xpu") { + auto filter_name = mixed_precision_op->Op()->Input("filter")[0]; + auto filter_data_type = + scope->FindVar(filter_name)->GetMutable()->dtype(); + auto x_name = mixed_precision_op->Op()->Input("x")[0]; + auto* x_node = FindNodeWithName(graph, x_name); + if (filter_data_type == phi::DataType::INT8 && + x_node->Var()->GetDataType() == + proto::VarType::Type::VarType_Type_FP16) { + return; + } + } else if (mixed_precision_op_type == "fc_xpu") { + auto w_name = mixed_precision_op->Op()->Input("w")[0]; + auto w_data_type = + scope->FindVar(w_name)->GetMutable()->dtype(); + auto x_name = mixed_precision_op->Op()->Input("x")[0]; + auto* x_node = FindNodeWithName(graph, x_name); + if (w_data_type == phi::DataType::INT8 && + x_node->Var()->GetDataType() == + proto::VarType::Type::VarType_Type_FP16) { + return; + } + } mixed_precision_op->Op()->RenameOutput(cast_in->Name(), cast_out->Name()); int out_dtype = proto::VarType::Type::VarType_Type_FP32; mixed_precision_op->Op()->SetAttr("out_dtype", out_dtype); diff --git a/paddle/fluid/framework/ir/xpu/conv2d_transpose_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/conv2d_transpose_xpu_fuse_pass.cc index 784d5d4ec029f..51ebb63c563dc 100644 --- a/paddle/fluid/framework/ir/xpu/conv2d_transpose_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/conv2d_transpose_xpu_fuse_pass.cc @@ -377,8 +377,14 @@ int Conv2dTransposeXPUFusePass::ApplyImpl(ir::Graph* graph, // filter max Node* filter_int16 = nullptr; Node* filter_max = nullptr; - PrepareWeight( - graph, scope, block, conv_filter, &filter_int16, &filter_max, false); + PrepareWeight(graph, + scope, + block, + conv_filter, + &filter_int16, + &filter_max, + false, + std::vector({})); // output && output max std::string conv2d_xpu_out_name; if (!act_type.empty()) { diff --git a/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc index 502c275a419d3..89a558c6601f1 100644 --- a/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include "glog/logging.h" @@ -19,6 +20,7 @@ #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/quantize_helper.h" #include "paddle/fluid/framework/ir/xpu/pass_utils.h" #include "paddle/fluid/framework/ir/xpu/quant_utils.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -355,6 +357,57 @@ class Conv2dXPUFusePass : public FusePassBase { bool with_branch_x, bool with_branch_y) const; + Node* GetNodeFromNodesMap( + const std::map>& nodes_map, + std::string pattern_node_name, + std::string node_name) const; + + void CreateFusionWeightsAndBias( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + bool with_conv_bias, + bool with_bn, + bool with_scale, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const; + + void CreateFusionInputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const; + + void CreateFusionBranch( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const; + + void CreateFusionOutputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::string act_type, + std::unordered_map>* var_quant_scales) + const; + + const std::unordered_set support_quant_op_type_{"conv2d", + "conv2d_xpu"}; const std::string name_scope_{"conv2d_xpu_fuse_pass"}; }; @@ -401,6 +454,532 @@ void Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_subgraph_count); } +Node* Conv2dXPUFusePass::GetNodeFromNodesMap( + const std::map>& nodes_map, + std::string pattern_node_name, + std::string node_name) const { + auto iter = nodes_map.find(pattern_node_name); + PADDLE_ENFORCE_EQ( + iter != nodes_map.end(), + true, + platform::errors::InvalidArgument("nodes_map[%s] not found in nodes_map", + pattern_node_name.c_str())); + auto node_map = iter->second; + auto node_iter = node_map.find(node_name); + PADDLE_ENFORCE_EQ(node_iter != node_map.end(), + true, + platform::errors::InvalidArgument( + "nodes_map[%s][%s] not found in nodes_map", + pattern_node_name.c_str(), + node_name.c_str())); + return node_iter->second; +} + +void Conv2dXPUFusePass::CreateFusionWeightsAndBias( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + bool with_conv_bias, + bool with_bn, + bool with_scale, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const { + // Get Node + auto* conv = GetNodeFromNodesMap(nodes_map, "conv", "conv"); + PADDLE_ENFORCE_EQ( + conv != nullptr, + true, + platform::errors::InvalidArgument("conv node ptr can not be null")); + auto* conv_filter = GetNodeFromNodesMap(nodes_map, "conv", "conv_filter"); + PADDLE_ENFORCE_EQ(conv_filter != nullptr, + true, + platform::errors::InvalidArgument( + "conv_filter node ptr can not be null")); + + // transfilter fp16 --> fp32 + auto* filter_t = + scope->FindVar(conv_filter->Name())->GetMutable(); + auto filter_len = filter_t->numel(); + auto filter_dtype = filter_t->dtype(); + if (filter_dtype == phi::DataType::FLOAT16) { + CastToFp32(filter_t, nullptr); + } + + // Get Weight scale in int8 scene + std::vector weight_scale{}; + if (AreScalesPresentForNodes(var_quant_scales, {conv_filter})) { + weight_scale = GetScaleVecValueForNode(var_quant_scales, conv_filter); + } + // Create fusion_bias_node + auto filter_dims = filter_t->dims(); + Node* fusion_bias_node = nullptr; + if (with_conv_bias) { + auto* ew_bias_add_y = + GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add_y"); + PADDLE_ENFORCE_EQ(ew_bias_add_y != nullptr, + true, + platform::errors::InvalidArgument( + "ew_bias_add_y node ptr can not be null")); + auto* ew_bias_add_y_t = + scope->FindVar(ew_bias_add_y->Name())->GetMutable(); + auto ew_bias_add_y_dims = ew_bias_add_y_t->dims(); + PADDLE_ENFORCE_EQ(filter_dims[0], + ew_bias_add_y_dims[0], + platform::errors::InvalidArgument( + "the shape[%d] of elewise bias tensor " + "must equal out_channel[%d] of conv", + ew_bias_add_y_dims[0], + filter_dims[0])); + PrepareBias(graph, scope, block, ew_bias_add_y, &fusion_bias_node); + } + + if (with_bn) { + auto* bn = GetNodeFromNodesMap(nodes_map, "bn", "bn"); + PADDLE_ENFORCE_EQ( + bn != nullptr, + true, + platform::errors::InvalidArgument("bn node ptr can not be null")); + auto* bn_bias = GetNodeFromNodesMap(nodes_map, "bn", "bn_bias"); + PADDLE_ENFORCE_EQ( + bn_bias != nullptr, + true, + platform::errors::InvalidArgument("bn_bias node ptr can not be null")); + auto* bn_scale = GetNodeFromNodesMap(nodes_map, "bn", "bn_scale"); + PADDLE_ENFORCE_EQ( + bn_scale != nullptr, + true, + platform::errors::InvalidArgument("bn_scale node ptr can not be null")); + auto* bn_var = GetNodeFromNodesMap(nodes_map, "bn", "bn_var"); + PADDLE_ENFORCE_EQ( + bn_var != nullptr, + true, + platform::errors::InvalidArgument("bn_var node ptr can not be null")); + auto* bn_mean = GetNodeFromNodesMap(nodes_map, "bn", "bn_mean"); + PADDLE_ENFORCE_EQ( + bn_mean != nullptr, + true, + platform::errors::InvalidArgument("bn_mean node ptr can not be null")); + + auto bn_bias_t = + scope->Var(bn_bias->Name())->GetMutable(); + PADDLE_ENFORCE_EQ( + filter_dims[0], + bn_bias_t->dims()[0], + platform::errors::InvalidArgument("the shape[%d] of bn bias tensor " + "must equal out_channel[%d] of conv", + bn_bias_t->dims()[0], + filter_dims[0])); + auto bn_scale_t = + scope->Var(bn_scale->Name())->GetMutable(); + auto bn_mean_t = + scope->Var(bn_mean->Name())->GetMutable(); + auto bn_var_t = scope->Var(bn_var->Name())->GetMutable(); + float* bn_scale_ptr = bn_scale_t->data(); + float* bn_bias_ptr = bn_bias_t->data(); + float* bn_mean_ptr = bn_mean_t->data(); + float* bn_var_ptr = bn_var_t->data(); + auto mean_len = bn_mean_t->numel(); + auto filter_stride = filter_len / mean_len; + float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon")); + if (!with_conv_bias) { // prev node is conv + PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node); + } + + auto fusion_bias_t = + scope->Var(fusion_bias_node->Name())->GetMutable(); + float* fusion_bias_ptr = fusion_bias_t->data(); + // recompute bias and weights + for (int i = 0; i < mean_len; ++i) { + bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); + } + // recompute the weights + if (op_weights_precision != "int8") { + float* filter_ptr = filter_t->data(); + for (int i = 0; i < mean_len; ++i) { + for (int j = 0; j < filter_stride; j++) { + filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i]; + } + } + } else { + int8_t* filter_ptr = filter_t->data(); + PADDLE_ENFORCE_EQ( + weight_scale.size(), + mean_len, + platform::errors::InvalidArgument( + "Weight max_scale size must equal batch_norm sacle/mean size.")); + for (int i = 0; i < mean_len; i++) { + weight_scale[i] *= fabs(bn_scale_ptr[i]); + } + for (int i = 0; i < mean_len; i++) { + if (bn_scale_ptr[i] < 0) { + for (int j = 0; j < filter_stride; ++j) { + filter_ptr[i * filter_stride + j] *= -1; + } + } + } + } + // recompute bias + if (!with_conv_bias) { + for (int i = 0; i < mean_len; ++i) { + fusion_bias_ptr[i] += (0.0f - bn_mean_ptr[i]) * bn_scale_ptr[i]; + } + } else { + for (int i = 0; i < mean_len; ++i) { + fusion_bias_ptr[i] = + bn_bias_ptr[i] + + (fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i]; + } + } + } + + // deal with scale op + if (with_scale) { + auto* scale = GetNodeFromNodesMap(nodes_map, "scale", "scale"); + PADDLE_ENFORCE_EQ( + scale != nullptr, + true, + platform::errors::InvalidArgument("scale node ptr can not be null")); + auto bias_len = filter_dims[0]; + float scale_val_ = 1.f; + float bias_val_ = 0.f; + scale_val_ = PADDLE_GET_CONST(float, scale->Op()->GetAttr("scale")); + bias_val_ = PADDLE_GET_CONST(float, scale->Op()->GetAttr("bias")); + bool bias_after_scale_ = + PADDLE_GET_CONST(bool, scale->Op()->GetAttr("bias_after_scale")); + // recompute bias as scale op + auto fusion_bias_t = + scope->GetVar(fusion_bias_node->Name())->GetMutable(); + float* fusion_bias_ptr = fusion_bias_t->data(); + for (int i = 0; i < bias_len; ++i) { + if (bias_after_scale_) { + fusion_bias_ptr[i] = fusion_bias_ptr[i] * scale_val_ + bias_val_; + } else { + fusion_bias_ptr[i] = (fusion_bias_ptr[i] + bias_val_) * scale_val_; + } + } + // recompute weight as scale op + if (op_weights_precision != "int8") { + float* filter_ptr = filter_t->data(); + for (int i = 0; i < filter_len; ++i) { + filter_ptr[i] *= scale_val_; + } + } else { + for (size_t i = 0; i < weight_scale.size(); i++) { + weight_scale[i] *= scale_val_; + } + } + } + + (*fusion_nodes_map)["bias"] = fusion_bias_node; + + Node* filter_intx = nullptr; + Node* filter_max = nullptr; + Node* scale_max = nullptr; + if (op_weights_precision != "int8") { + PrepareWeight(graph, + scope, + block, + conv_filter, + &filter_intx, + &filter_max, + false, + weight_scale); + } else { + PrepareWeight(graph, + scope, + block, + conv_filter, + &filter_intx, + &filter_max, + false, + weight_scale); + } + + bool is_per_channel_need_create_scale_max_node = + !weight_scale.empty() && !IsPerTensorQuant(weight_scale); + if (is_per_channel_need_create_scale_max_node) { + phi::DenseTensor ones_weight_max_tensor; + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + int max_ptr_size = weight_scale.empty() + ? phi::backends::xpu::get_xpu_max_ptr_size(-1) + : weight_scale.size(); + ones_weight_max_tensor.set_type(phi::DataType::FLOAT32); + ones_weight_max_tensor.Resize({max_ptr_size}); + std::vector ones_weight(max_ptr_size, 1.0); + memcpy(cpu_ctx->Alloc(&ones_weight_max_tensor), + ones_weight.data(), + max_ptr_size * sizeof(float)); + + std::string scale_max_name = conv_filter->Name() + "_scale_max"; + VarDesc scale_max_desc(scale_max_name); + scale_max_desc.SetPersistable(true); + scale_max_desc.SetShape(vectorize(ones_weight_max_tensor.dims())); + scale_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + scale_max = graph->CreateVarNode(&scale_max_desc); + auto* block_scale_max_desc = block->Var(scale_max_name); + block_scale_max_desc->SetPersistable(scale_max_desc.Persistable()); + block_scale_max_desc->SetShape(scale_max_desc.GetShape()); + block_scale_max_desc->SetDataType(scale_max_desc.GetDataType()); + Assign(ones_weight_max_tensor, + scope->Var(scale_max_name)->GetMutable()); + } + + (*fusion_nodes_map)["filter"] = filter_intx; + if (is_per_channel_need_create_scale_max_node) { + (*fusion_nodes_map)["filter_max"] = scale_max; + (*fusion_nodes_map)["scale_max"] = filter_max; + } else { + (*fusion_nodes_map)["filter_max"] = filter_max; + (*fusion_nodes_map)["scale_max"] = scale_max; + } +} + +void Conv2dXPUFusePass::CreateFusionInputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const { + // Get Node + auto* conv = GetNodeFromNodesMap(nodes_map, "conv", "conv"); + PADDLE_ENFORCE_EQ( + conv != nullptr, + true, + platform::errors::InvalidArgument("conv node ptr can not be null")); + auto* input = GetNodeFromNodesMap(nodes_map, "conv", "input"); + PADDLE_ENFORCE_EQ( + input != nullptr, + true, + platform::errors::InvalidArgument("conv input node ptr can not be null")); + // input max + std::string conv_input_max_name = input->Name() + "_input_max"; + Node* conv2d_xpu_input_max = nullptr; + if (op_weights_precision == "int8") { + PADDLE_ENFORCE_EQ(AreScalesPresentForNodes(var_quant_scales, {input}), + true, + platform::errors::InvalidArgument( + "When conv op is running in int8 precision, the " + "scales of input var should be present in!")); + float input_scale = GetScaleValueForNode(var_quant_scales, input); + int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1); + VarDesc conv_input_max_desc(conv_input_max_name); + conv_input_max_desc.SetPersistable(true); + conv_input_max_desc.SetShape({static_cast(max_ptr_size)}); + conv_input_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + conv2d_xpu_input_max = graph->CreateVarNode(&conv_input_max_desc); + auto input_max_tensor = + scope->Var(conv_input_max_name)->GetMutable(); + input_max_tensor->set_type(phi::DataType::FLOAT32); + input_max_tensor->Resize({max_ptr_size}); + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + std::vector input_scales(max_ptr_size, input_scale); + memcpy(cpu_ctx->Alloc(input_max_tensor), + input_scales.data(), + max_ptr_size * sizeof(float)); + } + (*fusion_nodes_map)["x"] = input; + (*fusion_nodes_map)["x_max"] = conv2d_xpu_input_max; +} + +void Conv2dXPUFusePass::CreateFusionBranch( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const { + // Get Node + auto* ew_branch_add = + GetNodeFromNodesMap(nodes_map, "ew_branch_add", "ew_branch_add"); + if (ew_branch_add) { + auto* ew_branch_add_in = + GetNodeFromNodesMap(nodes_map, "ew_branch_add", "ew_branch_add_in"); + PADDLE_ENFORCE_EQ(ew_branch_add_in != nullptr, + true, + platform::errors::InvalidArgument( + "ew_branch_add_in node ptr can not be null")); + (*fusion_nodes_map)["branch"] = ew_branch_add_in; + // ew_branch_add_max + std::string ew_branch_add_max_name = + ew_branch_add_in->Name() + "branch_max"; + Node* ew_branch_add_max = FindNodeWithName(graph, ew_branch_add_max_name); + if (op_weights_precision == "int8" && !ew_branch_add_max) { + int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1); + VarDesc ew_branch_add_in_max_desc(ew_branch_add_max_name); + ew_branch_add_in_max_desc.SetPersistable(true); + ew_branch_add_in_max_desc.SetShape({static_cast(max_ptr_size)}); + ew_branch_add_in_max_desc.SetDataType( + proto::VarType::Type::VarType_Type_FP32); + ew_branch_add_max = graph->CreateVarNode(&ew_branch_add_in_max_desc); + PADDLE_ENFORCE_EQ( + AreScalesPresentForNodes(var_quant_scales, {ew_branch_add_in}), + true, + platform::errors::InvalidArgument( + "When conv op is running in int8 precision with branch add, the " + "scales of branch var should be present in!")); + float ew_branch_add_scale = + GetScaleValueForNode(var_quant_scales, ew_branch_add_in); + auto* conv = GetNodeFromNodesMap(nodes_map, "conv", "conv"); + PADDLE_ENFORCE_EQ( + conv != nullptr, + true, + platform::errors::InvalidArgument("conv node ptr can not be null")); + auto ew_branch_add_max_tensor = + scope->Var(ew_branch_add_max_name)->GetMutable(); + ew_branch_add_max_tensor->set_type(phi::DataType::FLOAT32); + ew_branch_add_max_tensor->Resize({max_ptr_size}); + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + std::vector ew_branch_add_scales(max_ptr_size, + ew_branch_add_scale); + memcpy(cpu_ctx->Alloc(ew_branch_add_max_tensor), + ew_branch_add_scales.data(), + max_ptr_size * sizeof(float)); + } + (*fusion_nodes_map)["branch_max"] = ew_branch_add_max; + } +} + +void Conv2dXPUFusePass::CreateFusionOutputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::string act_type, + std::unordered_map>* var_quant_scales) + const { + auto* conv = GetNodeFromNodesMap(nodes_map, "conv", "conv"); + PADDLE_ENFORCE_EQ( + conv != nullptr, + true, + platform::errors::InvalidArgument("conv node ptr can not be null")); + // output && output max + std::string conv2d_xpu_out_name; + Node* conv2d_out_var_node = nullptr; + + auto* ew_branch_add = + GetNodeFromNodesMap(nodes_map, "ew_branch_add", "ew_branch_add"); + auto* bn = GetNodeFromNodesMap(nodes_map, "bn", "bn"); + auto* scale = GetNodeFromNodesMap(nodes_map, "scale", "scale"); + auto* ew_bias_add = + GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add"); + if (!act_type.empty()) { + auto* act_out = GetNodeFromNodesMap(nodes_map, "act", "act_out"); + PADDLE_ENFORCE_EQ( + act_out != nullptr, + true, + platform::errors::InvalidArgument("act_out node ptr can not be null")); + conv2d_xpu_out_name = act_out->Name(); + conv2d_out_var_node = act_out; + auto* act = GetNodeFromNodesMap(nodes_map, "act", "act"); + PADDLE_ENFORCE_EQ( + act != nullptr, + true, + platform::errors::InvalidArgument("act node ptr can not be null")); + } else if (ew_branch_add) { + auto* ew_branch_add_out = + GetNodeFromNodesMap(nodes_map, "ew_branch_add", "ew_branch_add_out"); + PADDLE_ENFORCE_EQ(ew_branch_add_out != nullptr, + true, + platform::errors::InvalidArgument( + "ew_branch_add_out node ptr can not be null")); + conv2d_xpu_out_name = ew_branch_add_out->Name(); + conv2d_out_var_node = ew_branch_add_out; + PADDLE_ENFORCE_EQ(ew_branch_add != nullptr, + true, + platform::errors::InvalidArgument( + "ew_branch_add node ptr can not be null")); + } else if (scale) { + auto* scale_out = GetNodeFromNodesMap(nodes_map, "scale", "scale_out"); + PADDLE_ENFORCE_EQ(scale_out != nullptr, + true, + platform::errors::InvalidArgument( + "scale_out node ptr can not be null")); + conv2d_xpu_out_name = scale_out->Name(); + conv2d_out_var_node = scale_out; + } else if (bn) { + auto* bn_out = GetNodeFromNodesMap(nodes_map, "bn", "bn_out"); + PADDLE_ENFORCE_EQ( + bn_out != nullptr, + true, + platform::errors::InvalidArgument("bn_out node ptr can not be null")); + conv2d_xpu_out_name = bn_out->Name(); + conv2d_out_var_node = bn_out; + } else if (ew_bias_add) { + auto* ew_bias_add_out = + GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add_out"); + PADDLE_ENFORCE_EQ(ew_bias_add_out != nullptr, + true, + platform::errors::InvalidArgument( + "ew_bias_add_out node ptr can not be null")); + conv2d_xpu_out_name = ew_bias_add_out->Name(); + conv2d_out_var_node = ew_bias_add_out; + } else { + auto* conv_out = GetNodeFromNodesMap(nodes_map, "conv", "conv_out"); + PADDLE_ENFORCE_EQ( + conv_out != nullptr, + true, + platform::errors::InvalidArgument("conv_out node ptr can not be null")); + conv2d_xpu_out_name = conv_out->Name(); + conv2d_out_var_node = conv_out; + auto* conv = GetNodeFromNodesMap(nodes_map, "conv", "conv"); + PADDLE_ENFORCE_EQ( + conv != nullptr, + true, + platform::errors::InvalidArgument("conv node ptr can not be null")); + } + (*fusion_nodes_map)["out"] = conv2d_out_var_node; + + // Create out max in + if (op_weights_precision == "int8" && + AreScalesPresentForNodes(var_quant_scales, {conv2d_out_var_node})) { + std::string conv_out_max_in_name = conv2d_xpu_out_name + "_max_in"; + int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1); + VarDesc conv_out_max_in_desc(conv_out_max_in_name); + conv_out_max_in_desc.SetPersistable(true); + conv_out_max_in_desc.SetShape({static_cast(max_ptr_size)}); + conv_out_max_in_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + Node* conv2d_xpu_out_max_in = graph->CreateVarNode(&conv_out_max_in_desc); + auto* block_out_max_in_desc = block->Var(conv_out_max_in_name); + block_out_max_in_desc->SetPersistable(conv_out_max_in_desc.Persistable()); + block_out_max_in_desc->SetShape(conv_out_max_in_desc.GetShape()); + block_out_max_in_desc->SetDataType(conv_out_max_in_desc.GetDataType()); + + float output_scale = + GetScaleValueForNode(var_quant_scales, conv2d_out_var_node); + phi::DenseTensor out_max_in_cpu_tensor; + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + out_max_in_cpu_tensor.set_type(phi::DataType::FLOAT32); + out_max_in_cpu_tensor.Resize({max_ptr_size}); + std::vector output_scales(max_ptr_size, output_scale); + memcpy(cpu_ctx->Alloc(&out_max_in_cpu_tensor), + output_scales.data(), + max_ptr_size * sizeof(float)); + Assign(out_max_in_cpu_tensor, + scope->Var(conv_out_max_in_name)->GetMutable()); + (*fusion_nodes_map)["out_max_in"] = conv2d_xpu_out_max_in; + } + + // Create out max + std::string conv_out_max_name = conv2d_xpu_out_name + "_max"; + VarDesc conv_out_max_desc(conv_out_max_name); + Node* conv2d_xpu_out_max = graph->CreateVarNode(&conv_out_max_desc); + (*fusion_nodes_map)["out_max"] = conv2d_xpu_out_max; +} + int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, const std::string& conv_type, const std::string& act_type, @@ -419,18 +998,23 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, with_scale, with_branch_x, with_branch_y); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + std::unordered_map> var_quant_scales = + GetQuantInfoFromTheGraph(graph, "has_quant_info", "var_quant_scales"); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "handle Conv2dXPUFusePass fuse"; - /* declare operator node's name */ + std::map> nodes_map; GET_IR_NODE(conv); GET_IR_NODE(ew_bias_add); GET_IR_NODE(bn); GET_IR_NODE(scale); GET_IR_NODE(ew_branch_add); GET_IR_NODE(act); - /* declare variable node's name*/ + /* Get variable node's name*/ GET_IR_NODE(input); GET_IR_NODE(conv_filter); GET_IR_NODE(conv_out); @@ -449,166 +1033,132 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, GET_IR_NODE(ew_branch_add_in); GET_IR_NODE(ew_branch_add_out); GET_IR_NODE(act_out); + + nodes_map.insert({"conv", + {{"conv", conv}, + {"conv_filter", conv_filter}, + {"input", input}, + {"conv_out", conv_out}}}); + nodes_map.insert({"ew_bias_add", + {{"ew_bias_add", ew_bias_add}, + {"ew_bias_add_y", ew_bias_add_y}, + {"ew_bias_add_out", ew_bias_add_out}}}); + nodes_map.insert({"bn", + {{"bn", bn}, + {"bn_bias", bn_bias}, + {"bn_mean", bn_mean}, + {"bn_scale", bn_scale}, + {"bn_var", bn_var}, + {"bn_out", bn_out}, + {"bn_var_out", bn_var_out}, + {"bn_mean_out", bn_mean_out}, + {"bn_saved_var", bn_saved_var}, + {"bn_saved_mean", bn_saved_mean}}}); + nodes_map.insert({"scale", {{"scale", scale}, {"scale_out", scale_out}}}); + nodes_map.insert({"ew_branch_add", + {{"ew_branch_add", ew_branch_add}, + {"ew_branch_add_in", ew_branch_add_in}, + {"ew_branch_add_out", ew_branch_add_out}}}); + nodes_map.insert({"act", {{"act", act}, {"act_out", act_out}}}); + + std::map fusion_nodes_map{{"x", nullptr}, + {"x_max", nullptr}, + {"filter", nullptr}, + {"filter_max", nullptr}, + {"bias", nullptr}, + {"branch", nullptr}, + {"branch_max", nullptr}, + {"scale_max", nullptr}, + {"out_max_in", nullptr}, + {"out", nullptr}, + {"out_max", nullptr}}; + + auto filter_data_type = scope->FindVar(conv_filter->Name()) + ->GetMutable() + ->dtype(); + std::string op_weights_precision = "float32"; + if (filter_data_type == phi::DataType::INT8) { + op_weights_precision = "int8"; + } else if (filter_data_type == phi::DataType::FLOAT16) { + op_weights_precision = "float16"; + } + VLOG(4) << "Conv2d fusion fuse pass is running on " << op_weights_precision + << " precision!"; auto* block = conv->Op()->Block(); - auto* scope = param_scope(); - PADDLE_ENFORCE_NOT_NULL( - scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); - - // recompute bias and weight for conv2d_xpu op - auto* filter_t = - scope->FindVar(conv_filter->Name())->GetMutable(); - // conv_filter fp16 --> fp32 - auto filter_len = filter_t->numel(); - auto filter_dtype = filter_t->dtype(); - int out_dtype = proto::VarType::Type::VarType_Type_FP32; - if (filter_dtype == phi::DataType::FLOAT16) { - out_dtype = proto::VarType::Type::VarType_Type_FP16; - CastToFp32(filter_t, nullptr); - } - - auto filter_dims = filter_t->dims(); - bool has_bias = with_bn || with_conv_bias; - // Create conv_fusion_bias (conv bias) variable - Node* fusion_bias_node = nullptr; - if (has_bias) { - if (with_conv_bias) { - auto* ew_bias_add_y_t = scope->FindVar(ew_bias_add_y->Name()) - ->GetMutable(); - auto ew_bias_add_y_dims = ew_bias_add_y_t->dims(); - PADDLE_ENFORCE_EQ(filter_dims[0], - ew_bias_add_y_dims[0], - platform::errors::InvalidArgument( - "the shape[%d] of elewise bias tensor " - "must equal out_channel[%d] of conv", - ew_bias_add_y_dims[0], - filter_dims[0])); - PrepareBias(graph, scope, block, ew_bias_add_y, &fusion_bias_node); - } - if (with_bn) { - auto bn_bias_t = - scope->Var(bn_bias->Name())->GetMutable(); - PADDLE_ENFORCE_EQ(filter_dims[0], - bn_bias_t->dims()[0], - platform::errors::InvalidArgument( - "the shape[%d] of bn bias tensor " - "must equal out_channel[%d] of conv", - bn_bias_t->dims()[0], - filter_dims[0])); - auto bn_scale_t = - scope->Var(bn_scale->Name())->GetMutable(); - auto bn_mean_t = - scope->Var(bn_mean->Name())->GetMutable(); - auto bn_var_t = - scope->Var(bn_var->Name())->GetMutable(); - float* filter_ptr = - filter_t->mutable_data(paddle::platform::CPUPlace()); - float* bn_scale_ptr = - bn_scale_t->mutable_data(paddle::platform::CPUPlace()); - float* bn_bias_ptr = - bn_bias_t->mutable_data(paddle::platform::CPUPlace()); - float* bn_mean_ptr = - bn_mean_t->mutable_data(paddle::platform::CPUPlace()); - float* bn_var_ptr = - bn_var_t->mutable_data(paddle::platform::CPUPlace()); - auto mean_len = bn_mean_t->numel(); - auto filter_stride = filter_len / mean_len; - float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon")); - if (!with_conv_bias) { // prev node is conv - PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node); - } - auto fusion_bias_t = scope->Var(fusion_bias_node->Name()) - ->GetMutable(); - float* fusion_bias_ptr = - fusion_bias_t->mutable_data(paddle::platform::CPUPlace()); - // recompute bias and weights - if (!with_conv_bias) { // prev node is conv - for (int i = 0; i < mean_len; ++i) { - bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); - fusion_bias_ptr[i] += (0.0f - bn_mean_ptr[i]) * bn_scale_ptr[i]; - for (int j = 0; j < filter_stride; j++) { - filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i]; - } - } - } else { - for (int i = 0; i < mean_len; ++i) { - bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); - fusion_bias_ptr[i] = - bn_bias_ptr[i] + - (fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i]; - for (int j = 0; j < filter_stride; j++) { - filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i]; - } - } - } - } + CreateFusionWeightsAndBias(graph, + scope, + block, + nodes_map, + &fusion_nodes_map, + with_conv_bias, + with_bn, + with_scale, + op_weights_precision, + &var_quant_scales); + CreateFusionInputs(graph, + scope, + block, + nodes_map, + &fusion_nodes_map, + op_weights_precision, + &var_quant_scales); + CreateFusionBranch(graph, + scope, + block, + nodes_map, + &fusion_nodes_map, + op_weights_precision, + &var_quant_scales); + CreateFusionOutputs(graph, + scope, + block, + nodes_map, + &fusion_nodes_map, + op_weights_precision, + act_type, + &var_quant_scales); + + framework::OpDesc conv2d_xpu_op_desc(block); + conv2d_xpu_op_desc.SetType("conv2d_xpu"); + conv2d_xpu_op_desc.SetInput("x", {fusion_nodes_map["x"]->Name()}); + if (fusion_nodes_map["x_max"]) { + conv2d_xpu_op_desc.SetInput("x_max", {fusion_nodes_map["x_max"]->Name()}); } - // deal with scale op - if (with_scale) { - auto bias_len = filter_dims[0]; - float scale_val_ = 1.f; - float bias_val_ = 0.f; - scale_val_ = PADDLE_GET_CONST(float, scale->Op()->GetAttr("scale")); - bias_val_ = PADDLE_GET_CONST(float, scale->Op()->GetAttr("bias")); - bool bias_after_scale_ = - PADDLE_GET_CONST(bool, scale->Op()->GetAttr("bias_after_scale")); - // recompute bias as scale op - auto fusion_bias_t = scope->GetVar(fusion_bias_node->Name()) - ->GetMutable(); - float* fusion_bias_ptr = - fusion_bias_t->mutable_data(paddle::platform::CPUPlace()); - for (int i = 0; i < bias_len; ++i) { - if (bias_after_scale_) { - fusion_bias_ptr[i] = fusion_bias_ptr[i] * scale_val_ + bias_val_; - } else { - fusion_bias_ptr[i] = (fusion_bias_ptr[i] + bias_val_) * scale_val_; - } - } - // recompute weight as scale op - float* filter_ptr = - filter_t->mutable_data(paddle::platform::CPUPlace()); - for (int i = 0; i < filter_len; ++i) { - filter_ptr[i] *= scale_val_; - } + conv2d_xpu_op_desc.SetInput("filter", {fusion_nodes_map["filter"]->Name()}); + conv2d_xpu_op_desc.SetInput("filter_max", + {fusion_nodes_map["filter_max"]->Name()}); + if (fusion_nodes_map["scale_max"]) { + conv2d_xpu_op_desc.SetInput("scale_max", + {fusion_nodes_map["scale_max"]->Name()}); } - // filter max - Node* filter_int16 = nullptr; - Node* filter_max = nullptr; - PrepareWeight( - graph, scope, block, conv_filter, &filter_int16, &filter_max, false); - // output && output max - std::string conv2d_xpu_out_name; - if (!act_type.empty()) { - conv2d_xpu_out_name = act_out->Name(); - } else if (ew_branch_add) { - conv2d_xpu_out_name = ew_branch_add_out->Name(); - } else if (scale) { - conv2d_xpu_out_name = scale_out->Name(); - } else if (bn) { - conv2d_xpu_out_name = bn_out->Name(); - } else if (ew_bias_add) { - conv2d_xpu_out_name = ew_bias_add_out->Name(); - } else { - conv2d_xpu_out_name = conv_out->Name(); + if (fusion_nodes_map["out_max_in"]) { + conv2d_xpu_op_desc.SetInput("out_max_in", + {fusion_nodes_map["out_max_in"]->Name()}); } - std::string conv2d_xpu_out_max_name = conv2d_xpu_out_name + "_max"; - VarDesc conv2d_xpu_out_max_desc(conv2d_xpu_out_max_name); - Node* conv2d_xpu_out_max = graph->CreateVarNode(&conv2d_xpu_out_max_desc); - // Generate conv2d_xpu op - framework::OpDesc conv2d_xpu_op_desc(block); - // set input&output var - conv2d_xpu_op_desc.SetType("conv2d_xpu"); - conv2d_xpu_op_desc.SetInput("x", {input->Name()}); - conv2d_xpu_op_desc.SetInput("filter", {filter_int16->Name()}); - conv2d_xpu_op_desc.SetInput("filter_max", {filter_max->Name()}); - conv2d_xpu_op_desc.SetOutput("out", {conv2d_xpu_out_name}); - conv2d_xpu_op_desc.SetOutput("out_max", {conv2d_xpu_out_max_name}); - // set fusion_bias input node - if (has_bias) { - conv2d_xpu_op_desc.SetInput("bias", {fusion_bias_node->Name()}); + conv2d_xpu_op_desc.SetOutput("out", {fusion_nodes_map["out"]->Name()}); + conv2d_xpu_op_desc.SetOutput("out_max", + {fusion_nodes_map["out_max"]->Name()}); + if (with_conv_bias || with_bn) { + PADDLE_ENFORCE_EQ( + fusion_nodes_map["bias"] != nullptr, + true, + platform::errors::InvalidArgument( + "fusion_nodes_map['bias'] node ptr can not be null")); + conv2d_xpu_op_desc.SetInput("bias", {fusion_nodes_map["bias"]->Name()}); } // set ew_branch_add input node if (ew_branch_add != nullptr) { - conv2d_xpu_op_desc.SetInput("branch", {ew_branch_add_in->Name()}); + PADDLE_ENFORCE_EQ( + fusion_nodes_map["branch"] != nullptr, + true, + platform::errors::InvalidArgument( + "fusion_nodes_map['branch'] node ptr can not be null")); + conv2d_xpu_op_desc.SetInput("branch", + {fusion_nodes_map["branch"]->Name()}); + if (fusion_nodes_map["branch_max"]) { + conv2d_xpu_op_desc.SetInput("branch_max", + {fusion_nodes_map["branch_max"]->Name()}); + } } // set attrs of conv2d_xpu float act_param_ = 0.0f; @@ -646,57 +1196,54 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, "strides", PADDLE_GET_CONST(std::vector, conv->Op()->GetAttr("strides"))); conv2d_xpu_op_desc.SetAttr("paddings", conv_paddings); - conv2d_xpu_op_desc.SetAttr("out_dtype", out_dtype); + // out_dtype is same to input precision + conv2d_xpu_op_desc.SetAttr("out_dtype", + fusion_nodes_map["x"]->Var()->GetDataType()); + // Link node auto* conv2d_xpu = graph->CreateOpNode(&conv2d_xpu_op_desc); - IR_NODE_LINK_TO(input, conv2d_xpu); - IR_NODE_LINK_TO(filter_int16, conv2d_xpu); - IR_NODE_LINK_TO(filter_max, conv2d_xpu); - if (ew_bias_add || bn) { - SAFE_IR_NODE_LINK_TO(fusion_bias_node, conv2d_xpu); - } - if (ew_branch_add_in) { - IR_NODE_LINK_TO(ew_branch_add_in, conv2d_xpu); - } - if (act_out) { - IR_NODE_LINK_TO(conv2d_xpu, act_out); - } else if (ew_branch_add_out) { - IR_NODE_LINK_TO(conv2d_xpu, ew_branch_add_out); - } else if (scale_out) { - IR_NODE_LINK_TO(conv2d_xpu, scale_out); - } else if (bn_out) { - IR_NODE_LINK_TO(conv2d_xpu, bn_out); - } else if (ew_bias_add_out) { - IR_NODE_LINK_TO(conv2d_xpu, ew_bias_add_out); - } else { - IR_NODE_LINK_TO(conv2d_xpu, conv_out); + IR_NODE_LINK_TO(fusion_nodes_map["x"], conv2d_xpu); + if (fusion_nodes_map["x_max"]) { + IR_NODE_LINK_TO(fusion_nodes_map["x_max"], conv2d_xpu); } - IR_NODE_LINK_TO(conv2d_xpu, conv2d_xpu_out_max); - // delete useless node - std::unordered_set delete_nodes = {conv}; - if (act != nullptr) { - delete_nodes.insert(act); + IR_NODE_LINK_TO(fusion_nodes_map["filter"], conv2d_xpu); + IR_NODE_LINK_TO(fusion_nodes_map["filter_max"], conv2d_xpu); + if (fusion_nodes_map["scale_max"]) { + IR_NODE_LINK_TO(fusion_nodes_map["scale_max"], conv2d_xpu); } - if (ew_branch_add != nullptr) { - delete_nodes.insert(ew_branch_add); + if (fusion_nodes_map["bias"]) { + SAFE_IR_NODE_LINK_TO(fusion_nodes_map["bias"], conv2d_xpu); + } + if (fusion_nodes_map["branch"]) { + IR_NODE_LINK_TO(fusion_nodes_map["branch"], conv2d_xpu); + } + if (fusion_nodes_map["branch_max"]) { + IR_NODE_LINK_TO(fusion_nodes_map["branch_max"], conv2d_xpu); + } + if (fusion_nodes_map["out_max_in"]) { + IR_NODE_LINK_TO(fusion_nodes_map["out_max_in"], conv2d_xpu); + } + IR_NODE_LINK_TO(conv2d_xpu, fusion_nodes_map["out"]); + IR_NODE_LINK_TO(conv2d_xpu, fusion_nodes_map["out_max"]); + // delete useless node + std::unordered_set delete_nodes; + if (conv != nullptr) { + delete_nodes.insert(conv); } if (scale != nullptr) { delete_nodes.insert(scale); } if (bn != nullptr) { delete_nodes.insert(bn); - delete_nodes.insert(bn_bias); - delete_nodes.insert(bn_var); - delete_nodes.insert(bn_mean); - delete_nodes.insert(bn_scale); - delete_nodes.insert(bn_var_out); - delete_nodes.insert(bn_mean_out); - delete_nodes.insert(bn_saved_var); - delete_nodes.insert(bn_saved_mean); } if (ew_bias_add != nullptr) { delete_nodes.insert(ew_bias_add); - delete_nodes.insert(ew_bias_add_y); + } + if (ew_branch_add != nullptr) { + delete_nodes.insert(ew_branch_add); + } + if (act != nullptr) { + delete_nodes.insert(act); } GraphSafeRemoveNodes(graph, delete_nodes); found_subgraph_count++; diff --git a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc index 4c8424b7df08f..373275706700f 100644 --- a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/quantize_helper.h" #include "paddle/fluid/framework/ir/xpu/pass_utils.h" #include "paddle/fluid/framework/ir/xpu/quant_utils.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -244,9 +245,68 @@ class FcXPUFusePass : public FusePassBase { bool with_bn, const std::string& act_type) const; + void CreateFusionWeightsAndBias( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + std::string mul_type, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + bool with_bias, + bool with_bn, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const; + + void CreateFusionOutputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const; + + void CreateFusionInputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const; + + Node* GetNodeFromNodesMap( + const std::map>& nodes_map, + std::string pattern_node_name, + std::string node_name) const; + const std::string name_scope_{"fc_xpu_fuse_pass"}; }; +Node* FcXPUFusePass::GetNodeFromNodesMap( + const std::map>& nodes_map, + std::string pattern_node_name, + std::string node_name) const { + auto iter = nodes_map.find(pattern_node_name); + PADDLE_ENFORCE_EQ( + iter != nodes_map.end(), + true, + platform::errors::InvalidArgument("nodes_map[%s] not found in nodes_map", + pattern_node_name.c_str())); + auto node_map = iter->second; + auto node_iter = node_map.find(node_name); + PADDLE_ENFORCE_EQ(node_iter != node_map.end(), + true, + platform::errors::InvalidArgument( + "nodes_map[%s][%s] not found in nodes_map", + pattern_node_name.c_str(), + node_name.c_str())); + return node_iter->second; +} + void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); @@ -275,6 +335,368 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_subgraph_count); } +void FcXPUFusePass::CreateFusionWeightsAndBias( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + std::string mul_type, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + bool with_bias, + bool with_bn, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const { + // Get Node + auto* mul = GetNodeFromNodesMap(nodes_map, "mul", "mul"); + PADDLE_ENFORCE_EQ( + mul != nullptr, + true, + platform::errors::InvalidArgument("mul node ptr can not be null")); + auto* mul_w = GetNodeFromNodesMap(nodes_map, "mul", "mul_w"); + PADDLE_ENFORCE_EQ( + mul_w != nullptr, + true, + platform::errors::InvalidArgument("mul_w node ptr can not be null")); + + // transfilter fp16 --> fp32 + auto* filter_t = + scope->FindVar(mul_w->Name())->GetMutable(); + auto filter_len = filter_t->numel(); + auto filter_dtype = filter_t->dtype(); + if (filter_dtype == phi::DataType::FLOAT16) { + CastToFp32(filter_t, nullptr); + } + + bool transpose_w = false; + if (mul_type == "matmul") { + transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y")); + } else if (mul_type == "matmul_v2") { + transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y")); + } + // Get Weight scale in int8 scene + std::vector weight_scale{}; + if (AreScalesPresentForNodes(var_quant_scales, {mul_w})) { + weight_scale = GetScaleVecValueForNode(var_quant_scales, mul_w); + } + // Create fusion_bias_node + Node* fusion_bias_node = nullptr; + if (with_bias) { + auto* ew_bias_add_bias = + GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add_bias"); + PADDLE_ENFORCE_EQ(ew_bias_add_bias != nullptr, + true, + platform::errors::InvalidArgument( + "ew_bias_add_bias node ptr can not be null")); + PrepareBias(graph, scope, block, ew_bias_add_bias, &fusion_bias_node); + } + + if (with_bn) { + auto* bn = GetNodeFromNodesMap(nodes_map, "bn", "bn"); + PADDLE_ENFORCE_EQ( + bn != nullptr, + true, + platform::errors::InvalidArgument("bn node ptr can not be null")); + auto* bn_bias = GetNodeFromNodesMap(nodes_map, "bn", "bn_bias"); + PADDLE_ENFORCE_EQ( + bn_bias != nullptr, + true, + platform::errors::InvalidArgument("bn_bias node ptr can not be null")); + auto* bn_scale = GetNodeFromNodesMap(nodes_map, "bn", "bn_scale"); + PADDLE_ENFORCE_EQ( + bn_scale != nullptr, + true, + platform::errors::InvalidArgument("bn_scale node ptr can not be null")); + auto* bn_var = GetNodeFromNodesMap(nodes_map, "bn", "bn_var"); + PADDLE_ENFORCE_EQ( + bn_var != nullptr, + true, + platform::errors::InvalidArgument("bn_var node ptr can not be null")); + auto* bn_mean = GetNodeFromNodesMap(nodes_map, "bn", "bn_mean"); + PADDLE_ENFORCE_EQ( + bn_mean != nullptr, + true, + platform::errors::InvalidArgument("bn_mean node ptr can not be null")); + + auto bn_bias_t = + scope->Var(bn_bias->Name())->GetMutable(); + auto bn_scale_t = + scope->Var(bn_scale->Name())->GetMutable(); + auto bn_mean_t = + scope->Var(bn_mean->Name())->GetMutable(); + auto bn_var_t = scope->Var(bn_var->Name())->GetMutable(); + float* bn_scale_ptr = bn_scale_t->data(); + float* bn_bias_ptr = bn_bias_t->data(); + float* bn_mean_ptr = bn_mean_t->data(); + float* bn_var_ptr = bn_var_t->data(); + auto mean_len = bn_mean_t->numel(); + auto filter_stride = filter_len / mean_len; + float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon")); + if (!with_bias) { // prev node is conv + PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node); + } + + auto fusion_bias_t = + scope->Var(fusion_bias_node->Name())->GetMutable(); + float* fusion_bias_ptr = fusion_bias_t->data(); + // recompute bias and weights + for (int i = 0; i < mean_len; ++i) { + bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); + } + // recompute the weights + if (op_weights_precision != "int8") { + float* filter_ptr = filter_t->data(); + for (int i = 0; i < mean_len; ++i) { + for (int j = 0; j < filter_stride; j++) { + filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i]; + } + } + } else { + int8_t* filter_ptr = filter_t->data(); + PADDLE_ENFORCE_EQ( + weight_scale.size(), + mean_len, + platform::errors::InvalidArgument( + "Weight max_scale size must equal batch_norm sacle/mean size.")); + for (int i = 0; i < mean_len; i++) { + weight_scale[i] *= fabs(bn_scale_ptr[i]); + } + for (int i = 0; i < mean_len; i++) { + if (bn_scale_ptr[i] < 0) { + for (int j = 0; j < filter_stride; ++j) { + filter_ptr[i * filter_stride + j] *= -1; + } + } + } + } + // recompute bias + if (!with_bias) { + for (int i = 0; i < mean_len; ++i) { + fusion_bias_ptr[i] += (0.0f - bn_mean_ptr[i]) * bn_scale_ptr[i]; + } + } else { + for (int i = 0; i < mean_len; ++i) { + fusion_bias_ptr[i] = + bn_bias_ptr[i] + + (fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i]; + } + } + } + + (*fusion_nodes_map)["bias"] = fusion_bias_node; + + Node* filter_intx = nullptr; + Node* filter_max = nullptr; + Node* scale_max = nullptr; + if (op_weights_precision != "int8") { + PrepareWeight(graph, + scope, + block, + mul_w, + &filter_intx, + &filter_max, + !transpose_w, + weight_scale); + } else { + PrepareWeight(graph, + scope, + block, + mul_w, + &filter_intx, + &filter_max, + !transpose_w, + weight_scale); + } + + bool is_per_channel_need_create_scale_max_node = + !weight_scale.empty() && !IsPerTensorQuant(weight_scale); + if (is_per_channel_need_create_scale_max_node) { + phi::DenseTensor ones_weight_max_tensor; + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + int max_ptr_size = weight_scale.empty() + ? phi::backends::xpu::get_xpu_max_ptr_size(-1) + : weight_scale.size(); + ones_weight_max_tensor.set_type(phi::DataType::FLOAT32); + ones_weight_max_tensor.Resize({max_ptr_size}); + std::vector ones_weight(max_ptr_size, 1.0); + memcpy(cpu_ctx->Alloc(&ones_weight_max_tensor), + ones_weight.data(), + max_ptr_size * sizeof(float)); + + std::string scale_max_name = mul_w->Name() + "_scale_max"; + VarDesc scale_max_desc(scale_max_name); + scale_max_desc.SetPersistable(true); + scale_max_desc.SetShape(vectorize(ones_weight_max_tensor.dims())); + scale_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + scale_max = graph->CreateVarNode(&scale_max_desc); + auto* block_scale_max_desc = block->Var(scale_max_name); + block_scale_max_desc->SetPersistable(scale_max_desc.Persistable()); + block_scale_max_desc->SetShape(scale_max_desc.GetShape()); + block_scale_max_desc->SetDataType(scale_max_desc.GetDataType()); + Assign(ones_weight_max_tensor, + scope->Var(scale_max_name)->GetMutable()); + } + + (*fusion_nodes_map)["w"] = filter_intx; + if (is_per_channel_need_create_scale_max_node) { + (*fusion_nodes_map)["w_max"] = scale_max; + (*fusion_nodes_map)["scale_max"] = filter_max; + } else { + (*fusion_nodes_map)["w_max"] = filter_max; + (*fusion_nodes_map)["scale_max"] = scale_max; + } +} + +void FcXPUFusePass::CreateFusionOutputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const { + auto* mul = GetNodeFromNodesMap(nodes_map, "mul", "mul"); + PADDLE_ENFORCE_EQ( + mul != nullptr, + true, + platform::errors::InvalidArgument("mul node ptr can not be null")); + // output && output max + std::string fc_xpu_out_name; + Node* fc_out_var_node = nullptr; + + auto* bn = GetNodeFromNodesMap(nodes_map, "bn", "bn"); + auto* ew_bias_add = + GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add"); + auto* act = GetNodeFromNodesMap(nodes_map, "act", "act"); + if (act) { + auto* act_out = GetNodeFromNodesMap(nodes_map, "act", "act_out"); + PADDLE_ENFORCE_EQ( + act_out != nullptr, + true, + platform::errors::InvalidArgument("act_out node ptr can not be null")); + fc_xpu_out_name = act_out->Name(); + fc_out_var_node = act_out; + } else if (bn) { + auto* bn_out = GetNodeFromNodesMap(nodes_map, "bn", "bn_out"); + PADDLE_ENFORCE_EQ( + bn_out != nullptr, + true, + platform::errors::InvalidArgument("bn_out node ptr can not be null")); + fc_xpu_out_name = bn_out->Name(); + fc_out_var_node = bn_out; + } else if (ew_bias_add) { + auto* ew_bias_add_out = + GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add_out"); + PADDLE_ENFORCE_EQ(ew_bias_add_out != nullptr, + true, + platform::errors::InvalidArgument( + "ew_bias_add_out node ptr can not be null")); + fc_xpu_out_name = ew_bias_add_out->Name(); + fc_out_var_node = ew_bias_add_out; + } else { + auto* mul_out = GetNodeFromNodesMap(nodes_map, "mul", "mul_out"); + PADDLE_ENFORCE_EQ( + mul_out != nullptr, + true, + platform::errors::InvalidArgument("mul_out node ptr can not be null")); + fc_xpu_out_name = mul_out->Name(); + fc_out_var_node = mul_out; + } + (*fusion_nodes_map)["out"] = fc_out_var_node; + + // Create out max in + if (op_weights_precision == "int8" && + AreScalesPresentForNodes(var_quant_scales, {fc_out_var_node})) { + std::string fc_out_max_in_name = fc_xpu_out_name + "_max_in"; + int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1); + VarDesc fc_out_max_in_desc(fc_out_max_in_name); + fc_out_max_in_desc.SetPersistable(true); + fc_out_max_in_desc.SetShape({static_cast(max_ptr_size)}); + fc_out_max_in_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + Node* fc_xpu_out_max_in = graph->CreateVarNode(&fc_out_max_in_desc); + auto* block_out_max_in_desc = block->Var(fc_out_max_in_name); + block_out_max_in_desc->SetPersistable(fc_out_max_in_desc.Persistable()); + block_out_max_in_desc->SetShape(fc_out_max_in_desc.GetShape()); + block_out_max_in_desc->SetDataType(fc_out_max_in_desc.GetDataType()); + + float output_scale = + GetScaleValueForNode(var_quant_scales, fc_out_var_node); + phi::DenseTensor out_max_in_cpu_tensor; + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + out_max_in_cpu_tensor.set_type(phi::DataType::FLOAT32); + out_max_in_cpu_tensor.Resize({max_ptr_size}); + std::vector output_scales(max_ptr_size, output_scale); + memcpy(cpu_ctx->Alloc(&out_max_in_cpu_tensor), + output_scales.data(), + max_ptr_size * sizeof(float)); + Assign(out_max_in_cpu_tensor, + scope->Var(fc_out_max_in_name)->GetMutable()); + (*fusion_nodes_map)["out_max_in"] = fc_xpu_out_max_in; + } + + // Create out max + std::string fc_out_max_name = fc_xpu_out_name + "_max"; + VarDesc fc_out_max_desc(fc_out_max_name); + Node* fc_xpu_out_max = graph->CreateVarNode(&fc_out_max_desc); + (*fusion_nodes_map)["out_max"] = fc_xpu_out_max; +} + +void FcXPUFusePass::CreateFusionInputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const { + // Get Node + auto* mul = GetNodeFromNodesMap(nodes_map, "mul", "mul"); + PADDLE_ENFORCE_EQ( + mul != nullptr, + true, + platform::errors::InvalidArgument("mul node ptr can not be null")); + auto* mul_x = GetNodeFromNodesMap(nodes_map, "mul", "mul_x"); + PADDLE_ENFORCE_EQ( + mul_x != nullptr, + true, + platform::errors::InvalidArgument("mul_x node ptr can not be null")); + // x max + std::string mul_x_max_name = mul_x->Name() + "_max"; + Node* mul_x_max = nullptr; + if (op_weights_precision == "int8") { + PADDLE_ENFORCE_EQ(AreScalesPresentForNodes(var_quant_scales, {mul_x}), + true, + platform::errors::InvalidArgument( + "When fc op is running in int8 precision, the scales " + "of input var should be present in!")); + float input_scale = GetScaleValueForNode(var_quant_scales, mul_x); + int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1); + VarDesc x_max_desc(mul_x_max_name); + x_max_desc.SetPersistable( + true); // Need depends on ir_params_sync_among_devices_pass copy to xpu + // device + x_max_desc.SetShape({static_cast(max_ptr_size)}); + x_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + mul_x_max = graph->CreateVarNode(&x_max_desc); + auto input_max_tensor = + scope->Var(mul_x_max_name)->GetMutable(); + input_max_tensor->set_type(phi::DataType::FLOAT32); + input_max_tensor->Resize({max_ptr_size}); + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + std::vector input_scales(max_ptr_size, input_scale); + memcpy(cpu_ctx->Alloc(input_max_tensor), + input_scales.data(), + max_ptr_size * sizeof(float)); + } + (*fusion_nodes_map)["x"] = mul_x; + (*fusion_nodes_map)["x_max"] = mul_x_max; +} + int FcXPUFusePass::ApplyImpl(ir::Graph* graph, const std::string& mul_type, bool with_bias, @@ -287,7 +709,9 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, with_bias, with_bn, act_type); - + auto* scope = param_scope(); + std::unordered_map> var_quant_scales = + GetQuantInfoFromTheGraph(graph, "has_quant_info", "var_quant_scales"); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { @@ -311,108 +735,96 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, GET_IR_NODE(bn_saved_mean); GET_IR_NODE(act); GET_IR_NODE(act_out); - auto* block = mul->Op()->Block(); - auto* scope = param_scope(); - - auto* filter_t = - scope->FindVar(mul_w->Name())->GetMutable(); - // weight fp16 --> fp32 - auto filter_dtype = filter_t->dtype(); - int out_dtype = proto::VarType::Type::VarType_Type_FP32; - if (filter_dtype == phi::DataType::FLOAT16) { - out_dtype = proto::VarType::Type::VarType_Type_FP16; - CastToFp32(filter_t, nullptr); - } - auto filter_dims = filter_t->dims(); - - bool transpose_w = false; - if (mul_type == "matmul") { - transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y")); - } else if (mul_type == "matmul_v2") { - transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y")); - } - - bool has_bias = with_bn || with_bias; - Node* fusion_bias_node = nullptr; - if (has_bias) { - if (bias != nullptr) { - PrepareBias(graph, scope, block, bias, &fusion_bias_node); - } - if (bn != nullptr) { - auto bn_bias_t = - scope->Var(bn_bias->Name())->GetMutable(); - auto bn_scale_t = - scope->Var(bn_scale->Name())->GetMutable(); - auto bn_mean_t = - scope->Var(bn_mean->Name())->GetMutable(); - auto bn_var_t = - scope->Var(bn_var->Name())->GetMutable(); - float* mul_w_ptr = filter_t->data(); - float* bn_scale_ptr = bn_scale_t->data(); - float* bn_bias_ptr = bn_bias_t->data(); - float* bn_mean_ptr = bn_mean_t->data(); - float* bn_var_ptr = bn_var_t->data(); - auto mean_len = bn_mean_t->numel(); - auto filter_h = filter_dims[0]; - auto filter_w = filter_dims[1]; - float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon")); - if (fusion_bias_node == nullptr) { // prev node is conv - PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node); - } - auto fusion_bias_t = scope->Var(fusion_bias_node->Name()) - ->GetMutable(); - float* fusion_bias_ptr = fusion_bias_t->data(); - // recompute bias and weights - if (bias == nullptr) { - for (int i = 0; i < mean_len; ++i) { - bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); - fusion_bias_ptr[i] += (0.f - bn_mean_ptr[i]) * bn_scale_ptr[i]; - for (int j = 0; j < filter_h; j++) { - mul_w_ptr[j * filter_w + i] *= bn_scale_ptr[i]; - } - } - } else { - for (int i = 0; i < mean_len; ++i) { - bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); - bn_bias_ptr[i] += - (fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i]; - for (int j = 0; j < filter_h; j++) { - mul_w_ptr[j * filter_w + i] *= bn_scale_ptr[i]; - } - } - memcpy(fusion_bias_ptr, bn_bias_ptr, mean_len * sizeof(float)); - } - } - } + std::map> nodes_map; + nodes_map.insert({"mul", + {{"mul", mul}, + {"mul_x", mul_x}, + {"mul_w", mul_w}, + {"mul_out", mul_out}}}); + nodes_map.insert({"ew_bias_add", + {{"ew_bias_add", add}, + {"ew_bias_add_bias", bias}, + {"ew_bias_add_out", add_out}}}); + nodes_map.insert({"bn", + {{"bn", bn}, + {"bn_bias", bn_bias}, + {"bn_mean", bn_mean}, + {"bn_scale", bn_scale}, + {"bn_var", bn_var}, + {"bn_out", bn_out}, + {"bn_var_out", bn_var_out}, + {"bn_mean_out", bn_mean_out}, + {"bn_saved_var", bn_saved_var}, + {"bn_saved_mean", bn_saved_mean}}}); + nodes_map.insert({"act", {{"act", act}, {"act_out", act_out}}}); - Node* mul_w_int16 = nullptr; - Node* mul_w_max = nullptr; - PrepareWeight( - graph, scope, block, mul_w, &mul_w_int16, &mul_w_max, !transpose_w); - - std::string fc_out_name; - if (act_out) { - fc_out_name = act_out->Name(); - } else if (bn) { - fc_out_name = bn_out->Name(); - } else if (add_out) { - fc_out_name = add_out->Name(); - } else { - fc_out_name = mul_out->Name(); + std::map fusion_nodes_map{{"x", nullptr}, + {"x_max", nullptr}, + {"w", nullptr}, + {"w_max", nullptr}, + {"bias", nullptr}, + {"scale_max", nullptr}, + {"out_max_in", nullptr}, + {"out", nullptr}, + {"out_max", nullptr}}; + auto filter_data_type = + scope->FindVar(mul_w->Name())->GetMutable()->dtype(); + std::string op_weights_precision = "float32"; + if (filter_data_type == phi::DataType::INT8) { + op_weights_precision = "int8"; + } else if (filter_data_type == phi::DataType::FLOAT16) { + op_weights_precision = "float16"; } - std::string fc_out_max_name = fc_out_name + "_max"; - VarDesc fc_out_max_desc(fc_out_max_name); - Node* fc_out_max = graph->CreateVarNode(&fc_out_max_desc); + VLOG(4) << "FC fusion fuse pass is running on " << op_weights_precision + << " precision!"; + auto* block = mul->Op()->Block(); + CreateFusionWeightsAndBias(graph, + scope, + block, + mul_type, + nodes_map, + &fusion_nodes_map, + with_bias, + with_bn, + op_weights_precision, + &var_quant_scales); + CreateFusionInputs(graph, + scope, + block, + nodes_map, + &fusion_nodes_map, + op_weights_precision, + &var_quant_scales); + CreateFusionOutputs(graph, + scope, + block, + nodes_map, + &fusion_nodes_map, + op_weights_precision, + &var_quant_scales); // Generate fc_xpu op framework::OpDesc fc_xpu_op_desc(block); fc_xpu_op_desc.SetType("fc_xpu"); - fc_xpu_op_desc.SetInput("x", {mul_x->Name()}); - fc_xpu_op_desc.SetInput("w", {mul_w_int16->Name()}); - fc_xpu_op_desc.SetInput("w_max", {mul_w_max->Name()}); - if (has_bias) { - fc_xpu_op_desc.SetInput("bias", {fusion_bias_node->Name()}); + fc_xpu_op_desc.SetInput("x", {fusion_nodes_map["x"]->Name()}); + if (fusion_nodes_map["x_max"]) { + fc_xpu_op_desc.SetInput("x_max", {fusion_nodes_map["x_max"]->Name()}); + } + fc_xpu_op_desc.SetInput("w", {fusion_nodes_map["w"]->Name()}); + fc_xpu_op_desc.SetInput("w_max", {fusion_nodes_map["w_max"]->Name()}); + if (fusion_nodes_map["bias"]) { + fc_xpu_op_desc.SetInput("bias", {fusion_nodes_map["bias"]->Name()}); + } + if (fusion_nodes_map["scale_max"]) { + fc_xpu_op_desc.SetInput("scale_max", + {fusion_nodes_map["scale_max"]->Name()}); } + if (fusion_nodes_map["out_max_in"]) { + fc_xpu_op_desc.SetInput("out_max_in", + {fusion_nodes_map["out_max_in"]->Name()}); + } + fc_xpu_op_desc.SetOutput("out", {fusion_nodes_map["out"]->Name()}); + fc_xpu_op_desc.SetOutput("out_max", {fusion_nodes_map["out_max"]->Name()}); fc_xpu_op_desc.SetAttr( "in_num_col_dims", static_cast(mul_x->Var()->GetShape().size() - 1)); @@ -440,48 +852,41 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, "act_alpha", PADDLE_GET_CONST(float, act->Op()->GetAttr("slope"))); } } - fc_xpu_op_desc.SetAttr("out_dtype", out_dtype); - fc_xpu_op_desc.SetOutput("out", {fc_out_name}); - fc_xpu_op_desc.SetOutput("out_max", {fc_out_max_name}); + // out_dtype is same to input precision + fc_xpu_op_desc.SetAttr("out_dtype", + fusion_nodes_map["x"]->Var()->GetDataType()); auto* fc_xpu = graph->CreateOpNode(&fc_xpu_op_desc); - IR_NODE_LINK_TO(mul_x, fc_xpu); - IR_NODE_LINK_TO(mul_w_int16, fc_xpu); - IR_NODE_LINK_TO(mul_w_max, fc_xpu); - if (bias || bn) { - SAFE_IR_NODE_LINK_TO(fusion_bias_node, fc_xpu); + IR_NODE_LINK_TO(fusion_nodes_map["x"], fc_xpu); + if (fusion_nodes_map["x_max"]) { + IR_NODE_LINK_TO(fusion_nodes_map["x_max"], fc_xpu); } - if (act_out) { - IR_NODE_LINK_TO(fc_xpu, act_out); - } else if (bn_out) { - IR_NODE_LINK_TO(fc_xpu, bn_out); - } else if (add_out) { - IR_NODE_LINK_TO(fc_xpu, add_out); - } else { - IR_NODE_LINK_TO(fc_xpu, mul_out); + IR_NODE_LINK_TO(fusion_nodes_map["w"], fc_xpu); + IR_NODE_LINK_TO(fusion_nodes_map["w_max"], fc_xpu); + if (fusion_nodes_map["scale_max"]) { + IR_NODE_LINK_TO(fusion_nodes_map["scale_max"], fc_xpu); } - IR_NODE_LINK_TO(fc_xpu, fc_out_max); + if (fusion_nodes_map["bias"]) { + IR_NODE_LINK_TO(fusion_nodes_map["bias"], fc_xpu); + } + if (fusion_nodes_map["out_max_in"]) { + IR_NODE_LINK_TO(fusion_nodes_map["out_max_in"], fc_xpu); + } + IR_NODE_LINK_TO(fc_xpu, fusion_nodes_map["out"]); + IR_NODE_LINK_TO(fc_xpu, fusion_nodes_map["out_max"]); // delete useless node std::unordered_set delete_nodes; - if (act != nullptr && add != nullptr) { - delete_nodes = {mul, mul_out, add, add_out, act}; - } else if (act) { - delete_nodes = {mul, mul_out, act}; - } else if (add) { - delete_nodes = {mul, mul_out, add}; - } else { - delete_nodes = {mul}; + if (mul != nullptr) { + delete_nodes.insert(mul); } if (bn != nullptr) { delete_nodes.insert(bn); - delete_nodes.insert(bn_bias); - delete_nodes.insert(bn_var); - delete_nodes.insert(bn_mean); - delete_nodes.insert(bn_scale); - delete_nodes.insert(bn_var_out); - delete_nodes.insert(bn_mean_out); - delete_nodes.insert(bn_saved_var); - delete_nodes.insert(bn_saved_mean); + } + if (add != nullptr) { + delete_nodes.insert(add); + } + if (act != nullptr) { + delete_nodes.insert(act); } GraphSafeRemoveNodes(graph, delete_nodes); found_subgraph_count++; diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc index 725f4e6a86a49..47bf2b06be9d9 100644 --- a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc @@ -424,11 +424,23 @@ int FusedMultiTransformerXPUPass::FusedMultiTransformerXPUQuant( nullptr, platform::errors::Fatal("w node should not be nullptr")); if (quant_post_dynamic_weight_precision == 0) { - PrepareWeight( - graph, scope, block, w_node, &w_intx, &w_max, need_transpose); + PrepareWeight(graph, + scope, + block, + w_node, + &w_intx, + &w_max, + need_transpose, + std::vector({})); } else { - PrepareWeight( - graph, scope, block, w_node, &w_intx, &w_max, need_transpose); + PrepareWeight(graph, + scope, + block, + w_node, + &w_intx, + &w_max, + need_transpose, + std::vector({})); } w_nodes->push_back(w_node); w_intx_nodes->push_back(w_intx); diff --git a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc index 1a9db472bc2cc..9b552bac36f2d 100644 --- a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc +++ b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc @@ -67,6 +67,7 @@ struct LinkConv2dPattern : public PatternBase { PATTERN_DECL_NODE(fusion_op); // declare variable node's name PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(filter); PATTERN_DECL_NODE(branch); private: @@ -79,14 +80,19 @@ LinkConv2dPattern::LinkConv2dPattern(PDPattern* pattern, : PatternBase(pattern, name_scope, name_scope), with_branch_(with_branch) { auto* fusion_op = pattern->NewNode(fusion_op_repr())->assert_is_op("conv2d_xpu"); + auto* x = pattern->NewNode(x_repr())->assert_is_op_input("conv2d_xpu", "x"); + auto* filter = pattern->NewNode(filter_repr()) + ->assert_is_op_input("conv2d_xpu", "filter") + ->assert_is_persistable_var(); PDNode* branch = nullptr; if (with_branch_) { branch = pattern->NewNode(branch_repr()) ->assert_is_op_input("conv2d_xpu", "branch"); - fusion_op->LinksFrom({branch}); + fusion_op->LinksFrom({x, branch, filter}); + } else { + fusion_op->LinksFrom({x, filter}); } - fusion_op->LinksFrom({x}); } struct LinkFcPattern : public PatternBase { @@ -96,18 +102,30 @@ struct LinkFcPattern : public PatternBase { PATTERN_DECL_NODE(fusion_op); // declare variable node's name PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(w); }; LinkFcPattern::LinkFcPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, name_scope) { auto* fusion_op = pattern->NewNode(fusion_op_repr())->assert_is_op("fc_xpu"); - auto* x = pattern->NewNode(x_repr())->assert_is_op_input("fc_xpu", "x"); - fusion_op->LinksFrom({x}); + auto* x = pattern->NewNode(x_repr())->assert_is_op_input("fc_xpu", "x"); + auto* w = pattern->NewNode(w_repr()) + ->assert_is_op_input("fc_xpu", "w") + ->assert_is_persistable_var(); + fusion_op->LinksFrom({x, w}); } } // namespace patterns +bool LinkXPUOpMaxPass::IsQuant(Node* weight_node) const { + auto w_dtype = param_scope() + ->FindVar(weight_node->Name()) + ->GetMutable() + ->dtype(); + return w_dtype == phi::DataType::INT8; +} + void LinkXPUOpMaxPass::LinkAddActMax(ir::Graph* graph) const { GraphPatternDetector gpd; patterns::LinkAddActPattern pattern(gpd.mutable_pattern(), name_scope_); @@ -155,15 +173,18 @@ void LinkXPUOpMaxPass::LinkConv2dMax(ir::Graph* graph, bool with_branch) const { patterns::LinkConv2dPattern pattern( gpd.mutable_pattern(), name_scope_, with_branch); int found_subgraph_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "handle LinkConv2dMax"; - /* declare operator node's name */ + /* get operator node's name */ GET_IR_NODE(fusion_op); - /* declare variable node's name*/ + /* get variable node's name*/ GET_IR_NODE(x); + GET_IR_NODE(filter); GET_IR_NODE(branch); + if (IsQuant(filter)) { + return; + } auto* fusion_op_desc = fusion_op->Op(); bool fusion_op_has_branch = fusion_op_desc->HasInput("branch"); if (fusion_op_has_branch) { @@ -177,7 +198,12 @@ void LinkXPUOpMaxPass::LinkConv2dMax(ir::Graph* graph, bool with_branch) const { auto preop_max_var_name = x_pre_op->Output("out_max"); for (auto max_node : x->inputs[0]->outputs) { if (preop_max_var_name[0] == max_node->Name()) { - fusion_op_desc->SetInput("x_max", {max_node->Name()}); + if (fusion_op_desc->HasInput("x_max")) { + auto x_max_old_name = fusion_op_desc->Input("x_max")[0]; + fusion_op_desc->RenameInput(x_max_old_name, max_node->Name()); + } else { + fusion_op_desc->SetInput("x_max", {max_node->Name()}); + } IR_NODE_LINK_TO(max_node, fusion_op); } } @@ -205,14 +231,16 @@ void LinkXPUOpMaxPass::LinkFcMax(ir::Graph* graph) const { GraphPatternDetector gpd; patterns::LinkFcPattern pattern(gpd.mutable_pattern(), name_scope_); int found_subgraph_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "handle LinkFcMax"; - /* declare operator node's name */ + /* get operator node's name */ GET_IR_NODE(fusion_op); - /* declare variable node's name*/ + /* get variable node's name*/ GET_IR_NODE(x); + GET_IR_NODE(w); + + if (IsQuant(w)) return; auto* fusion_op_desc = fusion_op->Op(); auto* x_pre_op = x->inputs[0]->Op(); if (x->inputs.size() > 0 && x->inputs[0]->IsOp() && @@ -220,7 +248,12 @@ void LinkXPUOpMaxPass::LinkFcMax(ir::Graph* graph) const { auto preop_max_var_name = x_pre_op->Output("out_max"); for (auto max_node : x->inputs[0]->outputs) { if (preop_max_var_name[0] == max_node->Name()) { - fusion_op_desc->SetInput("x_max", {max_node->Name()}); + if (fusion_op_desc->HasInput("x_max")) { + auto x_max_old_name = fusion_op_desc->Input("x_max")[0]; + fusion_op_desc->RenameInput(x_max_old_name, max_node->Name()); + } else { + fusion_op_desc->SetInput("x_max", {max_node->Name()}); + } IR_NODE_LINK_TO(max_node, fusion_op); } } diff --git a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.h b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.h index cad199ce573bb..a71a2e19cf430 100644 --- a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.h +++ b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.h @@ -102,6 +102,7 @@ Fused subgraph: */ void LinkAddActMax(ir::Graph* graph) const; + bool IsQuant(Node* weight_node) const; const std::string name_scope_{"link_xpu_op_max_pass"}; }; diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc index 255c1f5d47a4c..04439608aaa23 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc @@ -561,7 +561,8 @@ void MultiEncoderXPUFusePass::PrepareQKVWeight(Graph* graph, &q_w_fp32_t, &k_w_fp32_t, &v_w_fp32_t}; phi::ConcatKernel(*cpu_ctx, in_tensors, 0, &qkv_w_int16_t); - PrepareWeight(&qkv_w_int16_t, &qkv_w_max_t, false); + ConvertWithQuant( + &qkv_w_int16_t, &qkv_w_max_t, false, std::vector({})); size_t qkv_w_int16_hash = HashTensor(qkv_w_int16_t); size_t qkv_w_max_hash = HashTensor(qkv_w_max_t); std::string qkv_w_int16_name = std::to_string(qkv_w_int16_hash); @@ -813,16 +814,17 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( &qkv_w_int16, &qkv_w_max); -#define PREPARE_QKV_MATMUL_W(idx_) \ - Node* qkv_matmul_##idx_##_w_int16 = nullptr; \ - Node* qkv_matmul_##idx_##_w_max = nullptr; \ - PrepareWeight(graph, \ - scope, \ - block, \ - qkv_matmul_##idx_##_w, \ - &qkv_matmul_##idx_##_w_int16, \ - &qkv_matmul_##idx_##_w_max, \ - true); +#define PREPARE_QKV_MATMUL_W(idx_) \ + Node* qkv_matmul_##idx_##_w_int16 = nullptr; \ + Node* qkv_matmul_##idx_##_w_max = nullptr; \ + PrepareWeight(graph, \ + scope, \ + block, \ + qkv_matmul_##idx_##_w, \ + &qkv_matmul_##idx_##_w_int16, \ + &qkv_matmul_##idx_##_w_max, \ + true, \ + std::vector({})); PREPARE_QKV_MATMUL_W(1); PREPARE_QKV_MATMUL_W(2); PREPARE_QKV_MATMUL_W(3); diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.cc b/paddle/fluid/framework/ir/xpu/pass_utils.cc index eeb0e23e19ecd..c6dc291315399 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.cc +++ b/paddle/fluid/framework/ir/xpu/pass_utils.cc @@ -121,102 +121,123 @@ size_t HashTensor(const phi::DenseTensor& in) { template size_t HashTensor(const phi::DenseTensor& in); template size_t HashTensor(const phi::DenseTensor& in); +template size_t HashTensor(const phi::DenseTensor& in); std::string GetPrefixWithoutHash(const std::string& name) { std::size_t found = name.find("_#"); return found == std::string::npos ? name : name.substr(0, found); } -template +template void PrepareWeight(Graph* graph, Scope* scope, BlockDesc* block, - Node* src, - Node** dst, - Node** dst_max, - bool transpose) { - auto src_name = src->Name(); - auto* src_tensor = scope->Var(src_name)->GetMutable(); - phi::DenseTensor dst_tensor; - Assign(*src_tensor, &dst_tensor); - phi::DenseTensor dst_max_tensor; - PrepareWeight(&dst_tensor, &dst_max_tensor, transpose); - - size_t dst_hash = HashTensor(dst_tensor); - size_t dst_max_hash = HashTensor(dst_max_tensor); - std::string pre_name = GetPrefixWithoutHash(src_name); - std::string dst_name = pre_name + "_#" + std::to_string(dst_hash); - std::string dst_max_name = pre_name + "_max_#" + std::to_string(dst_max_hash); - *dst = FindNodeWithName(graph, dst_name); - if (*dst == nullptr) { - // Create dst node - // Update dst var_desc in block - VarDesc dst_desc(dst_name); - dst_desc.SetPersistable(true); - dst_desc.SetShape(vectorize(dst_tensor.dims())); - dst_desc.SetDataType(framework::TransToProtoVarType(dst_tensor.dtype())); - *dst = graph->CreateVarNode(&dst_desc); - auto* block_dst_desc = block->Var(dst_name); - block_dst_desc->SetPersistable(dst_desc.Persistable()); - block_dst_desc->SetShape(dst_desc.GetShape()); - block_dst_desc->SetDataType(dst_desc.GetDataType()); - // Create dst_max node - // Update dst_max var_desc in block - VarDesc dst_max_desc(dst_max_name); - dst_max_desc.SetPersistable(true); - dst_max_desc.SetShape(vectorize(dst_max_tensor.dims())); - dst_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); - *dst_max = graph->CreateVarNode(&dst_max_desc); - auto* block_dst_max_desc = block->Var(dst_max_name); - block_dst_max_desc->SetPersistable(dst_max_desc.Persistable()); - block_dst_max_desc->SetShape(dst_max_desc.GetShape()); - block_dst_max_desc->SetDataType(dst_max_desc.GetDataType()); - + Node* weight, + Node** dst_weight, + Node** dst_weight_max, + bool transpose, + const std::vector& weight_scales) { + auto weight_name = weight->Name(); + auto* weight_tensor = scope->Var(weight_name)->GetMutable(); + phi::DenseTensor dst_weight_tensor; + Assign(*weight_tensor, &dst_weight_tensor); + phi::DenseTensor dst_weight_max_tensor; + ConvertWeightWrapper( + &dst_weight_tensor, &dst_weight_max_tensor, transpose, weight_scales); + size_t dst_weight_hash = HashTensor(dst_weight_tensor); + size_t dst_weight_max_hash = HashTensor(dst_weight_max_tensor); + std::string pre_name = GetPrefixWithoutHash(weight_name); + std::string dst_weight_name = + pre_name + "_#" + std::to_string(dst_weight_hash); + std::string dst_weight_max_name = + pre_name + "_max_#" + std::to_string(dst_weight_max_hash); + *dst_weight = FindNodeWithName(graph, dst_weight_name); + if (*dst_weight == nullptr) { + // Create dst_weight node + // Update dst_weight var_desc in block + VarDesc dst_weight_desc(dst_weight_name); + dst_weight_desc.SetPersistable(true); + dst_weight_desc.SetShape(vectorize(dst_weight_tensor.dims())); + dst_weight_desc.SetDataType( + framework::TransToProtoVarType(dst_weight_tensor.dtype())); + *dst_weight = graph->CreateVarNode(&dst_weight_desc); + auto* block_dst_weight_desc = block->Var(dst_weight_name); + block_dst_weight_desc->SetPersistable(dst_weight_desc.Persistable()); + block_dst_weight_desc->SetShape(dst_weight_desc.GetShape()); + block_dst_weight_desc->SetDataType(dst_weight_desc.GetDataType()); + // Create dst_weight_max node + // Update dst_weight_max var_desc in block + VarDesc dst_weight_max_desc(dst_weight_max_name); + dst_weight_max_desc.SetPersistable(true); + dst_weight_max_desc.SetShape(vectorize(dst_weight_max_tensor.dims())); + dst_weight_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + *dst_weight_max = graph->CreateVarNode(&dst_weight_max_desc); + auto* block_dst_weight_max_desc = block->Var(dst_weight_max_name); + block_dst_weight_max_desc->SetPersistable( + dst_weight_max_desc.Persistable()); + block_dst_weight_max_desc->SetShape(dst_weight_max_desc.GetShape()); + block_dst_weight_max_desc->SetDataType(dst_weight_max_desc.GetDataType()); // Find dst/dst_max variable in scope - auto* dst_var = scope->FindVar(dst_name); - if (dst_var == nullptr) { - // Create dst/dst_max variable/tensor - Assign(dst_tensor, scope->Var(dst_name)->GetMutable()); - Assign(dst_max_tensor, - scope->Var(dst_max_name)->GetMutable()); + auto* dst_weight_var = scope->FindVar(dst_weight_name); + if (dst_weight_var == nullptr) { + // Create dst_weight/dst_weight_max variable/tensor + Assign(dst_weight_tensor, + scope->Var(dst_weight_name)->GetMutable()); + Assign(dst_weight_max_tensor, + scope->Var(dst_weight_max_name)->GetMutable()); } else { // Share the same variable PADDLE_ENFORCE_NOT_NULL( - scope->FindVar(dst_max_name), - platform::errors::Fatal( - "dst_max(%s) variable should not be nullptr if dst(%s) " - "variable is exist. (src_name is %s)", - dst_max_name, - dst_name, - src_name)); + scope->FindVar(dst_weight_max_name), + platform::errors::Fatal("dst_weight_max(%s) variable should not be " + "nullptr if dst_weight(%s) " + "variable is exist. (weight_name is %s)", + dst_weight_max_name, + dst_weight_name, + weight_name)); } } else { - *dst_max = FindNodeWithName(graph, dst_max_name); + *dst_weight_max = FindNodeWithName(graph, dst_weight_max_name); PADDLE_ENFORCE_NOT_NULL( - *dst_max, - platform::errors::Fatal( - "dst_max(%s) variable should not be nullptr if dst(%s) " - "variable is exist. (src_name is %s)", - dst_max_name, - dst_name, - src_name)); + *dst_weight_max, + platform::errors::Fatal("dst_weight_max(%s) variable should not be " + "nullptr if dst_weight(%s) " + "variable is exist. (weight_name is %s)", + dst_weight_max_name, + dst_weight_name, + weight_name)); } } -template void PrepareWeight(Graph* graph, - Scope* scope, - BlockDesc* block, - Node* src, - Node** dst, - Node** dst_max, - bool transpose); -template void PrepareWeight(Graph* graph, - Scope* scope, - BlockDesc* block, - Node* src, - Node** dst, - Node** dst_max, - bool transpose); +template void PrepareWeight( + Graph* graph, + Scope* scope, + BlockDesc* block, + Node* weight, + Node** dst_weight, + Node** dst_weight_max, + bool transpose, + const std::vector& weight_scales); + +template void PrepareWeight( + Graph* graph, + Scope* scope, + BlockDesc* block, + Node* weight, + Node** dst_weight, + Node** dst_weight_max, + bool transpose, + const std::vector& weight_scales); + +template void PrepareWeight( + Graph* graph, + Scope* scope, + BlockDesc* block, + Node* weight, + Node** dst_weight, + Node** dst_weight_max, + bool transpose, + const std::vector& weight_scales); void PrepareBias( Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst) { diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.h b/paddle/fluid/framework/ir/xpu/pass_utils.h index d1e7b218a0b46..668519c8eb406 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.h +++ b/paddle/fluid/framework/ir/xpu/pass_utils.h @@ -57,18 +57,62 @@ std::vector FindOpNodeByInputName(Graph* graph, template size_t HashTensor(const phi::DenseTensor& in); -template +template ::value, Tcpu>::type* + ptr = nullptr> +void ConvertWeightWrapper(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales) { + ConvertWithQuant(weight, weight_max, transpose, weight_scales); +} + +template ::value, Tcpu>::type* + ptr = nullptr> +void ConvertWeightWrapper(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales) { + ConvertWithoutQuant(weight, weight_max, transpose, weight_scales); +} + +// 1. Quant weight from fp32 to int16/int31/int8 +// 2. Weight data is in-place update. +// 3. Generate weight max tensor +template void PrepareWeight(Graph* graph, Scope* scope, BlockDesc* block, - Node* src, - Node** dst, - Node** dst_max, - bool transpose); + Node* weight, + Node** dst_weight, + Node** dst_weight_max, + bool transpose, + const std::vector& weight_scales); void PrepareBias( Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst); +inline std::string FindOutputNameByVarName(framework::OpDesc* op, + const std::string& searched_name) { + std::string ret; + for (const auto& name : op->OutputNames()) + for (const auto& output_name : op->Output(name)) + if (output_name == searched_name) ret = name; + return ret; +} + +inline std::string FindInputNameByVarName(framework::OpDesc* op, + const std::string& searched_name) { + std::string ret; + for (const auto& name : op->InputNames()) + for (const auto& input_name : op->Input(name)) + if (input_name == searched_name) ret = name; + return ret; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.cc b/paddle/fluid/framework/ir/xpu/quant_utils.cc index fcda50051a362..a137a006e9f70 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.cc +++ b/paddle/fluid/framework/ir/xpu/quant_utils.cc @@ -64,9 +64,12 @@ void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out) { case phi::DataType::FLOAT32: phi::TransposeKernel(*cpu_ctx, *in, axis, out_ptr); break; + case phi::DataType::INT8: + phi::TransposeKernel(*cpu_ctx, *in, axis, out_ptr); + break; default: PADDLE_THROW(platform::errors::InvalidArgument( - "Only support fp16 and fp32, but received dtype is %s.", + "Only support fp16/fp32/int8, but received dtype is %s.", phi::DataTypeToString(in->dtype()))); break; } @@ -258,15 +261,30 @@ void QuantFP32ToIntX(const float* src_ptr, } } -template -void PrepareWeight(phi::DenseTensor* weight, - phi::DenseTensor* weight_max, - bool transpose) { +template < + typename Tcpu, + typename Txpu, + typename std::enable_if::value, Tcpu>::type* ptr> +void ConvertWithQuant(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales) { + LOG(FATAL) << "Not support for Tcpu is " + << phi::CppTypeToDataType::Type(); +} + +template < + typename Tcpu, + typename Txpu, + typename std::enable_if::value, Tcpu>::type* ptr> +void ConvertWithQuant(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales) { // Convert fp16 to fp32 phi::DenseTensor weight_fp32; CastToFp32(weight, &weight_fp32); - // Transpose if (transpose) { Transpose2D(&weight_fp32); } @@ -286,17 +304,74 @@ void PrepareWeight(phi::DenseTensor* weight, max_ptr_size * sizeof(float)); // Quant - weight->set_type(phi::CppTypeToDataType::Type()); + weight->set_type(phi::CppTypeToDataType::Type()); weight->Resize(weight_fp32.dims()); - QuantFP32ToIntX(weight_data, cpu_ctx->Alloc(weight), max_val, size); + QuantFP32ToIntX( + weight_data, cpu_ctx->Alloc(weight), max_val, size); } -template void PrepareWeight(phi::DenseTensor* weight, - phi::DenseTensor* weight_max, - bool transpose); -template void PrepareWeight(phi::DenseTensor* weight, - phi::DenseTensor* weight_max, - bool transpose); +template +void ConvertWithoutQuant(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales) { + if (transpose) { + Transpose2D(weight); + } + if (std::is_same::value || std::is_same::value) { + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + int max_ptr_size = weight_scales.empty() + ? phi::backends::xpu::get_xpu_max_ptr_size(-1) + : weight_scales.size(); + weight_max->set_type(phi::DataType::FLOAT32); + weight_max->Resize({max_ptr_size}); + if (!weight_scales.empty()) { + memcpy(cpu_ctx->Alloc(weight_max), + weight_scales.data(), + max_ptr_size * sizeof(float)); + } else { + LOG(FATAL) << "weight scales cannot be empty!"; + } + } else { + LOG(FATAL) << "Only support int8<->int8 and int16<->int16 convert."; + } +} + +template void ConvertWithQuant( + phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales); + +template void ConvertWithQuant( + phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales); + +template void ConvertWithoutQuant( + phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales); + +bool IsPerTensorQuant(const std::vector& weight_max) { + bool per_tensor = true; + PADDLE_ENFORCE_GT( + weight_max.size(), + 0, + platform::errors::InvalidArgument( + "Op's channel size: [%d] should great than zero", weight_max.size())); + auto first = weight_max[0]; + for (size_t i = 1; i < weight_max.size(); ++i) { + if (std::abs(first - weight_max[i]) > 1e-6) { + per_tensor = false; + break; + } + } + return per_tensor; +} } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.h b/paddle/fluid/framework/ir/xpu/quant_utils.h index b417fa03323db..1a2952c614542 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.h +++ b/paddle/fluid/framework/ir/xpu/quant_utils.h @@ -27,13 +27,31 @@ void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); -// 1. Quant weight from fp32 to int16/int31 -// 2. Weight data is in-place update. -// 3. Generate weight max tensor template -void PrepareWeight(phi::DenseTensor* weight, - phi::DenseTensor* weight_max, - bool transpose); +void ConvertWithoutQuant(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales); + +template ::value, Tcpu>::type* + ptr = nullptr> +void ConvertWithQuant(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales); + +template ::value, + Tcpu>::type* ptr = nullptr> +void ConvertWithQuant(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales); + +bool IsPerTensorQuant(const std::vector& weight_max); } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc index 8383501c30b8f..8fa4a377175a7 100644 --- a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc @@ -286,9 +286,6 @@ void MapMatmulV2ToMatmulXPUPass::MapMatmulV2ToMatmul(ir::Graph* graph) const { desc.SetAttr("transpose_X", matmul_v2->Op()->GetAttr("trans_x")); desc.SetAttr("transpose_Y", matmul_v2->Op()->GetAttr("trans_y")); desc.SetAttr("alpha", 1.0f); - if (matmul_v2->Op()->HasAttr("use_mkldnn")) { - desc.SetAttr("use_mkldnn", matmul_v2->Op()->GetAttr("use_mkldnn")); - } auto matmul_node = graph->CreateOpNode(&desc); IR_NODE_LINK_TO(matmul_x, matmul_node); IR_NODE_LINK_TO(matmul_y, matmul_node); diff --git a/paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.cc b/paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.cc new file mode 100644 index 0000000000000..f1d2752321aad --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { +PDNode *patterns::DequantXPUAny::operator()() { + auto *dequant_op = + pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize_xpu"); + + auto *dequant_out = pattern->NewNode(dequant_out_repr()) + ->AsOutput() + ->assert_is_op_output("dequantize_xpu", "y"); + + auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); + + dequant_op->LinksTo({dequant_out}); + next_op->LinksFrom({dequant_out}); + + return dequant_out; +} + +PDNode *patterns::QuantXPUAny::operator()() { + auto *quant_in = pattern->NewNode(quant_in_repr()) + ->AsInput() + ->assert_is_op_input("quantize_xpu", "x"); + auto *quant_op = + pattern->NewNode(quant_op_repr())->assert_is_op("quantize_xpu"); + + auto *quant_out = pattern->NewNode(quant_out_repr()) + ->AsOutput() + ->assert_is_op_output("quantize_xpu", "y"); + + auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); + + quant_op->LinksFrom({quant_in}).LinksTo({quant_out}); + next_op->LinksFrom({quant_out}); + + return quant_out; +} + +PDNode *patterns::DequantQuantXPUAny::operator()() { + auto *dequant_in = pattern->NewNode(dequant_in_repr()) + ->AsInput() + ->assert_is_op_input("dequantize_xpu", "x"); + + auto *dequant_op = + pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize_xpu"); + + auto *dequant_out = pattern->NewNode(dequant_out_repr()) + ->AsOutput() + ->assert_is_op_output("dequantize_xpu", "y"); + + auto *quant_op = pattern->NewNode(quant_op_repr()) + ->assert_is_op("quantize_xpu") + ->AsIntermediate(); + + auto *quant_out = pattern->NewNode(quant_out_repr()) + ->AsOutput() + ->assert_is_op_output("quantize_xpu"); + + auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); + + dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out}); + quant_op->LinksFrom({dequant_out}).LinksTo({quant_out}); + next_op->LinksFrom({quant_out}); + + return quant_out; +} + +PDNode *patterns::OpDequantXPU::operator()() { + auto any_op = pattern->NewNode(any_op_repr())->assert_is_op(); + auto *dequant_in = pattern->NewNode(dequant_in_repr()) + ->assert_is_op_input("dequantize_xpu", "x"); + auto *dequant_op = + pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize_xpu"); + auto dequant_out = pattern->NewNode(dequant_out_repr()) + ->AsOutput() + ->assert_is_op_output("dequantize_xpu", "y"); + + any_op->LinksTo({dequant_in}); + dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out}); + return dequant_out; +} + +PDNode *patterns::MultipleQuantizeXPU::operator()() { + auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput(); + + // find nodes that are inputs to quantize operators + prev_out->assert_more([&](Node *node) { + int counter = static_cast(std::count_if( + node->outputs.begin(), node->outputs.end(), [&](Node const *iter) { + return iter && iter->IsOp() && iter->Op()->Type() == "quantize_xpu"; + })); + return (counter > 1); + }); + + return prev_out; +} + +} // namespace patterns +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.h b/paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.h new file mode 100644 index 0000000000000..c849b2a24bb48 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.h @@ -0,0 +1,96 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +// Dequantize + anyOP +// This quantize is used for getting number of ops the Dequantize's +// output is an input to. +struct DequantXPUAny : public PatternBase { + DequantXPUAny(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "dequant_xpu_any") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(dequant_op); + PATTERN_DECL_NODE(dequant_out); + PATTERN_DECL_NODE(next_op); +}; + +// Quantize + anyOP +struct QuantXPUAny : public PatternBase { + QuantXPUAny(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "quant_xpu_any") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(quant_in); + PATTERN_DECL_NODE(quant_op); + PATTERN_DECL_NODE(quant_out); + PATTERN_DECL_NODE(next_op); +}; + +// Dequantize + Quantize + anyOP +// This pattern is used for squashing the dequantize-quantize pairs. +struct DequantQuantXPUAny : public PatternBase { + DequantQuantXPUAny(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "dequant_quant_xpu_any") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(dequant_in); + PATTERN_DECL_NODE(dequant_max_in); + PATTERN_DECL_NODE(dequant_op); + PATTERN_DECL_NODE(dequant_out); + PATTERN_DECL_NODE(quant_max_in); + PATTERN_DECL_NODE(quant_op); + PATTERN_DECL_NODE(quant_out); + PATTERN_DECL_NODE(next_op); +}; + +// Op + Dequant +// named nodes: +// any_op, dequant_in +// dequant_op, dequant_out +struct OpDequantXPU : public PatternBase { + OpDequantXPU(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "op_dequant_xpu") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(any_op); + PATTERN_DECL_NODE(dequant_in); + PATTERN_DECL_NODE(dequant_max_in); + PATTERN_DECL_NODE(dequant_op); + PATTERN_DECL_NODE(dequant_out); +}; + +// anyOp + more then one quantize op +// This pattern is used for squashing multiple quantize with the same scale. +struct MultipleQuantizeXPU : public PatternBase { + MultipleQuantizeXPU(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "multiple_quantize_xpu") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(prev_out); +}; + +} // namespace patterns +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.cc b/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.cc new file mode 100644 index 0000000000000..761f17a92e299 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.cc @@ -0,0 +1,280 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.h" + +#include +#include +#include + +#include "paddle/fluid/framework/ir/quantize_helper.h" +#include "paddle/utils/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +static void UnlinkNodes(ir::Node* a, ir::Node* b) { + a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b), + a->outputs.end()); + b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a), + b->inputs.end()); +} + +static void MarkAndLogCannotQuantizeOp(Node* op, + const char* details = nullptr) { + std::stringstream msg_ss; + msg_ss << "Cannot quantize operator " << op->Name() + << " (type: " << op->Op()->Type() << ", id: " << op->id() << ")."; + if (details) msg_ss << " " << details; + VLOG(2) << msg_ss.str().c_str(); +} +void XPUQuantizeOpPass::GetQuantInfo(Graph* graph) const { + var_quant_scales_ = + GetQuantInfoFromTheGraph(graph, "has_quant_info", "var_quant_scales"); +} + +void XPUQuantizeOpPass::QuantizeInput(Graph* g, + Node* op, + Node* input, + std::string input_arg_name) const { + auto inputs = op->Op()->InputNames(); + bool name_found = + std::find(inputs.begin(), inputs.end(), input_arg_name) != inputs.end(); + PADDLE_ENFORCE_EQ(name_found, + true, + platform::errors::InvalidArgument( + "Var(%s) isn't the input of the %s operator.", + input_arg_name, + op->Op()->Type())); + + // Create quantize output variable + VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out")); + auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc); + quantize_out_node->Var()->SetDataType( + proto::VarType::Type::VarType_Type_INT8); + + // Create a quantize op node + float scale = GetScaleValueForNode(&var_quant_scales_, input); + OpDesc q_desc; + q_desc.SetType("quantize_xpu"); + q_desc.SetInput("x", std::vector({input->Name()})); + q_desc.SetOutput("y", std::vector({quantize_out_node->Name()})); + q_desc.SetAttr("out_dtype", + static_cast(proto::VarType::Type::VarType_Type_INT8)); + q_desc.SetAttr("scale", static_cast(scale)); + auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied. + + // Update op's input + op->Op()->SetInput(input_arg_name, + std::vector({quantize_out_node->Name()})); + + // Link quantize op + UnlinkNodes(input, op); + IR_NODE_LINK_TO(input, quantize_op); + IR_NODE_LINK_TO(quantize_op, quantize_out_node); + IR_NODE_LINK_TO(quantize_out_node, op); +} + +void XPUQuantizeOpPass::DequantizeOutput(Graph* g, + Node* op, + Node* output, + std::string output_arg_name) const { + auto outputs = op->Op()->OutputNames(); + bool name_found = + std::find(outputs.begin(), outputs.end(), output_arg_name) != + outputs.end(); + PADDLE_ENFORCE_EQ(name_found, + true, + platform::errors::InvalidArgument( + "Var(%s) isn't the output of the %s operator.", + output_arg_name, + op->Op()->Type())); + + // Create dequantize input variable + VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in")); + auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc); + dequantize_in_node->Var()->SetDataType( + proto::VarType::Type::VarType_Type_INT8); + + float scale = GetScaleValueForNode(&var_quant_scales_, output); + // Create a quantize op node + OpDesc deq_desc; + deq_desc.SetType("dequantize_xpu"); + deq_desc.SetInput("x", + std::vector({dequantize_in_node->Name()})); + deq_desc.SetOutput("y", std::vector({output->Name()})); + deq_desc.SetAttr("out_dtype", static_cast(output->Var()->GetDataType())); + deq_desc.SetAttr("scale", static_cast(scale)); + auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied. + + // Update op's input + op->Op()->SetOutput(output_arg_name, + std::vector({dequantize_in_node->Name()})); + + // Link dequantize op + UnlinkNodes(op, output); + IR_NODE_LINK_TO(op, dequantize_in_node); + IR_NODE_LINK_TO(dequantize_in_node, dequantize_op); + IR_NODE_LINK_TO(dequantize_op, output); +} + +void XPUQuantizeOpPass::QuantizeConv(ir::Graph* graph) const { + for (auto* n : graph->Nodes()) { + if (n->IsOp()) { + auto* op = n->Op(); + if (op->Type() != "conv2d_xpu") { + continue; + } + Node* w_var_node = nullptr; + Node* x_var_node = nullptr; + Node* out_var_node = nullptr; + Node* branch_var_node = nullptr; + + for (auto* input_node : n->inputs) { + if (!input_node->IsVar()) { + continue; + } + if (input_node->Var()->Name() == op->Input("x")[0]) { + x_var_node = input_node; + } else if (input_node->Var()->Name() == op->Input("filter")[0]) { + w_var_node = input_node; + } else if (op->HasInput("branch") && + input_node->Var()->Name() == op->Input("branch")[0]) { + branch_var_node = input_node; + } + } + + for (auto* output_node : n->outputs) { + if (!output_node->IsVar()) { + continue; + } + if (output_node->Var()->Name() == op->Output("out")[0]) { + out_var_node = output_node; + } + } + if (!AreScalesPresentForNodes(&var_quant_scales_, + {x_var_node, w_var_node})) { + MarkAndLogCannotQuantizeOp(n, "No scale available for the operator"); + return; + } + + QuantizeInput(graph, n, x_var_node, "x"); + auto has_output_scale = + AreScalesPresentForNodes(&var_quant_scales_, {out_var_node}); + bool has_branch = branch_var_node != nullptr; + + // Note: Conv2d fusion requires branch datatype is same as output + // datatype, so we should consider branch/output together. + if (has_branch) { + bool has_branch_scale = + AreScalesPresentForNodes(&var_quant_scales_, {branch_var_node}); + if (has_output_scale && has_branch_scale) { + QuantizeInput(graph, n, branch_var_node, "branch"); + DequantizeOutput(graph, n, out_var_node, "out"); + // Note: out_dtype attr must be set, because if dequantize_output, we + // consider the kernel out_dtype as int8. + n->Op()->SetAttr( + "out_dtype", + static_cast(proto::VarType::Type::VarType_Type_INT8)); + } else { + n->Op()->SetAttr("out_dtype", x_var_node->Var()->GetDataType()); + } + } else { + if (has_output_scale) { + DequantizeOutput(graph, n, out_var_node, "out"); + // Note: out_dtype attr must be set, because if dequantize_output, we + // consider the kernel out_dtype as int8. + n->Op()->SetAttr( + "out_dtype", + static_cast(proto::VarType::Type::VarType_Type_INT8)); + } else { + n->Op()->SetAttr("out_dtype", x_var_node->Var()->GetDataType()); + } + } + } + } +} + +void XPUQuantizeOpPass::QuantizeFC(ir::Graph* graph) const { + for (auto* n : graph->Nodes()) { + if (n->IsOp()) { + auto* op = n->Op(); + if (op->Type() != "fc_xpu") { + continue; + } + Node* w_var_node = nullptr; + Node* x_var_node = nullptr; + Node* out_var_node = nullptr; + + for (auto* input_node : n->inputs) { + if (!input_node->IsVar()) { + continue; + } + if (input_node->Var()->Name() == op->Input("x")[0]) { + x_var_node = input_node; + } else if (input_node->Var()->Name() == op->Input("w")[0]) { + w_var_node = input_node; + } + } + + for (auto* output_node : n->outputs) { + if (!output_node->IsVar()) { + continue; + } + if (output_node->Var()->Name() == op->Output("out")[0]) { + out_var_node = output_node; + } + } + if (!AreScalesPresentForNodes(&var_quant_scales_, + {x_var_node, w_var_node})) { + MarkAndLogCannotQuantizeOp(n, "No scale available for the operator"); + return; + } + + QuantizeInput(graph, n, x_var_node, "x"); + + auto has_output_scale = + AreScalesPresentForNodes(&var_quant_scales_, {out_var_node}); + if (has_output_scale) { + DequantizeOutput(graph, n, out_var_node, "out"); + n->Op()->SetAttr( + "out_dtype", + static_cast(proto::VarType::Type::VarType_Type_INT8)); + } else { + n->Op()->SetAttr("out_dtype", x_var_node->Var()->GetDataType()); + } + } + } +} + +void XPUQuantizeOpPass::ApplyImpl(ir::Graph* graph) const { + VLOG(3) << "Insert quantize/dequantize op to the graph."; + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init(name_scope_, graph); + PADDLE_ENFORCE_NOT_NULL( + param_scope(), + platform::errors::InvalidArgument("Scope cannot be nullptr.")); + + GetQuantInfo(graph); + QuantizeConv(graph); + QuantizeFC(graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(xpu_quantize_op_pass, paddle::framework::ir::XPUQuantizeOpPass); diff --git a/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.h b/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.h new file mode 100644 index 0000000000000..28d0f42e76bde --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.h @@ -0,0 +1,62 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Quantize all supported operators. + */ +class XPUQuantizeOpPass : public FusePassBase { + public: + virtual ~XPUQuantizeOpPass() {} + + protected: + void ApplyImpl(Graph* graph) const override; + void QuantizeConv(Graph* graph) const; + void QuantizeFC(Graph* graph) const; + + private: + void QuantizeInput(Graph* g, + Node* op, + Node* input, + std::string input_arg_name) const; + + void DequantizeOutput(Graph* g, + Node* op, + Node* output, + std::string output_arg_name) const; + + void GetQuantInfo(Graph* graph) const; + + mutable std::unordered_map> var_quant_scales_; + const std::string name_scope_{"xpu_quantize_op_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.cc new file mode 100644 index 0000000000000..6161293bf7fb7 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.cc @@ -0,0 +1,281 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file eint8_outcept in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either eint8_outpress or +// implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.h" + +#include +#include + +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/utils/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +XPUQuantizeSquashPass::XPUQuantizeSquashPass() {} + +void XPUQuantizeSquashPass::FindNodesToKeep( + Graph* graph, + std::unordered_map* nodes_keep_counter) const { + GraphPatternDetector gpd; + patterns::DequantXPUAny deq_any_pattern{gpd.mutable_pattern(), + "dequant_xpu_any"}; + deq_any_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, deq_any_pattern); + + if (nodes_keep_counter->find(dequant_out) == nodes_keep_counter->end()) + (*nodes_keep_counter)[dequant_out] = 1; + else + (*nodes_keep_counter)[dequant_out] += 1; + + found_count++; + }; + gpd(graph, handler); + AddStatis(found_count); +} + +void XPUQuantizeSquashPass::DequantQuantSquash( + Graph* graph, + std::unordered_map* nodes_keep_counter) const { + GraphPatternDetector gpd; + patterns::DequantQuantXPUAny squash_pattern{gpd.mutable_pattern(), + "dequant_quant_xpu_any"}; + squash_pattern(); + + int found_dequant_quant_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern); + + auto* next_op_desc = next_op->Op(); + float dequant_scale = + PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("scale")); + float quant_scale = + PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("scale")); + + PADDLE_ENFORCE_NE( + nodes_keep_counter->find(dequant_out), + nodes_keep_counter->end(), + platform::errors::NotFound("The dequant output node is not found.")); + + // check if dequantize op should be kept or removed, decrease the counter + bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1; + + if (dequant_scale == quant_scale) { + // squash dequantize-quantize to nothing + auto quant_out_var_name = quant_out->Name(); + for (auto input_name : next_op_desc->InputNames()) { + auto& input_names = next_op_desc->MutableInputs()->at(input_name); + std::replace(input_names.begin(), + input_names.end(), + quant_out_var_name, + dequant_in->Name()); + next_op_desc->SetInput(input_name, input_names); + } + if (keep_dequant) + GraphSafeRemoveNodes(graph, {quant_op, quant_out}); + else + GraphSafeRemoveNodes(graph, + {dequant_op, quant_op, dequant_out, quant_out}); + + IR_NODE_LINK_TO(dequant_in, next_op); + + found_dequant_quant_count++; + } + }; + gpd(graph, handler); + AddStatis(found_dequant_quant_count); + PrettyLogDetail("--- squashed %d dequantize-quantize pairs", + found_dequant_quant_count); +} + +void XPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const { + GraphPatternDetector gpd; + patterns::OpDequantXPU op_dequant_pattern{gpd.mutable_pattern(), + "op_dequant_xpu"}; + op_dequant_pattern(); + + int found_op_dequant_squash_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "squash op-dequant ops pair"; + GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, op_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, op_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, op_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, op_dequant_pattern); + + if (dequant_in->outputs.size() == 1) { + // Find the name of the output linking any_op to dequant_in + std::string output_name = + FindOutputNameByVarName(any_op->Op(), dequant_in->Name()); + + if (output_name.empty()) return; + any_op->Op()->SetAttr("out_dtype", dequant_out->Var()->GetDataType()); + any_op->Op()->SetOutput(output_name, + std::vector({dequant_out->Name()})); + IR_NODE_LINK_TO(any_op, dequant_out); + GraphSafeRemoveNodes(graph, {dequant_in, dequant_op}); + found_op_dequant_squash_count++; + } + }; + gpd(graph, handler); + AddStatis(found_op_dequant_squash_count); + PrettyLogDetail("--- squashed %d dequant with ops", + found_op_dequant_squash_count); +} + +// conv2d_xpu, fc_xpu +void XPUQuantizeSquashPass::QuantOpSquash(Graph* graph) const { + GraphPatternDetector gpd; + patterns::QuantXPUAny quant_any_pattern{gpd.mutable_pattern(), + "quant_xpu_any"}; + quant_any_pattern(); + + int found_quant_op_squash_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "squash op-dequant ops pair"; + + GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, quant_any_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, quant_any_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, quant_any_pattern); + GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, quant_any_pattern); + + if (quant_out->outputs.size() == 1) { + std::string input_name = + FindInputNameByVarName(next_op->Op(), quant_out->Name()); + + if (input_name.empty()) return; + // Only support quant + conv2d_xpu/fc_xpu fusion + if (!(next_op->Op()->Type() == "conv2d_xpu" || + next_op->Op()->Type() == "fc_xpu")) { + return; + } + next_op->Op()->SetInput(input_name, + std::vector({quant_in->Name()})); + IR_NODE_LINK_TO(quant_in, next_op); + GraphSafeRemoveNodes(graph, {quant_out, quant_op}); + found_quant_op_squash_count++; + } + }; + gpd(graph, handler); + AddStatis(found_quant_op_squash_count); + PrettyLogDetail("--- squashed %d quantize with ops", + found_quant_op_squash_count); +} + +void XPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { + GraphPatternDetector gpd; + patterns::MultipleQuantizeXPU multiple_quantize_pattern{ + gpd.mutable_pattern(), "multiple_quantize_xpu"}; + multiple_quantize_pattern(); + + int found_multiple_quantize_squash_count = 0; + int removed_quantize = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "fuse multiple quantize ops"; + + GET_IR_NODE_FROM_SUBGRAPH(prev_out, prev_out, multiple_quantize_pattern); + + auto* first_quant_op = *(std::find_if( + prev_out->outputs.begin(), prev_out->outputs.end(), [&](Node* node) { + return (node->IsOp() && node->Op()->Type() == "quantize_xpu"); + })); + auto* first_quant_out = first_quant_op->outputs[0]; + float scale = first_quant_op->Op()->GetAttrIfExists("scale"); + + PADDLE_ENFORCE_NE(scale, + 0, + platform::errors::InvalidArgument( + "Quantize scale(%f) should not be equal 0.", scale)); + + for (int iter = prev_out->outputs.size() - 1; iter >= 0; iter--) { + auto quant_op = prev_out->outputs[iter]; + if (quant_op->IsOp() && quant_op->Op()->Type() == "quantize_xpu" && + quant_op->id() != first_quant_op->id() && + quant_op->Op()->GetAttrIfExists("scale") == scale) { + auto quant_out = quant_op->outputs[0]; + auto last_op = quant_out->outputs[0]; + auto last_op_op = last_op->Op(); + + std::string last_op_input_name = + FindInputNameByVarName(last_op_op, quant_out->Name()); + + PADDLE_ENFORCE_NE( + last_op_input_name.empty(), + true, + platform::errors::NotFound("Operator after quantize operator(%s) " + "should have quantize output as input.", + quant_out->Name())); + + // update the next operator input, + // by replacing quant_out with first_quant_out + auto last_op_names = last_op->Op()->Inputs().at(last_op_input_name); + std::replace(last_op_names.begin(), + last_op_names.end(), + quant_out->Name(), + first_quant_out->Name()); + last_op_op->SetInput(last_op_input_name, + std::vector(last_op_names)); + + IR_NODE_LINK_TO(first_quant_out, last_op); + GraphSafeRemoveNodes(graph, {quant_op, quant_out}); + removed_quantize++; + } + } + found_multiple_quantize_squash_count++; + }; + gpd(graph, handler); + AddStatis(found_multiple_quantize_squash_count); + PrettyLogDetail("--- squashed %d quantize op", removed_quantize); +} + +void XPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, + platform::errors::InvalidArgument( + "The graph in function XPUQuantizeSquashPass::ApplyImpl is null.")); + FusePassBase::Init("xpu_quantize_squash_pass", graph); + + std::unordered_map nodes_keep_counter; + FindNodesToKeep(graph, &nodes_keep_counter); + DequantQuantSquash(graph, &nodes_keep_counter); + OpDequantSquash(graph); + // QuantOpSquash(graph); // If the quant op is fused into conv2d_xpu, the + // performance will become worse. + MultipleQuantizeSquash(graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(xpu_quantize_squash_pass, + paddle::framework::ir::XPUQuantizeSquashPass); diff --git a/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.h new file mode 100644 index 0000000000000..d3f37dd42010d --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.h @@ -0,0 +1,73 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Squash dequantize->quantize pair pattern into requantize op + */ + +class XPUQuantizeSquashPass : public FusePassBase { + public: + XPUQuantizeSquashPass(); + virtual ~XPUQuantizeSquashPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + /* + * For each dequantize's output find the number of operators it is an input to + */ + void FindNodesToKeep( + Graph* graph, + std::unordered_map* nodes_keep_counter) const; + + /* + * Squash dequantize-quantize ops pairs into nothing + */ + void DequantQuantSquash( + Graph* graph, + std::unordered_map* nodes_keep_counter) const; + + /* + * Squash dequant if the previous operator support fp32 out + */ + void OpDequantSquash(Graph* graph) const; + + /* + * Squash quantize if several quatize ops have the same scale + */ + void MultipleQuantizeSquash(Graph* graph) const; + + /* + * Squash quantize if is before conv2d_xpu/fc_xpuy + */ + void QuantOpSquash(Graph* graph) const; + + const std::string name_scope_{"squash"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index 9f8e9ed80ca46..96ead2e8b032e 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/naive_executor.h" +#include #include #include #include @@ -51,6 +52,38 @@ void NaiveExecutor::Prepare(Scope *scope, CreateOps(program_desc, block_id, with_feed_fetch_ops); } +void NaiveExecutor::PrepareInterpreterCore( + Scope *scope, + const ProgramDesc &program_desc, + const framework::interpreter::ExecutionConfig &execution_config) { + interpreter_core_ = std::make_unique( + place_, program_desc.Block(0), scope, execution_config); +} + +void NaiveExecutor::PrepareInterpreterCore( + Scope *scope, + const ::pir::Program &pir_program, + const framework::interpreter::ExecutionConfig &execution_config) { + interpreter_core_ = + std::make_unique(place_, + std::vector{}, + pir_program.block(), + scope, + execution_config); +} + +void NaiveExecutor::RunInterpreterCore( + const std::vector &feed_names, bool need_fetch) { + platform::ScopedFlushDenormal flush; +#ifdef PADDLE_WITH_NVTX + platform::CudaNvtxRangePush("model", platform::NvtxRangeColor::Yellow); +#endif + interpreter_core_->Run(feed_names, need_fetch); +#ifdef PADDLE_WITH_NVTX + platform::CudaNvtxRangePop(); +#endif +} + void NaiveExecutor::Run() { #ifdef PADDLE_WITH_DNNL platform::AttachPointerHashToMKLDNNKey(this, place_); @@ -190,6 +223,9 @@ phi::DenseTensor *NaiveExecutor::FindTensor(const std::string &name) { void NaiveExecutor::RegisterOutputHook(const HookFunc &hookfunc) { output_hookfuncs_.push_back(hookfunc); + if (interpreter_core_) { + interpreter_core_->SetOutputHooks(output_hookfuncs_); + } } void NaiveExecutor::RegisterInputHook(const HookFunc &hookfunc) { diff --git a/paddle/fluid/framework/naive_executor.h b/paddle/fluid/framework/naive_executor.h index 85f98046285b3..7d937ea0f4b05 100644 --- a/paddle/fluid/framework/naive_executor.h +++ b/paddle/fluid/framework/naive_executor.h @@ -26,6 +26,11 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" +#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h" +#include "paddle/fluid/framework/new_executor/interpretercore.h" + +#include "paddle/pir/core/program.h" + namespace paddle { namespace framework { @@ -52,6 +57,18 @@ class NaiveExecutor { int block_id, bool with_feed_fetch_ops); + void PrepareInterpreterCore( + Scope* scope, + const ProgramDesc& program_desc, + const framework::interpreter::ExecutionConfig& execution_config = + framework::interpreter::ExecutionConfig{}); + + void PrepareInterpreterCore( + Scope* scope, + const ::pir::Program& pir_program, + const framework::interpreter::ExecutionConfig& execution_config = + framework::interpreter::ExecutionConfig{}); + // Create variables before head. // Create parameters if persistable is true, or create the temporary variables // instead. @@ -63,6 +80,9 @@ class NaiveExecutor { // Run all the operators. void Run(); + void RunInterpreterCore(const std::vector& feed_names = {}, + bool need_fetch = false); + // Get an tensor to operating directly, without the need for feed_ops. phi::DenseTensor* FindTensor(const std::string& name); @@ -96,6 +116,8 @@ class NaiveExecutor { std::unordered_map> reuse_cache_; std::vector cluster_buffer_; + + std::unique_ptr interpreter_core_; }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index 2716846b0e4de..e4b8ee2560c29 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -6,18 +6,17 @@ add_subdirectory(pir_adaptor) set(STANDALONE_EXECUTOR_SRCS feed_fetch_utils.cc interpretercore.cc new_executor_defs.cc - standalone_executor.cc program_interpreter.cc new_ir_interpreter.cc) + standalone_executor.cc program_interpreter.cc pir_interpreter.cc) set(STANDALONE_EXECUTOR_DEPS interpreter interpretercore_garbage_collector workqueue - pd_op_dialect - pd_op_to_kernel_pass + op_dialect_vjp + transform pir_adaptor program_translator instruction_base - pd_inplace_pass pir plan) diff --git a/paddle/fluid/framework/new_executor/feed_fetch_utils.cc b/paddle/fluid/framework/new_executor/feed_fetch_utils.cc index f9f922fc93264..dee86a8463d0f 100644 --- a/paddle/fluid/framework/new_executor/feed_fetch_utils.cc +++ b/paddle/fluid/framework/new_executor/feed_fetch_utils.cc @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/feed_fetch_utils.h" - #include #include -#include "paddle/fluid/framework/new_executor/new_executor_defs.h" -#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/new_executor/feed_fetch_utils.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" namespace paddle { namespace framework { @@ -26,6 +25,8 @@ namespace framework { void SetColAttrForFeedFetchOps(std::shared_ptr program_desc, const int64_t micro_batch_num, const int64_t micro_batch_id) { + if (micro_batch_num < 2) return; + const std::set& valid_feed_fetch_op_types = { "fetch", "fetch_v2", "feed"}; for (const auto& op_desc : program_desc->MutableBlock(0)->AllOps()) { @@ -48,5 +49,203 @@ void SetColAttrForFeedFetchOps(std::shared_ptr program_desc, } } +void SplitFeedTensors(const std::vector& feed_names, + const int64_t micro_batch_num, + Scope* scope, + std::vector>* out) { + std::vector feed_tensors; + for (size_t i = 0; i < feed_names.size(); ++i) { + auto feed_name = feed_names[i]; + auto feed_var = scope->GetVar(feed_name); + PADDLE_ENFORCE_NOT_NULL( + feed_var, + platform::errors::NotFound("Variable %s should not be nullptr.", + feed_names[i])); + feed_tensors.push_back(feed_var->Get()); + } + + out->resize(micro_batch_num); + if (micro_batch_num < 2) { + (*out)[0] = std::move(feed_tensors); + return; + } + + for (size_t i = 0; i < feed_tensors.size(); ++i) { + auto& feed_tensor = feed_tensors[i]; + int64_t numel_size = feed_tensor.dims()[0]; + PADDLE_ENFORCE_EQ(numel_size % micro_batch_num, + 0, + phi::errors::InvalidArgument( + "Split expects feed data (%s)'s dim[0] (%d) is " + "diviable by micro_batch_num (%d).", + feed_names[i], + numel_size, + micro_batch_num)); + int64_t split_size = numel_size / micro_batch_num; + VLOG(4) << "Split feed data:" << feed_names[i] << ", dims:(" + << feed_tensor.dims() << "), micro_batch_num:" << micro_batch_num; + for (int64_t j = 0; j < micro_batch_num; ++j) { + (*out)[j].resize(i + 1); + (*out)[j][i].ShareDataWith( + feed_tensor.Slice(j * split_size, j * split_size + split_size)); + } + } +} + +void FetchTensors(const std::vector& job_fetch_names, + const std::vector& fetch_var_names, + const int64_t micro_batch_id, + Scope* scope, + FetchUnmergedList* fetch_list) { + PADDLE_ENFORCE_GT( + fetch_list->size(), + micro_batch_id, + phi::errors::Unavailable("The fetch list size (%lld) should be greater " + "than micro_batch_id (%lld)", + fetch_list->size(), + micro_batch_id)); + + fetch_list->at(micro_batch_id).resize(fetch_var_names.size()); + for (auto& var_name : job_fetch_names) { + int col = find(fetch_var_names.begin(), fetch_var_names.end(), var_name) - + fetch_var_names.begin(); + auto* var = scope->FindVar(var_name); + auto& src = var->Get(); + auto* dst = + &(PADDLE_GET(phi::DenseTensor, fetch_list->at(micro_batch_id)[col])); + TensorCopy(src, platform::CPUPlace(), dst); + } +} + +void MergeFetchTensors(const FetchUnmergedList& fetch_list, + const int64_t micro_batch_num, + FetchList* out) { + if (fetch_list.size() == 0) return; + + PADDLE_ENFORCE_EQ( + fetch_list.size(), + micro_batch_num, + phi::errors::Unavailable("The fetch_list size (%lld) shoule be equal to " + "the micro_batch_num (%lld)", + fetch_list.size(), + micro_batch_num)); + + if (micro_batch_num < 2) { + *out = std::move(fetch_list[0]); + return; + } + + out->resize(fetch_list[0].size()); + for (size_t i = 0; i < fetch_list[0].size(); ++i) { + std::vector tensors_ptr; + for (auto micro_batch_id = 0; micro_batch_id < micro_batch_num; + ++micro_batch_id) { + tensors_ptr.push_back( + &PADDLE_GET_CONST(phi::DenseTensor, fetch_list[micro_batch_id][i])); + } + phi::DenseTensor merged_tensor; + MergeTensors(tensors_ptr, platform::CPUPlace(), &merged_tensor); + out->at(i) = std::move(merged_tensor); + } +} + +void MergeTensors(const std::vector& tensors, + const platform::Place dst_place, + phi::DenseTensor* target) { + PADDLE_ENFORCE_EQ( + tensors.empty(), + false, + phi::errors::InvalidArgument("The tensors to be merged are empty.")); + + DDim new_dim = tensors[0]->dims(); + proto::VarType::Type new_type = proto::VarType::FP32; + phi::DataLayout new_layout = tensors[0]->layout(); + for (auto* t : tensors) { + if (t->numel() && t->IsInitialized()) { + new_dim = t->dims(); + new_type = framework::TransToProtoVarType(t->dtype()); + new_layout = t->layout(); + break; + } + } + + auto rank = tensors[0]->dims().size(); + if (rank == 0) { + std::vector init_shape = {1}; + new_dim = new_dim.reshape(init_shape); + } + + for (size_t i = 1; i < tensors.size(); ++i) { + auto* t = tensors[i]; + if (t->numel() && t->IsInitialized()) { + PADDLE_ENFORCE_EQ( + new_type, + framework::TransToProtoVarType(t->dtype()), + phi::errors::InvalidArgument( + "phi::DenseTensor data type does not match, expected type is %s, " + "actual " + "type is %s.", + DataTypeToString(new_type), + DataTypeToString(framework::TransToProtoVarType(t->dtype())))); + PADDLE_ENFORCE_EQ( + new_layout, + t->layout(), + phi::errors::InvalidArgument( + "phi::DenseTensor layout does not match, expected layout is %s, " + "actual layout is %s.", + phi::DataLayoutToString(new_layout), + phi::DataLayoutToString(t->layout()))); + if (rank > 0) { + auto tensor_dims = t->dims(); + PADDLE_ENFORCE_EQ(tensor_dims.size(), + new_dim.size(), + phi::errors::InvalidArgument( + "dimensions of DenseTensor does not match")); + for (int j = 1; j < t->dims().size(); j++) { + PADDLE_ENFORCE_EQ( + tensor_dims[j], + new_dim[j], + phi::errors::InvalidArgument( + "DenseTensor.ddim[%d] should eaqual to %d, but is %d", + j, + new_dim[j], + tensor_dims[j])); + } + new_dim[0] += t->dims()[0]; + } else if (rank == 0) { + auto tensor_dims = t->dims(); + PADDLE_ENFORCE_EQ(tensor_dims.size(), + 0, + phi::errors::InvalidArgument( + "dimensions of DenseTensor does not match")); + PADDLE_ENFORCE_EQ(new_dim.size(), + 1, + phi::errors::InvalidArgument( + "dimensions of DenseTensor does not match")); + new_dim[0] += 1; + } + } + } + + target->Resize(new_dim); + target->set_layout(new_layout); + target->mutable_data(dst_place, TransToPhiDataType(new_type)); + + int begin = 0; + for (auto* src : tensors) { + int src_dim = 1; + if (src->dims()[0] > 0) { + src_dim = src->dims()[0]; + } + int end = static_cast(begin + src_dim); + if (end == begin) { + continue; + } + auto dst = target->Slice(begin, end); + TensorCopy(*src, dst_place, &dst); + begin = end; + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/feed_fetch_utils.h b/paddle/fluid/framework/new_executor/feed_fetch_utils.h index 65f1f6cab74bd..d1eff750d4a8f 100644 --- a/paddle/fluid/framework/new_executor/feed_fetch_utils.h +++ b/paddle/fluid/framework/new_executor/feed_fetch_utils.h @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/new_executor/interpreter/plan.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" namespace paddle { namespace framework { @@ -26,5 +27,24 @@ void SetColAttrForFeedFetchOps(std::shared_ptr program_desc, const int64_t micro_batch_num, const int64_t micro_batch_id); +void SplitFeedTensors(const std::vector& feed_names, + const int64_t micro_batch_num, + Scope* scope, + std::vector>* out); + +void FetchTensors(const std::vector& job_fetch_names, + const std::vector& fetch_var_names, + const int64_t micro_batch_id, + Scope* scope, + FetchUnmergedList* fetch_list); + +void MergeFetchTensors(const FetchUnmergedList& fetch_list, + const int64_t micro_batch_num, + FetchList* out); + +void MergeTensors(const std::vector& tensors, + const platform::Place dst_place, + phi::DenseTensor* target); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.cc b/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.cc index 42f5d7e765106..166853e2b18da 100644 --- a/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.cc +++ b/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.cc @@ -17,32 +17,10 @@ #include "paddle/fluid/framework/new_executor/garbage_collector/event_garbage_collector.h" #include "paddle/fluid/framework/new_executor/garbage_collector/fast_garbage_collector.h" #include "paddle/fluid/framework/new_executor/garbage_collector/no_event_garbage_collector.h" -#include "paddle/phi/core/flags.h" - -PHI_DECLARE_bool(fast_eager_deletion_mode); -PHI_DECLARE_bool(new_executor_use_cuda_graph); namespace paddle { namespace framework { -bool IsInterpretercoreFastGCEnabled() { - // When using cuda graph, fast GC must be used. Because - // `EventQuery` method in event GC cannot be used in - // cuda graph. - PADDLE_ENFORCE_EQ(memory::allocation::AllocatorFacade::Instance() - .IsStreamSafeCUDAAllocatorUsed() == false && - FLAGS_new_executor_use_cuda_graph, - false, - platform::errors::InvalidArgument( - "When FLAGS_new_executor_use_cuda_graph is true, " - "IsStreamSafeCUDAAllocatorUsed must be true, but " - "got false.")); - return (memory::allocation::AllocatorFacade::Instance() - .IsStreamSafeCUDAAllocatorUsed() && - FLAGS_fast_eager_deletion_mode) || - FLAGS_new_executor_use_cuda_graph; -} - InterpreterCoreGarbageCollector::InterpreterCoreGarbageCollector() { garbages_ = std::make_unique(); max_memory_size_ = static_cast(GetEagerDeletionThreshold()); diff --git a/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h b/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h index fb697e887216e..dc84e88cdcd98 100644 --- a/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h +++ b/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h @@ -16,11 +16,14 @@ #include #include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" -#include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/memory/allocation/spin_lock.h" #include "paddle/fluid/platform/device_event.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" +#include "paddle/phi/core/flags.h" + +PHI_DECLARE_bool(fast_eager_deletion_mode); +PHI_DECLARE_bool(new_executor_use_cuda_graph); namespace paddle { namespace framework { @@ -46,7 +49,23 @@ class InterpreterCoreGarbageCollector { memory::SpinLock spinlock_; }; -bool IsInterpretercoreFastGCEnabled(); +inline bool IsInterpretercoreFastGCEnabled() { + // When using cuda graph, fast GC must be used. Because + // `EventQuery` method in event GC cannot be used in + // cuda graph. + PADDLE_ENFORCE_EQ(memory::allocation::AllocatorFacade::Instance() + .IsStreamSafeCUDAAllocatorUsed() == false && + FLAGS_new_executor_use_cuda_graph, + false, + platform::errors::InvalidArgument( + "When FLAGS_new_executor_use_cuda_graph is true, " + "IsStreamSafeCUDAAllocatorUsed must be true, but " + "got false.")); + return (memory::allocation::AllocatorFacade::Instance() + .IsStreamSafeCUDAAllocatorUsed() && + FLAGS_fast_eager_deletion_mode) || + FLAGS_new_executor_use_cuda_graph; +} std::unique_ptr CreateInterpreterCoreGarbageCollector( diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc index 8841103213400..e549b243f87ec 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc @@ -17,102 +17,111 @@ #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" #include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" #include "paddle/cinn/hlir/framework/instruction.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" +#include "paddle/cinn/runtime/cuda/cuda_util.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" #include "paddle/fluid/framework/paddle2cinn/transform_type.h" namespace paddle { namespace framework { -// TODO(Aurelius84): Think deeply what's the responsibility is it. -// Currently it assumes CinnLaunchContext role. -class JitContext { +class CinnJitInstruction::FnPtrImpl { + using CUDAJITInfo = cinn::hlir::framework::pir::CUDAJITInfo; + public: - cinn_buffer_t* GetCinnBufferOfVar(const std::string& name) { - auto res = paddle2argument_.find(name); - PADDLE_ENFORCE_NE( - res, - paddle2argument_.end(), - platform::errors::NotFound( - "Variable(%s) not found in compilation result", name)); - return static_cast(res->second); - } + explicit FnPtrImpl(const CUDAJITInfo& cuda_jit_info) + : cuda_jit_info_(cuda_jit_info) {} + void Run(const std::vector& kernel_args, void* stream) { + func_args_.clear(); + ptr_storage_.resize(kernel_args.size()); + for (size_t i = 0; i < kernel_args.size(); ++i) { + ptr_storage_[i] = kernel_args[i]->data(); + func_args_.push_back(ptr_storage_.data() + i); + } - // NOTE(Aurelius84): Before running each instruction, we should share Tensor - // memory from paddle scope with cinn_buffer_t from cinn scope including - // inputs and outputs. - void ShareMemToCinn(const std::string& var_name, - const phi::Place& place, - Scope* scope) { - cinn_buffer_t* buffer = GetCinnBufferOfVar(var_name); - auto* tensor = scope->GetVar(var_name)->GetMutable(); - // TODO(Aurelius84): Maybe we should consider to unify the Scope - // structure between paddle and cinn, so that we don't need to develop - // the glue code. - buffer->memory = reinterpret_cast(tensor->mutable_data( - place, paddle2cinn::TransToPaddleDataType(buffer->type))); + CUDA_DRIVER_CALL( + cuLaunchKernel(static_cast(cuda_jit_info_.fn_ptr), + cuda_jit_info_.grid_dims[0], + cuda_jit_info_.grid_dims[1], + cuda_jit_info_.grid_dims[2], + cuda_jit_info_.block_dims[0], + cuda_jit_info_.block_dims[1], + cuda_jit_info_.block_dims[2], + 0, // share memory + static_cast(stream), + func_args_.data(), + nullptr)) } - // TODO(Aurelius84): Add logic to parse stream for different device. - void* GetStream() { return nullptr; } - private: - // because a cinn_pod_value_t does not own a cinn_buffer_t object, - // an extra stroage is necessary to keep those objects and they can - // not be released until the runtime program finish execution. - std::vector> hold_buffers_; - // this map saves all execution arguments with their cinn names as key, - // and it is passed to the Execute interface of a cinn runtime program. - std::map name2argument_; - // this map saves all execution arguments with paddle variables as key, - // this map conbine name2argument_ and paddle2cinn_varmap_ - std::map paddle2argument_; -}; + CUDAJITInfo cuda_jit_info_; -// TODO(Aurelius84): Impl should hold JitContext instance to -// deliver the device context for 'instr->Run' and responsible -// to deal with inner buffer_t shareing between framework::Scope -// and cinn::Scope. -class CinnJitInstruction::Impl { - using Instruction = cinn::hlir::framework::Instruction; - - public: - explicit Impl(Instruction* instr) : instr_(instr) {} - // TODO(Aurelus84): Support to specify name2podargs and stream arguments. - void Run() { - PADDLE_ENFORCE_NOT_NULL( - instr_, platform::errors::NotFound("instr_ should not be NULL")); - instr_->Run(/*name2podargs=*/nullptr, - false, - /*stream=*/nullptr, - /*use_cache=*/true); - } - const Instruction* pointer() const { return instr_; } - - private: - Instruction* instr_{nullptr}; + std::vector ptr_storage_; + std::vector func_args_; }; -CinnJitInstruction::CinnJitInstruction(size_t id, - const platform::Place& place, - ::pir::Operation* op, - Scope* scope) +CinnJitInstruction::CinnJitInstruction( + size_t id, + const platform::Place& place, + ::pir::Operation* op, + const ValueExecutionInfo& value_exec_info) : InstructionBase(id, place) { - // TODO(Aurelius84): We shall simplify members of JitKernelOp to make it - // only hold related function ptrs. Impl is the real runtime data structure - // responsible to construct hlir::framework::Instruction. auto jit_kernel_op = op->dyn_cast(); - impl_ = std::make_shared(jit_kernel_op.instruction()); + fn_ptr_impl_ = std::make_shared(jit_kernel_op.cuda_jit_info()); op_ = op; + + place_ = place; + + InitInputsOutputsIds(op, value_exec_info); + + for (size_t i = 0; i < op->num_operands(); ++i) { + auto in = op->operand_source(i); + + auto var_name = value_exec_info.GetVarName(in); + + auto tensor = value_exec_info.GetScope() + ->Var(var_name) + ->GetMutable(); + + tensor_args_.push_back(tensor); + } + + dev_ctx_ = phi::DeviceContextPool::Instance().Get(place_); + + for (size_t i = 0; i < op->num_results(); ++i) { + pir::Value result = op->result(i); + auto var_name = value_exec_info.GetVarName(result); + + auto tensor = value_exec_info.GetScope() + ->Var(var_name) + ->GetMutable(); + + tensor_args_.push_back(tensor); + + out_tensor_ = tensor; + + auto alloc_tensor_type = + result.type().dyn_cast(); + tensor->set_type( + paddle::dialect::TransToPhiDataType(alloc_tensor_type.dtype())); + tensor->Resize(alloc_tensor_type.dims()); + } } void CinnJitInstruction::Run() { - VLOG(6) << "Run cinn jit_kernel_op : " << Name(); - impl_->Run(); + auto gpu_ctx = static_cast(dev_ctx_); + + auto stream = gpu_ctx->stream(); + for (size_t i = 0; i < tensor_args_.size(); ++i) { + gpu_ctx->Alloc(tensor_args_[i], tensor_args_[i]->dtype()); + } + + fn_ptr_impl_->Run(tensor_args_, static_cast(stream)); } const std::string& CinnJitInstruction::Name() const { - // TODO(Aurelius84): Consider the case for instrucitons constaning - // multipule function ptrs and function names. - return impl_->pointer()->function_name(); + static const std::string name = "cinn_jit"; + return name; } } // namespace framework diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h index 5f5e4f74e8884..ceb4014f044a6 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h @@ -30,7 +30,7 @@ class CinnJitInstruction : public InstructionBase { CinnJitInstruction(size_t id, const platform::Place& place, ::pir::Operation* op, - Scope* scope); + const ValueExecutionInfo& value_exec_info); // TODO(Aurelius84): Only implement core interface and need implement GC and // Event logic. @@ -41,8 +41,17 @@ class CinnJitInstruction : public InstructionBase { ::pir::Operation* Operation() const override { return op_; } private: - class Impl; - std::shared_ptr impl_{nullptr}; + class FnPtrImpl; + + std::shared_ptr fn_ptr_impl_{nullptr}; + + platform::Place place_; + + phi::DeviceContext* dev_ctx_; + + phi::DenseTensor* out_tensor_; + + std::vector tensor_args_; ::pir::Operation* op_{nullptr}; // not owned }; diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc index 780219e406bff..c66e10a822056 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc @@ -16,8 +16,8 @@ #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" -#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/fluid/framework/new_executor/pir_interpreter.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" @@ -62,51 +62,17 @@ CondInstruction::CondInstruction(size_t id, } VLOG(6) << "finish process cond_var and output_vars"; - auto true_branch_block = if_op.true_block(); - auto true_branch_yied_inputs = GetYiedOpInputs(true_branch_block); - Scope* true_scope = &(value_exec_info->GetScope()->NewScope()); - true_branch_inter_ = - new NewIRInterpreter(place, - {}, - true_branch_block, - true_scope, - value_exec_info->NewChild(true_scope), - {}); - - std::set true_skip_gc_names_set; - for (auto value : true_branch_yied_inputs) { - true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value)); - true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value)); - } - true_branch_inter_->SetSkipGcVars(true_skip_gc_names_set); - VLOG(6) << "finish process true branch interpreter"; - - auto false_branch_block = if_op.false_block(); - auto false_branch_yied_inputs = GetYiedOpInputs(false_branch_block); - Scope* false_scope = &(value_exec_info->GetScope()->NewScope()); - false_branch_inter_ = - new NewIRInterpreter(place, - {}, - false_branch_block, - false_scope, - value_exec_info->NewChild(false_scope), - {}); - - std::set false_skip_gc_names_set; - for (auto value : false_branch_yied_inputs) { - false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value)); - false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value)); - } - false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set); - VLOG(6) << "finish process false branch interpreter"; - // NOTE(zhangbo): IfOp sub_block's inputs include two kind of value: one is // OpOperand of IfOp, and the other is external Values used in true_block or // false_block. + auto true_branch_block = if_op.true_block(); + auto false_branch_block = if_op.false_block(); std::unordered_map> inputs; GetInputIds(op, *value_exec_info, &inputs); - GetOutsideOpInputs(true_branch_block, *value_exec_info, &inputs); - GetOutsideOpInputs(false_branch_block, *value_exec_info, &inputs); + auto true_outside_inputs = + GetOutsideOpInputs(true_branch_block, *value_exec_info, &inputs); + auto false_outside_inputs = + GetOutsideOpInputs(false_branch_block, *value_exec_info, &inputs); SetInputs(inputs); std::unordered_map> outputs; @@ -125,6 +91,51 @@ CondInstruction::CondInstruction(size_t id, } SetOutputs(outputs); VLOG(6) << "finish process inputs outputs index"; + + Scope* true_scope = &(value_exec_info->GetScope()->NewScope()); + true_branch_inter_ = new PirInterpreter(place, + {}, + true_branch_block, + true_scope, + value_exec_info->NewChild(true_scope), + {}); + + std::set true_skip_gc_names_set; + for (auto value : GetYiedOpInputs(true_branch_block)) { + true_branch_outputs_.push_back(true_branch_inter_->GetNameByValue(value)); + true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value)); + true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value)); + } + // NOTE(zhangbo): According to the concept of control flow, child scopes + // should not control the lifecycle of parent scope variables. + for (auto value : true_outside_inputs) { + true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value)); + true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value)); + } + true_branch_inter_->SetSkipGcVars(true_skip_gc_names_set); + VLOG(6) << "finish process true branch interpreter"; + + Scope* false_scope = &(value_exec_info->GetScope()->NewScope()); + false_branch_inter_ = + new PirInterpreter(place, + {}, + false_branch_block, + false_scope, + value_exec_info->NewChild(false_scope), + {}); + + std::set false_skip_gc_names_set; + for (auto value : GetYiedOpInputs(false_branch_block)) { + false_branch_outputs_.push_back(false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value)); + } + for (auto value : false_outside_inputs) { + false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value)); + } + false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set); + VLOG(6) << "finish process false branch interpreter"; } CondInstruction::~CondInstruction() { @@ -137,7 +148,7 @@ CondInstruction::~CondInstruction() { } void CondInstruction::CopyBranchOutput( - const std::vector& var_names, const NewIRInterpreter* inter) { + const std::vector& var_names, const PirInterpreter* inter) { for (size_t i = 0; i < var_names.size(); ++i) { auto* inner_var = inter->InnerScope()->GetVar(var_names[i]); @@ -150,10 +161,10 @@ void CondInstruction::Run() { DeviceContext().Wait(); if (cond_var_->Get().data()[0]) { true_branch_inter_->Run({}, false); - CopyBranchOutput(true_skip_gc_names_, true_branch_inter_); + CopyBranchOutput(true_branch_outputs_, true_branch_inter_); } else { false_branch_inter_->Run({}, false); - CopyBranchOutput(false_skip_gc_names_, false_branch_inter_); + CopyBranchOutput(false_branch_outputs_, false_branch_inter_); } // copy ouptut diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h index 1cdc4a388126a..79af374ecdd32 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h @@ -24,7 +24,7 @@ namespace paddle { namespace framework { class Scope; class Value; -class NewIRInterpreter; +class PirInterpreter; class ValueExecutionInfo; class CondInstruction : public InstructionBase { @@ -44,7 +44,7 @@ class CondInstruction : public InstructionBase { private: void CopyBranchOutput(const std::vector& var_names, - const NewIRInterpreter* inter); + const PirInterpreter* inter); ::pir::Operation* op_; @@ -54,9 +54,13 @@ class CondInstruction : public InstructionBase { std::vector output_vars_; - NewIRInterpreter* true_branch_inter_; + PirInterpreter* true_branch_inter_; - NewIRInterpreter* false_branch_inter_; + PirInterpreter* false_branch_inter_; + + std::vector true_branch_outputs_; + + std::vector false_branch_outputs_; // TODO(zhangbo): Currently, only the output of IfOp is included. In the // future, need to consider how to support IfGradOp using IfOp value. diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_base.cc b/paddle/fluid/framework/new_executor/instruction/instruction_base.cc index 0b494c29dea86..62419acffc099 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_base.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_base.cc @@ -217,8 +217,11 @@ void InstructionBase::SetOutputs( void InstructionBase::InitInputsOutputsIds( ::pir::Operation* op, const ValueExecutionInfo& value_exec_info) { auto op_attributes = op->attributes(); - auto op_name = - op_attributes.at("op_name").dyn_cast().AsString(); + std::string op_name; + if (op_attributes.count("op_name ")) { + op_name = + op_attributes.at("op_name").dyn_cast().AsString(); + } std::unordered_map> inputs; for (size_t i = 0; i < op->num_operands(); i++) { pir::Value value = op->operand_source(i); @@ -257,8 +260,7 @@ void InstructionBase::InitInputsOutputsIds( std::string InstructionBase::DebugStringEx( const paddle::framework::Scope* scope, - const std::unordered_map<::pir::Value, std::string>& value_2_var_name) - const { + ValueExecutionInfo* value_exe_info) const { std::stringstream ss; ss << "Op(" << Name() << "), inputs:{"; @@ -268,7 +270,7 @@ std::string InstructionBase::DebugStringEx( auto& input = *it; bool is_no_need_buffer_var = (!no_need_buffer_vars.empty() && no_need_buffer_vars.count(input.first) > 0); - auto var_name = value_2_var_name.at(input.first); + auto var_name = value_exe_info->GetVarName(input.first); ss << var_name; if (scope) { if (!VarInited(*scope, var_name)) { @@ -296,7 +298,7 @@ std::string InstructionBase::DebugStringEx( ss << "}, outputs:{"; for (auto it = Outputs().begin(); it != Outputs().end();) { auto& output = *it; - auto var_name = value_2_var_name.at(output.first); + auto var_name = value_exe_info->GetVarName(output.first); ss << var_name; if (scope) { if (!VarInited(*scope, var_name)) { diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_base.h b/paddle/fluid/framework/new_executor/instruction/instruction_base.h index 6079742611915..5dd7ff3e4d2a5 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_base.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_base.h @@ -144,10 +144,8 @@ class InstructionBase { const ValueExecutionInfo& value_exec_info); // if scope is not null, also show dimensions of arguments - virtual std::string DebugStringEx( - const paddle::framework::Scope* scope, - const std::unordered_map<::pir::Value, std::string>& value_2_var_name) - const; + virtual std::string DebugStringEx(const paddle::framework::Scope* scope, + ValueExecutionInfo* value_exe_info) const; protected: size_t id_; diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc index cf845ca482437..42e595159f217 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc @@ -31,7 +31,9 @@ #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" -#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/pir/core/block_argument.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/phi/core/distributed/comm_context_manager.h" @@ -221,7 +223,7 @@ void GetInputIds(pir::Operation* op, } } -void GetOutsideOpInputs( +std::vector GetOutsideOpInputs( pir::Block* block, const ValueExecutionInfo& value_exec_info, std::unordered_map>* input_ids) { @@ -231,7 +233,11 @@ void GetOutsideOpInputs( inner_outputs.insert(op->result(i)); } } + for (size_t arg_id = 0; arg_id < block->args_size(); ++arg_id) { + inner_outputs.insert(block->argument(arg_id)); + } + std::vector outside_op_inputs; for (auto op : (*block)) { for (size_t i = 0; i < op->num_operands(); ++i) { pir::Value value = op->operand_source(i); @@ -244,9 +250,30 @@ void GetOutsideOpInputs( i, op->name())); input_ids->emplace(value, GetValueIds(value, value_exec_info)); + outside_op_inputs.push_back(value); } } } + return outside_op_inputs; +} + +bool GetCondData(const phi::DenseTensor& cond) { + if (paddle::platform::is_cpu_place(cond.place())) { + return cond.data()[0]; + } + // when platform::is_gpu_place(cond.place()) or + // platform::is_xpu_place(cond.place()) is true + std::unique_ptr cpu_cond{new phi::DenseTensor()}; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_CUSTOM_DEVICE) + paddle::framework::TensorCopySync(cond, platform::CPUPlace(), cpu_cond.get()); +#else + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "This version of PaddlePaddle does NOT support GPU/XPU but got " + "GPU/XPU tensor Cond in WhileOp. Please compile WITH_GPU or " + "WITH_XPU option.")); +#endif + return cpu_cond->data()[0]; } } // namespace framework diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.h b/paddle/fluid/framework/new_executor/instruction/instruction_util.h index fdc0e8774c1c5..14937ca5dc70c 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.h @@ -49,10 +49,11 @@ void GetInputIds(pir::Operation* op, const ValueExecutionInfo& value_exec_info, std::unordered_map>* input_ids); -void GetOutsideOpInputs( +std::vector GetOutsideOpInputs( pir::Block* block, const ValueExecutionInfo& value_exec_info, std::unordered_map>* input_ids); +bool GetCondData(const phi::DenseTensor& cond); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc index 97bda34777008..e4e3aef27e7c7 100644 --- a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc @@ -90,7 +90,7 @@ LegacyKernelInstruction::LegacyKernelInstruction( phi::errors::PreconditionNotMet( "can not find OpYamlInfoInterface from [%s]", legacy_op_name_)); paddle::dialect::OpYamlInfoParser yaml_info_parser( - yaml_interface->get_op_info_()); + yaml_interface->get_op_info_(), paddle::dialect::IsLegacyOp(op_name)); VLOG(6) << "finish process yaml_info_parser"; if (infer_meta_interface_) { @@ -166,12 +166,12 @@ LegacyKernelInstruction::~LegacyKernelInstruction() { } void LegacyKernelInstruction::Run() { + VLOG(6) << "Run op " << legacy_op_name_ << " infer meta."; if (infer_meta_interface_) { infer_meta_interface_->infer_meta_(&(infer_meta_context_)); } - VLOG(6) << "Run op " << legacy_op_name_ << " infer meta."; - (*(phi_kernel_))((kernel_context_)); VLOG(6) << "Run op " << legacy_op_name_ << " kernel."; + (*(phi_kernel_))((kernel_context_)); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc index 3f93161a363fa..f690b7290107b 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc @@ -94,7 +94,7 @@ PhiKernelInstruction::PhiKernelInstruction( phi::errors::PreconditionNotMet( "can not find OpYamlInfoInterface from [%s]", phi_op_name_)); paddle::dialect::OpYamlInfoParser yaml_info_parser( - yaml_interface->get_op_info_()); + yaml_interface->get_op_info_(), paddle::dialect::IsLegacyOp(op_name)); VLOG(6) << "finish process yaml_info_parser"; if (infer_meta_interface_) { diff --git a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc index 6108e5e73455f..38d8124bb28aa 100644 --- a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc @@ -16,8 +16,8 @@ #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" -#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/fluid/framework/new_executor/pir_interpreter.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" @@ -53,8 +53,6 @@ WhileInstruction::WhileInstruction(size_t id, SetKernelType(AnalyseOpFuncType(op, place)); VLOG(6) << "finish process analyse kernel type"; - Scope* inner_scope = local_scope == nullptr ? scope : local_scope; - VLOG(6) << "finish process inputs outputs index"; PADDLE_ENFORCE(op->isa(), @@ -63,41 +61,25 @@ WhileInstruction::WhileInstruction(size_t id, auto while_op = op->dyn_cast(); - cond_var_ = inner_scope->GetVar( + cond_var_ = parent_exe_info->GetScope()->FindVar( parent_exe_info->GetValue2VarName().at(while_op.operand_source(0))); + for (size_t i = 1; i < while_op.num_operands(); ++i) { - while_op_inputs_.push_back(inner_scope->GetVar( + inputs_.push_back(parent_exe_info->GetScope()->FindVar( parent_exe_info->GetValue2VarName().at(while_op.operand_source(i)))); } for (size_t i = 0; i < while_op.num_results(); ++i) { - while_op_outputs_.push_back(inner_scope->GetVar( + outputs_.push_back(parent_exe_info->GetScope()->FindVar( parent_exe_info->GetValue2VarName().at(while_op.result(i)))); } body_block_ = while_op.body_block(); - auto body_block_outputs = GetYiedOpInputs(body_block_); - - Scope* body_scope = &(parent_exe_info->GetScope()->NewScope()); - auto body_exe_info = parent_exe_info->NewChild(body_scope); - for (size_t i = 0; i < body_block_->args_size(); ++i) { - auto var_name = "body_block_arg_" + std::to_string(i); - body_scope->Var(var_name); - body_exe_info->Add(body_block_->argument(i), var_name); - } - body_inter_ = std::unique_ptr(new NewIRInterpreter( - place, {}, body_block_, body_scope, body_exe_info, {})); - - std::set body_skip_gc_names_set; - for (auto value : body_block_outputs) { - body_skip_gc_names_.push_back(body_inter_->GetNameByValue(value)); - body_skip_gc_names_set.insert(body_inter_->GetNameByValue(value)); - } - body_inter_->SetSkipGcVars(body_skip_gc_names_set); std::unordered_map> inputs; GetInputIds(op, *parent_exe_info, &inputs); - + auto body_outside_inputs = + GetOutsideOpInputs(body_block_, *parent_exe_info, &inputs); SetInputs(inputs); std::unordered_map> outputs; @@ -116,12 +98,35 @@ WhileInstruction::WhileInstruction(size_t id, } } SetOutputs(outputs); + + Scope* body_scope = &(parent_exe_info->GetScope()->NewScope()); + auto body_exe_info = parent_exe_info->NewChild(body_scope); + for (size_t i = 0; i < body_block_->args_size(); ++i) { + auto var_name = "body_block_arg_" + std::to_string(i); + body_scope->Var(var_name); + body_exe_info->Add(body_block_->argument(i), var_name); + } + body_inter_ = std::unique_ptr(new PirInterpreter( + place, {}, body_block_, body_scope, body_exe_info, {})); + + std::set body_skip_gc_names_set; + auto body_block_outputs = GetYiedOpInputs(body_block_); + for (auto value : body_block_outputs) { + body_outputs_.push_back(body_inter_->GetNameByValue(value)); + body_skip_gc_names_.push_back(body_inter_->GetNameByValue(value)); + body_skip_gc_names_set.insert(body_inter_->GetNameByValue(value)); + } + for (auto value : body_outside_inputs) { + body_skip_gc_names_.push_back(body_inter_->GetNameByValue(value)); + body_skip_gc_names_set.insert(body_inter_->GetNameByValue(value)); + } + body_inter_->SetSkipGcVars(body_skip_gc_names_set); } void WhileInstruction::CopyInputsToOutputs() { - for (size_t i = 0; i < while_op_outputs_.size(); ++i) { - while_op_outputs_[i]->GetMutable()->ShareDataWith( - while_op_inputs_[i]->Get()); + for (size_t i = 0; i < outputs_.size(); ++i) { + outputs_[i]->GetMutable()->ShareDataWith( + inputs_[i]->Get()); } } @@ -131,25 +136,26 @@ void WhileInstruction::PassArgsToBodyBlock() { auto var_name = body_inter_->GetNameByValue(block_arg); auto* inner_var = body_inter_->local_scope()->GetVar(var_name); inner_var->GetMutable()->ShareDataWith( - while_op_outputs_[i]->Get()); + outputs_[i]->Get()); } } void WhileInstruction::GetValueFromBodyBlock() { cond_var_->GetMutable()->ShareDataWith( body_inter_->local_scope() - ->GetVar(body_skip_gc_names_[0]) + ->GetVar(body_outputs_[0]) ->Get()); - for (size_t i = 0; i < while_op_outputs_.size(); ++i) { - auto& out_var_name = body_skip_gc_names_[i + 1]; + for (size_t i = 0; i < outputs_.size(); ++i) { + auto& out_var_name = body_outputs_[i + 1]; auto* out_var = body_inter_->local_scope()->GetVar(out_var_name); - while_op_outputs_[i]->GetMutable()->ShareDataWith( + outputs_[i]->GetMutable()->ShareDataWith( out_var->Get()); } } + void WhileInstruction::Run() { CopyInputsToOutputs(); - while (cond_var_->Get().data()[0]) { + while (GetCondData(cond_var_->Get())) { PassArgsToBodyBlock(); body_inter_->Run({}, false); GetValueFromBodyBlock(); diff --git a/paddle/fluid/framework/new_executor/instruction/while_instruction.h b/paddle/fluid/framework/new_executor/instruction/while_instruction.h index efe09b2f3e3f5..1c9cfabbde286 100644 --- a/paddle/fluid/framework/new_executor/instruction/while_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/while_instruction.h @@ -24,7 +24,7 @@ namespace paddle { namespace framework { class Scope; class Value; -class NewIRInterpreter; +class PirInterpreter; class ValueExecutionInfo; /// The execute semantics of while op ['output' = while_op('cond', 'intput')] @@ -44,7 +44,7 @@ class WhileInstruction : public InstructionBase { void Run() override; - const std::string& Name() const override { return cond_name_; } + const std::string& Name() const override { return name_; } ::pir::Operation* Operation() const override { return op_; } @@ -58,13 +58,15 @@ class WhileInstruction : public InstructionBase { // Get return value from body_block after each execution. void GetValueFromBodyBlock(); - std::string cond_name_{"while_instruction"}; + std::string name_{"while_instruction"}; Variable* cond_var_; - std::vector while_op_inputs_; - std::vector while_op_outputs_; - std::unique_ptr body_inter_; + std::vector inputs_; + std::vector outputs_; + + std::unique_ptr body_inter_; + std::vector body_outputs_; std::vector body_skip_gc_names_; ::pir::Block* body_block_; diff --git a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc index 4ce8c411a10b2..5dc9550dbadaa 100644 --- a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc +++ b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc @@ -547,13 +547,13 @@ void DependencyBuilder::UpdateVarMinRwOp( /// ======================== /// /// For new ir /// /// ======================== /// -NewIrDependencyBuilder::NewIrDependencyBuilder() { +PirDependencyBuilder::PirDependencyBuilder() { is_build_ = false; op_downstream_map_ = std::make_shared>>(); op_happens_before_ = std::make_shared>>(); } -const std::map>& NewIrDependencyBuilder::Build( +const std::map>& PirDependencyBuilder::Build( std::vector instructions) { if (is_build_) { return *op_downstream_map_; @@ -590,7 +590,7 @@ const std::map>& NewIrDependencyBuilder::Build( return *op_downstream_map_; } -void NewIrDependencyBuilder::BuildDownstreamMap() { +void PirDependencyBuilder::BuildDownstreamMap() { auto var2min_rw_op = std::map>(); // # map from variable id to read // write op id. @@ -664,8 +664,8 @@ void NewIrDependencyBuilder::BuildDownstreamMap() { } } -void NewIrDependencyBuilder::ShareDependencyFrom( - const NewIrDependencyBuilder& src) { +void PirDependencyBuilder::ShareDependencyFrom( + const PirDependencyBuilder& src) { std::tie(op_downstream_map_, op_happens_before_) = src.GetDependency(); is_build_ = true; } diff --git a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.h b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.h index 18a26ea770cec..f3b2ae305f3f9 100644 --- a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.h +++ b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.h @@ -103,9 +103,9 @@ class DependencyBuilder { /// ======================== /// /// For new ir /// /// ======================== /// -class NewIrDependencyBuilder : public DependencyBuilder { +class PirDependencyBuilder : public DependencyBuilder { public: - NewIrDependencyBuilder(); + PirDependencyBuilder(); // build op dependencies and return the mapping from op to its downstream-op // set @@ -114,7 +114,7 @@ class NewIrDependencyBuilder : public DependencyBuilder { void BuildDownstreamMap(); - void ShareDependencyFrom(const NewIrDependencyBuilder& src); + void ShareDependencyFrom(const PirDependencyBuilder& src); private: std::vector instructions_; // not_owned diff --git a/paddle/fluid/framework/new_executor/interpreter/execution_config.h b/paddle/fluid/framework/new_executor/interpreter/execution_config.h index 828678fa59da1..def76235331f1 100644 --- a/paddle/fluid/framework/new_executor/interpreter/execution_config.h +++ b/paddle/fluid/framework/new_executor/interpreter/execution_config.h @@ -29,6 +29,7 @@ struct ExecutionConfig { bool used_for_cinn{false}; bool used_for_control_flow_op{false}; bool used_for_jit{false}; + bool used_for_inference{false}; size_t device_num_threads{0}; size_t host_num_threads{0}; diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index fef6b91f95026..ee8c8cd2ec79d 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -46,10 +46,6 @@ #ifdef PADDLE_WITH_CUSTOM_DEVICE #include "paddle/phi/backends/device_manager.h" #endif -PADDLE_DEFINE_EXPORTED_bool( - new_executor_log_memory_stats, - false, - "Log memory stats after each op runs, just used for debug."); PHI_DECLARE_bool(use_mkldnn); PHI_DECLARE_bool(check_nan_inf); @@ -602,6 +598,11 @@ void BuildOpFuncList(const platform::Place& place, for (size_t i = 0; i < ops.size(); ++i) { auto op = ops[i].get(); const std::string& op_type = op->Type(); + if (execution_config.used_for_inference) { + if (op_type == "feed" || op_type == "fetch") { + continue; + } + } VLOG(6) << "Build OpFuncNode from : " << op_type; @@ -986,7 +987,7 @@ void BuildOpFuncList(const platform::Place& place, // gc--------------------------------------------- auto iter = unused_var_map.find(op); if (iter == unused_var_map.end()) { - interpreter::LogDeviceMemoryStats(place); + memory::LogDeviceMemoryStats(place, op_type); continue; } @@ -1036,7 +1037,7 @@ void BuildOpFuncList(const platform::Place& place, } delete garbages; // free mem - interpreter::LogDeviceMemoryStats(place); + memory::LogDeviceMemoryStats(place, op_type); } } @@ -1131,21 +1132,6 @@ void BuildVariableScope(const framework::BlockDesc& block, } } -void LogDeviceMemoryStats(const platform::Place& place) { - if (FLAGS_new_executor_log_memory_stats && platform::is_gpu_place(place)) { - VLOG(0) << "memory_allocated: " - << static_cast(memory::DeviceMemoryStatCurrentValue( - "Allocated", place.device)) / - 1024 / 1024 - << " MB"; - VLOG(0) << "max_memory_allocated: " - << static_cast(memory::DeviceMemoryStatPeakValue( - "Allocated", place.device)) / - 1024 / 1024 - << " MB"; - } -} - void SetDeviceCommContext(framework::OperatorBase* operator_base, platform::DeviceContext* dev_ctx) { if (operator_base->HasAttr("ring_id")) { @@ -1221,7 +1207,7 @@ const paddle::framework::Variable* GetVariableByName( return nullptr; } -std::vector GetOriginInputNames(std::string op_name) { +std::vector GetOriginInputNames(const std::string& op_name) { std::vector ret; pir::IrContext* ctx = pir::IrContext::Instance(); pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); @@ -1234,7 +1220,7 @@ std::vector GetOriginInputNames(std::string op_name) { return ret; } -std::vector GetOriginOutputNames(std::string op_name) { +std::vector GetOriginOutputNames(const std::string& op_name) { std::vector ret; pir::IrContext* ctx = pir::IrContext::Instance(); pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h index 49bcd8de0b4b1..5f7d7216c003e 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h @@ -110,7 +110,8 @@ void BuildVariableScope(const framework::BlockDesc& block, void BuildId2VarName(const std::map& var_name_2_id, std::unordered_map* id_2_var_name); -void LogDeviceMemoryStats(const platform::Place& place); +void LogDeviceMemoryStats(const platform::Place& place, + const std::string& op_name); void SetDeviceCommContext(framework::OperatorBase* operator_base, platform::DeviceContext* dev_ctx); @@ -125,9 +126,9 @@ const paddle::framework::Variable* GetVariableByName( const std::unordered_map& variable_2_var_name); -std::vector GetOriginInputNames(std::string op_name); +std::vector GetOriginInputNames(const std::string& op_name); -std::vector GetOriginOutputNames(std::string op_name); +std::vector GetOriginOutputNames(const std::string& op_name); void PrintValuesAndVariables( const pir::Block& block, diff --git a/paddle/fluid/framework/new_executor/interpreter/job.h b/paddle/fluid/framework/new_executor/interpreter/job.h index 493063f9e1516..952702d6e2f0a 100644 --- a/paddle/fluid/framework/new_executor/interpreter/job.h +++ b/paddle/fluid/framework/new_executor/interpreter/job.h @@ -31,27 +31,10 @@ class Job final { const std::string& Type() const { return type_; } - int ColAttrForFetchOp(int fetch_op_id) const { - return fetch_op_id_to_col_attr_.at(fetch_op_id); - } - int64_t MicroBatchId() const { return micro_batch_id_; } std::set SkipGcVars() const { return skip_gc_vars_; } - std::vector AllFetchOpIds() const { - std::vector fetch_op_ids; - fetch_op_ids.reserve(fetch_op_id_to_col_attr_.size()); - for (auto& item : fetch_op_id_to_col_attr_) { - fetch_op_ids.push_back(item.first); - } - return fetch_op_ids; - } - - void SetColAttrForFetchOp(int fetch_op_id, int col_attr) { - fetch_op_id_to_col_attr_[fetch_op_id] = col_attr; - } - void SetMicroBatchId(int64_t micro_batch_id) { PADDLE_ENFORCE_GE( micro_batch_id, @@ -71,11 +54,17 @@ class Job final { skip_gc_vars_ = skip_gc_vars; } + void SetFetchVarName(const std::string& fetch_var_name) { + fetch_var_names_.push_back(fetch_var_name); + } + + std::vector FetchVarNames() { return fetch_var_names_; } + private: const std::string type_; int64_t micro_batch_id_; - std::unordered_map fetch_op_id_to_col_attr_; std::set skip_gc_vars_; + std::vector fetch_var_names_; }; } // namespace interpreter diff --git a/paddle/fluid/framework/new_executor/interpreter/plan.cc b/paddle/fluid/framework/new_executor/interpreter/plan.cc index ab05c6216426a..ee2f5dc57acc6 100644 --- a/paddle/fluid/framework/new_executor/interpreter/plan.cc +++ b/paddle/fluid/framework/new_executor/interpreter/plan.cc @@ -21,7 +21,8 @@ namespace framework { namespace interpreter { Plan::Plan(const std::vector>& job_list, - const std::unordered_map& type_to_program) + const std::unordered_map>& + type_to_program) : job_list_(job_list), type_to_program_(type_to_program), micro_batch_num_(1) { @@ -65,7 +66,16 @@ const std::vector>& Plan::JobList() const { return job_list_; } -const ProgramDesc* Plan::Program(const std::string& job_type) const { +const std::vector Plan::JobTypes() const { + std::vector res; + for (auto kv : type_to_ir_program_) { + res.emplace_back(kv.first); + } + return res; +} + +const std::shared_ptr Plan::Program( + const std::string& job_type) const { return type_to_program_.at(job_type); } @@ -74,8 +84,8 @@ std::shared_ptr<::pir::Program> Plan::IrProgram( return type_to_ir_program_.at(job_type); } -void Plan::UpdateIrProgram(const std::string& job_type, - std::shared_ptr<::pir::Program> ir_prog) { +void Plan::SetIrProgram(const std::string& job_type, + std::shared_ptr<::pir::Program> ir_prog) { type_to_ir_program_[job_type] = ir_prog; } diff --git a/paddle/fluid/framework/new_executor/interpreter/plan.h b/paddle/fluid/framework/new_executor/interpreter/plan.h index 389eb5c9df84e..beb2c176f94ad 100644 --- a/paddle/fluid/framework/new_executor/interpreter/plan.h +++ b/paddle/fluid/framework/new_executor/interpreter/plan.h @@ -31,7 +31,8 @@ namespace interpreter { class Plan final { public: Plan(const std::vector>& job_list, - const std::unordered_map& type_to_program); + const std::unordered_map>& + type_to_program); Plan(const std::vector>& job_list, const std::unordered_map>& type_to_ir_program); @@ -39,18 +40,20 @@ class Plan final { ~Plan() = default; const std::vector>& JobList() const; + const std::vector JobTypes() const; - const ProgramDesc* Program(const std::string& job_type) const; + const std::shared_ptr Program(const std::string& job_type) const; std::shared_ptr<::pir::Program> IrProgram(const std::string& job_type) const; - void UpdateIrProgram(const std::string& job_type, - std::shared_ptr<::pir::Program> ir_prog); + void SetIrProgram(const std::string& job_type, + std::shared_ptr<::pir::Program> ir_prog); int64_t MicroBatchNum() const; private: const std::vector> job_list_; - const std::unordered_map type_to_program_; + const std::unordered_map> + type_to_program_; std::unordered_map> type_to_ir_program_; int64_t micro_batch_num_; diff --git a/paddle/fluid/framework/new_executor/interpreter/static_build.cc b/paddle/fluid/framework/new_executor/interpreter/static_build.cc index bebeb142d473f..2cf615d99b1ba 100644 --- a/paddle/fluid/framework/new_executor/interpreter/static_build.cc +++ b/paddle/fluid/framework/new_executor/interpreter/static_build.cc @@ -50,6 +50,7 @@ std::set OpsCanSkipedFakeAllocInStaticBuild = { "create_py_reader", "depend", "fetch_v2", + "print", "send_v2", "nop"}; @@ -170,10 +171,10 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) { std::stringstream ss; ss << "The following OPs are unable to static build:\n"; for (auto& item : invalid_ops) { - ss << item.first << " [in_black_list = " << (item.second >> 6 & 1) - << ", is_operator_base = " << (item.second >> 5 & 1) - << ", is_custom_op = " << (item.second >> 4 & 1) - << ", use_mkldnn = " << (item.second >> 3 & 1) + ss << item.first << " [in_black_list = " << (item.second >> 5 & 1) + << ", is_operator_base = " << (item.second >> 4 & 1) + << ", is_custom_op = " << (item.second >> 3 & 1) + << ", use_mkldnn = " << (item.second >> 2 & 1) << ", sub_block_can_not_static_build = " << (item.second >> 1 & 1) << "]\n"; } diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc index bbbaf4c0dd75f..5b60205fbc529 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc @@ -431,6 +431,7 @@ void analyse_event_info_for_two_instructions( if (has_data_dependency( instructions[cur_instr_id], instructions[next_instr_id]) || + !run_type_info[next_instr_id][DownstreamRunType::kEventRun].empty() || instructions[next_instr_id]->OpBase()->Type() == "depend") { waiter_instr_ids->insert(next_instr_id); return; @@ -490,6 +491,7 @@ void analyse_event_info_for_two_instructions< if (has_data_dependency( instructions[cur_instr_id], instructions[next_instr_id]) || + !run_type_info[next_instr_id][DownstreamRunType::kEventRun].empty() || instructions[next_instr_id]->Name() == "pd_op.depend") { waiter_instr_ids->insert(next_instr_id); return; @@ -651,7 +653,7 @@ void StreamAnalyzer::ShareEventInfoFrom(const StreamAnalyzer& src) { /// ======================== /// /// For new ir /// /// ======================== /// -void NewIrStreamAnalyzer::ConstructEvents( +void PirStreamAnalyzer::ConstructEvents( const std::vector>& instructions) { if (!is_event_info_build_) { @@ -664,7 +666,7 @@ void NewIrStreamAnalyzer::ConstructEvents( cross_step_merged_instructions_ptr.emplace_back(instr.get()); } - NewIrDependencyBuilder dependency_builder; + PirDependencyBuilder dependency_builder; dependency_builder.Build(cross_step_merged_instructions_ptr); const std::map>& downstream_map = dependency_builder.OpDownstreamMap(); @@ -729,7 +731,7 @@ void NewIrStreamAnalyzer::ConstructEvents( } } -void NewIrStreamAnalyzer::AnalyseAllRunType( +void PirStreamAnalyzer::AnalyseAllRunType( const std::vector& instructions, const std::map>& downstream_map, std::vector>>* run_type_info) const { @@ -737,7 +739,7 @@ void NewIrStreamAnalyzer::AnalyseAllRunType( instructions, downstream_map, place_, run_type_info); } -void NewIrStreamAnalyzer::AnalyseAllEventInfo( +void PirStreamAnalyzer::AnalyseAllEventInfo( const std::vector& instructions, const std::vector>>& run_type_info, std::map>>* @@ -746,14 +748,14 @@ void NewIrStreamAnalyzer::AnalyseAllEventInfo( instructions, run_type_info, event_info); } -void NewIrStreamAnalyzer::ShrinkEventInfo( - const NewIrDependencyBuilder& dependency_builder, +void PirStreamAnalyzer::ShrinkEventInfo( + const PirDependencyBuilder& dependency_builder, std::map>>* event_info_map) const { - shrink_event_info(dependency_builder, event_info_map); + shrink_event_info(dependency_builder, event_info_map); } -platform::DeviceType NewIrStreamAnalyzer::GetWaiterType( +platform::DeviceType PirStreamAnalyzer::GetWaiterType( const paddle::framework::InstructionBase* instr) const { if (instr->KernelType() == OpFuncType::kCpuSync) { return platform::kCPU; @@ -767,14 +769,14 @@ platform::DeviceType NewIrStreamAnalyzer::GetWaiterType( } } -void NewIrStreamAnalyzer::ShareEventInfoFrom(const NewIrStreamAnalyzer& src) { +void PirStreamAnalyzer::ShareEventInfoFrom(const PirStreamAnalyzer& src) { event_info_ = src.GetEventInfo(); is_event_info_build_ = true; } std::shared_ptr< std::map>>> -NewIrStreamAnalyzer::GetEventInfo() const { +PirStreamAnalyzer::GetEventInfo() const { return event_info_; } diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h index 8f2ee33ca4ed5..6af94ea71a5a5 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h @@ -127,17 +127,17 @@ class StreamAnalyzer { /// ======================== /// /// For new ir /// /// ======================== /// -class NewIrStreamAnalyzer { +class PirStreamAnalyzer { public: using DeviceContext = platform::DeviceContext; using Place = platform::Place; - explicit NewIrStreamAnalyzer(const Place& place) : place_(place) { + explicit PirStreamAnalyzer(const Place& place) : place_(place) { event_info_ = std::make_shared< std::map>>>(); } - ~NewIrStreamAnalyzer() {} + ~PirStreamAnalyzer() {} void ConstructEvents( const std::vector>& @@ -146,7 +146,7 @@ class NewIrStreamAnalyzer { platform::DeviceType GetWaiterType( const paddle::framework::InstructionBase* instr) const; - void ShareEventInfoFrom(const NewIrStreamAnalyzer& src); + void ShareEventInfoFrom(const PirStreamAnalyzer& src); std::shared_ptr< std::map>>> @@ -165,7 +165,7 @@ class NewIrStreamAnalyzer { event_info) const; void ShrinkEventInfo( - const NewIrDependencyBuilder& dependency_builder, + const PirDependencyBuilder& dependency_builder, std::map>>* event_info_map) const; diff --git a/paddle/fluid/framework/new_executor/interpreter_base_impl.h b/paddle/fluid/framework/new_executor/interpreter_base_impl.h index 369216e0078c4..bea61cdeeec84 100644 --- a/paddle/fluid/framework/new_executor/interpreter_base_impl.h +++ b/paddle/fluid/framework/new_executor/interpreter_base_impl.h @@ -47,7 +47,7 @@ PHI_DECLARE_bool(check_nan_inf); PD_DECLARE_bool(benchmark); PHI_DECLARE_uint64(executor_log_deps_every_microseconds); PHI_DECLARE_bool(new_executor_use_cuda_graph); -PHI_DECLARE_bool(enable_new_ir_in_executor); +PHI_DECLARE_bool(enable_pir_in_executor); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PHI_DECLARE_bool(sync_nccl_allreduce); #endif @@ -67,7 +67,8 @@ class InterpreterBaseImpl { virtual ~InterpreterBaseImpl() = default; virtual paddle::framework::FetchList Run( const std::vector& feed_names, - const std::vector& feed_tensors) = 0; + const std::vector& feed_tensors, + bool need_fetch = true) = 0; virtual paddle::framework::FetchList Run( const std::vector& feed_names, bool need_fetch = true) = 0; diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 8e052d3b2685e..d7efd510535e8 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/framework/new_executor/interpretercore.h" -#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" +#include "paddle/fluid/framework/new_executor/pir_interpreter.h" #include "paddle/fluid/framework/new_executor/program_interpreter.h" #include "paddle/pir/core/program.h" #include "paddle/pir/core/value.h" @@ -54,7 +54,7 @@ InterpreterCore::InterpreterCore( framework::Scope* scope, const ExecutionConfig& execution_config) { VLOG(4) << "InterpreterCore(): " << this << " on " << place; - impl_ = std::make_unique( + impl_ = std::make_unique( place, fetch_var_names, ir_block, scope, execution_config); } @@ -65,8 +65,9 @@ InterpreterCore::~InterpreterCore() { FetchList InterpreterCore::Run( const std::vector& feed_names, - const std::vector& feed_tensors) { - return impl_->Run(feed_names, feed_tensors); + const std::vector& feed_tensors, + bool need_fetch) { + return impl_->Run(feed_names, feed_tensors, need_fetch); } FetchList InterpreterCore::Run(const std::vector& feed_names, diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index d21bd9e1fc378..022bc0c06f5b2 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -47,7 +47,8 @@ class InterpreterCore { paddle::framework::FetchList Run( const std::vector& feed_names, - const std::vector& feed_tensors); + const std::vector& feed_tensors, + bool need_fetch = true); paddle::framework::FetchList Run(const std::vector& feed_names, bool need_fetch = true); diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index f4ca2f20d01ae..a336e2c377dfd 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -19,6 +19,8 @@ #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h" #include "paddle/fluid/platform/profiler/event_tracing.h" namespace paddle { @@ -237,7 +239,8 @@ const std::vector& Instruction::GCCheckVars() const { } void Instruction::ResetContext(const VariableValueMap& in_vars, - const VariableValueMap& out_vars) { + const VariableValueMap& out_vars, + const std::string& op_name) { runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars)); infershape_ctx_.reset( new RuntimeInferShapeContext(*OpBase(), *runtime_ctx_.get())); @@ -246,16 +249,37 @@ void Instruction::ResetContext(const VariableValueMap& in_vars, static framework::Scope scope_; execution_ctx_.reset( new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get())); + + auto op_with_kernel = + dynamic_cast(OpBase()); + if (op_with_kernel != nullptr && op_with_kernel->Info().infer_meta_) { + if (infershape_ctx_->HasRuntimeAttributes() == false) { + compat_infermeta_ctx_ = paddle::framework::BuildInferMetaContext( + infershape_ctx_.get(), op_name); + can_use_infermeta_ctx_ = true; + } + } } void Instruction::ResetContextWithScope(const VariableValueMap& in_vars, const VariableValueMap& out_vars, - const framework::Scope& scope) { + const framework::Scope& scope, + const std::string& op_name) { runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars)); infershape_ctx_.reset( new RuntimeInferShapeContext(*OpBase(), *runtime_ctx_.get())); execution_ctx_.reset( new ExecutionContext(*OpBase(), scope, dev_ctx_, *runtime_ctx_.get())); + + auto op_with_kernel = + dynamic_cast(OpBase()); + if (op_with_kernel != nullptr && op_with_kernel->Info().infer_meta_) { + if (infershape_ctx_->HasRuntimeAttributes() == false) { + compat_infermeta_ctx_ = paddle::framework::BuildInferMetaContext( + infershape_ctx_.get(), op_name); + can_use_infermeta_ctx_ = true; + } + } } std::shared_ptr Instruction::InnerRuntimeContext() const { @@ -267,6 +291,10 @@ std::shared_ptr Instruction::InnerInferShapeContext() return infershape_ctx_; } +const phi::InferMetaContext* Instruction::InnerCompatInferMetaContext() const { + return &compat_infermeta_ctx_; +} + std::shared_ptr Instruction::InnerExecutionContext() const { return execution_ctx_; } @@ -286,5 +314,40 @@ void Instruction::AddInplace(Variable* in, Variable* out) { void Instruction::ClearInplace() { vec_inplace_in_to_out_.clear(); } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +void Instruction::UpdataRecordStreamForGcInfo() { + if (!IsInterpretercoreFastGCEnabled() || + KernelType() != OpFuncType::kGpuAsync) { + return; + } + if (DeviceContext().GetPlace().GetType() == phi::AllocationType::CUSTOM) { + return; + } + need_record_stream_for_gc_ = true; + + stream_ = reinterpret_cast(DeviceContext()).stream(); +// TODO(lizhiyu): Only analyse the 'send_v2' for GPT pp strategy right now. +// To support all the operators for communicating in the future. +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto operator_base_ptr = OpBase(); + if ((operator_base_ptr->Type() == "send_v2") && + (operator_base_ptr->Attr("use_calc_stream") == false)) { + int ring_id = operator_base_ptr->Attr("ring_id"); + if (FLAGS_dynamic_static_unified_comm) { + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + stream_ = static_cast( + comm_context_manager.Get(std::to_string(ring_id))) + ->GetStream(); + } else { + stream_ = platform::NCCLCommContext::Instance() + .Get(ring_id, DeviceContext().GetPlace()) + ->stream(); + } + } +#endif +} +#endif + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index ee9f17034a45f..91522969b0ced 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -18,12 +18,21 @@ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/platform/device_event_base.h" #include "paddle/fluid/platform/event.h" +#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/utils/rw_lock.h" +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); +#endif #define SCOPE_VARS_READER_LOCK AutoRDLock auto_lock(&vars_lock_); #define SCOPE_VARS_WRITER_LOCK AutoWRLock auto_lock(&vars_lock_); @@ -262,16 +271,20 @@ class Instruction { const std::vector& GCCheckVars() const; void ResetContext(const VariableValueMap& in_vars, - const VariableValueMap& out_vars); + const VariableValueMap& out_vars, + const std::string& op_name); void ResetContextWithScope(const VariableValueMap& in_vars, const VariableValueMap& out_vars, - const framework::Scope& scope); + const framework::Scope& scope, + const std::string& op_name); std::shared_ptr InnerRuntimeContext() const; std::shared_ptr InnerInferShapeContext() const; + const phi::InferMetaContext* InnerCompatInferMetaContext() const; + std::shared_ptr InnerExecutionContext() const; const platform::DeviceContext& DeviceContext() const; @@ -290,6 +303,15 @@ class Instruction { const OpFuncNode* OpFunc() const { return &op_func_node_; } + // record stream for gc +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + bool need_record_stream_for_gc_ = false; + gpuStream_t stream_{nullptr}; + void UpdataRecordStreamForGcInfo(); +#endif + + bool can_use_infermeta_ctx_ = false; + private: bool is_artificial_; // Instruction is artificial means that it is only used // to assist scheduling and no need to be executed. @@ -307,6 +329,7 @@ class Instruction { std::shared_ptr runtime_ctx_; std::shared_ptr infershape_ctx_; + paddle::framework::CompatInferMetaContext compat_infermeta_ctx_; std::shared_ptr execution_ctx_; std::vector gc_check_vars_; diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/CMakeLists.txt b/paddle/fluid/framework/new_executor/pir_adaptor/CMakeLists.txt index f66f96b7409fe..9236249f421d7 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/pir_adaptor/CMakeLists.txt @@ -4,4 +4,4 @@ file(GLOB PIR_ADAPTOR_SRCS "*.cc") cc_library( pir_adaptor SRCS ${PIR_ADAPTOR_SRCS} - DEPS program_translator pd_kernel_dialect) + DEPS program_translator op_dialect) diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index 409532aa59560..714a65b1d9261 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -50,10 +50,15 @@ std::shared_ptr ValueExecutionInfo::NewChild(Scope* scope) { std::shared_ptr info = std::make_shared(scope); info->parent_ = this; + info->value_2_var_name_ = this->value_2_var_name_; + info->var_2_var_name_ = this->var_2_var_name_; + info->var_name_2_id_ = this->var_name_2_id_; + info->id_2_var_name_ = this->id_2_var_name_; + info->var_list_ = this->var_list_; return info; } -void ValueExecutionInfo::Add(::pir::Value value, std::string var_name) { +void ValueExecutionInfo::Add(::pir::Value value, const std::string& var_name) { auto* var = scope_->FindVar(var_name); PADDLE_ENFORCE_NOT_NULL( var, platform::errors::NotFound("Cannot find %s in scope.", var_name)); @@ -79,8 +84,8 @@ void ValueExecutionInfo::Add(::pir::Value value, std::string var_name) { } void ValueExecutionInfo::Rename(pir::Value value, - std::string new_name, - std::string orig_name) { + const std::string& new_name, + const std::string& orig_name) { value_2_var_name_[value] = new_name; for (auto kv : value_2_var_name_) { @@ -157,54 +162,15 @@ void ValueExecutionInfo::ResetVarList(int id, Variable* var) { var_list_[id] = var; } -bool ValueExecutionInfo::HasValue(::pir::Value value) const { - return HasValueInternal(value); -} - -bool ValueExecutionInfo::HasLocalValue(::pir::Value value) const { - return HasValueLocally(value); -} - -std::string ValueExecutionInfo::GetVarName(::pir::Value value) const { - return GetVarNameInternal(value); -} - -std::string ValueExecutionInfo::GetVarName(const Variable* var) const { - return GetVarNameInternal(var); -} - -std::string ValueExecutionInfo::GetLocalVarName(::pir::Value value) const { - return GetVarNameLocally(value); -} - -std::string ValueExecutionInfo::GetLocalVarName(const Variable* var) const { - return GetVarNameLocally(var); -} - -int ValueExecutionInfo::GetVarId(::pir::Value value) const { - return GetVarIdInternal(value); -} - -int ValueExecutionInfo::GetVarId(const Variable* var) const { - return GetVarIdInternal(var); -} - -int ValueExecutionInfo::GetLocalVarId(::pir::Value value) const { - return GetVarIdLocally(value); -} - -int ValueExecutionInfo::GetLocalVarId(const Variable* var) const { - return GetVarIdLocally(var); -} - -bool ValueExecutionInfo::HasValueInternal(::pir::Value value) const { - if (HasValueLocally(value)) { +bool ValueExecutionInfo::HasVar(const std::string& var_name) const { + auto it = var_name_2_id_.find(var_name); + if (it != var_name_2_id_.end()) { return true; } - return (parent_ == nullptr) ? false : parent_->HasValueInternal(value); + return false; } -bool ValueExecutionInfo::HasValueLocally(::pir::Value value) const { +bool ValueExecutionInfo::HasValue(::pir::Value value) const { auto it = value_2_var_name_.find(value); if (it != value_2_var_name_.end()) { return true; @@ -212,15 +178,7 @@ bool ValueExecutionInfo::HasValueLocally(::pir::Value value) const { return false; } -std::string ValueExecutionInfo::GetVarNameInternal(::pir::Value value) const { - auto name = GetVarNameLocally(value); - if (name != "") { - return name; - } - return (parent_ == nullptr) ? "" : parent_->GetVarNameInternal(value); -} - -std::string ValueExecutionInfo::GetVarNameLocally(::pir::Value value) const { +std::string ValueExecutionInfo::GetVarName(::pir::Value value) const { auto it = value_2_var_name_.find(value); if (it != value_2_var_name_.end()) { return it->second; @@ -228,15 +186,7 @@ std::string ValueExecutionInfo::GetVarNameLocally(::pir::Value value) const { return ""; } -std::string ValueExecutionInfo::GetVarNameInternal(const Variable* var) const { - auto name = GetVarNameLocally(var); - if (name != "") { - return name; - } - return (parent_ == nullptr) ? "" : parent_->GetVarNameInternal(var); -} - -std::string ValueExecutionInfo::GetVarNameLocally(const Variable* var) const { +std::string ValueExecutionInfo::GetVarName(const Variable* var) const { auto it = var_2_var_name_.find(var); if (it != var_2_var_name_.end()) { return it->second; @@ -244,16 +194,8 @@ std::string ValueExecutionInfo::GetVarNameLocally(const Variable* var) const { return ""; } -int ValueExecutionInfo::GetVarIdInternal(::pir::Value value) const { - auto id = GetVarIdLocally(value); - if (id != -1) { - return id; - } - return (parent_ == nullptr) ? -1 : parent_->GetVarIdInternal(value); -} - -int ValueExecutionInfo::GetVarIdLocally(::pir::Value value) const { - auto var_name = GetVarNameLocally(value); +int ValueExecutionInfo::GetVarId(::pir::Value value) const { + auto var_name = GetVarName(value); auto it = var_name_2_id_.find(var_name); if (it != var_name_2_id_.end()) { return it->second; @@ -261,16 +203,8 @@ int ValueExecutionInfo::GetVarIdLocally(::pir::Value value) const { return -1; } -int ValueExecutionInfo::GetVarIdInternal(const Variable* var) const { - auto id = GetVarIdLocally(var); - if (id != -1) { - return id; - } - return (parent_ == nullptr) ? -1 : parent_->GetVarIdInternal(var); -} - -int ValueExecutionInfo::GetVarIdLocally(const Variable* var) const { - auto var_name = GetVarNameLocally(var); +int ValueExecutionInfo::GetVarId(const Variable* var) const { + auto var_name = GetVarName(var); auto it = var_name_2_id_.find(var_name); if (it != var_name_2_id_.end()) { return it->second; @@ -368,6 +302,7 @@ void BuildValue(pir::Value value, var->GetMutable(); } else if (value.type().isa()) { auto tensor_array = var->GetMutable(); + tensor_array->clear(); for (size_t i = 0; i < value.type().dyn_cast().size(); i++) { PADDLE_ENFORCE(value.type() @@ -409,9 +344,7 @@ void HandleForSpecialOp(pir::Operation* op, auto value = op->result(0); value_exe_info->Add(value, fetch_var_name); - } - - if (op_name == "pd_op.feed" || op_name == "pd_op.data") { + } else if (op_name == "pd_op.feed" || op_name == "pd_op.data") { VLOG(6) << "Handle for" << op_name; auto value = op->result(0); VLOG(6) << "link feed output to feed in variable" @@ -425,9 +358,7 @@ void HandleForSpecialOp(pir::Operation* op, "The variable %s shoud exist", name)); value_exe_info->Add(value, name); - } - - if (op_name == "builtin.combine") { + } else if (op_name == "builtin.combine") { auto out_value = op->result(0); Variable* var = nullptr; @@ -451,9 +382,7 @@ void HandleForSpecialOp(pir::Operation* op, tensor_array->emplace_back( value_exe_info->GetScope()->FindVar(value_2_var_name.at(value))); } - } - - if (op_name == "builtin.set_parameter") { + } else if (op_name == "builtin.set_parameter") { VLOG(6) << "Handle for builtin.set_parameter:"; auto param_name = op->attributes() .at("parameter_name") @@ -478,8 +407,7 @@ void HandleForSpecialOp(pir::Operation* op, } value_exe_info->Rename(value, param_name, orig_name); - } - if (op_name.compare(pir::ShadowOutputOp::name()) == 0) { + } else if (op_name == "builtin.shadow_output") { VLOG(6) << "Handle for builtin.shadow_ouptut"; auto var_name = op->attributes() .at("output_name") @@ -490,15 +418,15 @@ void HandleForSpecialOp(pir::Operation* op, // change opreand name to param_name auto orig_name = value_exe_info->GetValue2VarName().at(value); - if (value_exe_info->GetScope()->FindVar(var_name) == nullptr) { - const_cast(value_exe_info->GetScope()) - ->Rename(orig_name, var_name); + if (value_exe_info->GetScope()->FindVar(var_name) != nullptr) { + const_cast(value_exe_info->GetScope())->EraseVars({var_name}); + VLOG(1) << "var " << var_name << " has been removed from scope"; } + const_cast(value_exe_info->GetScope())->Rename(orig_name, var_name); + VLOG(8) << "var " << orig_name << " has been renamed to " << var_name; value_exe_info->Rename(value, var_name, orig_name); - } - - if (op_name == "builtin.get_parameter") { + } else if (op_name == "builtin.get_parameter") { VLOG(6) << "Handle for builtin.get_parameter:"; auto param_name = op->attributes() .at("parameter_name") @@ -507,9 +435,7 @@ void HandleForSpecialOp(pir::Operation* op, auto value = op->result(0); value_exe_info->Add(value, param_name); - } - - if (op_name == "builtin.slice") { + } else if (op_name == "builtin.slice") { VLOG(6) << "Handle for builtin.slice"; auto out_value = op->result(0); auto in_value = op->operand_source(0); @@ -534,9 +460,7 @@ void HandleForSpecialOp(pir::Operation* op, std::string var_name = value_exe_info->GetVar2VarName().at(variable_array[index]); value_exe_info->AddValue2VarName(out_value, var_name); - } - - if (op_name == "builtin.split") { + } else if (op_name == "builtin.split") { VLOG(6) << "Handle for builtin.split"; auto in_value = op->operand_source(0); PADDLE_ENFORCE_EQ(value_exe_info->GetValue2VarName().count(in_value), @@ -560,17 +484,13 @@ void HandleForSpecialOp(pir::Operation* op, value_exe_info->GetVar2VarName().at(variable_array[idx]); value_exe_info->AddValue2VarName(out_value, var_name); } - } - - if (op_name == "pd_op.if") { + } else if (op_name == "pd_op.if") { auto if_op = op->dyn_cast(); for (size_t i = 0; i < if_op->num_results(); ++i) { auto if_op_out_value = if_op->result(i); BuildValue(if_op_out_value, var_name_prefix, value_exe_info); } - } - - if (op_name == "pd_op.while") { + } else if (op_name == "pd_op.while") { auto while_op = op->dyn_cast(); for (size_t i = 0; i < while_op->num_results(); ++i) { @@ -594,7 +514,8 @@ void HandleForInplaceOp(pir::Operation* op, pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); paddle::dialect::OpYamlInfoParser yaml_parser( op_info.GetInterfaceImpl() - ->get_op_info_()); + ->get_op_info_(), + paddle::dialect::IsLegacyOp(op_name)); for (size_t i = 0; i < op->num_results(); ++i) { pir::Value value = op->result(i); @@ -608,8 +529,7 @@ void HandleForInplaceOp(pir::Operation* op, const std::string& inplace_name = yaml_parser.InplaceName(value_name); pir::Value inplace_value = op->operand_source(yaml_parser.InputName2Id().at(inplace_name)); - std::string var_name = - value_exe_info->GetValue2VarName().at(inplace_value); + std::string var_name = value_exe_info->GetVarName(inplace_value); VLOG(4) << "inplace: " << value_name << " -> " << inplace_name << " (var: " << var_name << ")"; value_exe_info->AddValue2VarName(value, var_name); @@ -618,8 +538,7 @@ void HandleForInplaceOp(pir::Operation* op, pir::Value view_value = op->operand_source(yaml_parser.InputName2Id().at(view_name)); // const std::string& var_name = value_2_var_name->at(view_value); - const std::string& var_name = - value_exe_info->GetValue2VarName().at(view_value); + std::string var_name = value_exe_info->GetVarName(view_value); VLOG(4) << "view: " << value_name << " -> " << view_name << " (var: " << var_name << ")"; value_exe_info->AddValue2VarName(value, var_name); @@ -897,7 +816,6 @@ std::shared_ptr BuildOperatorBase( "pir::vector type")); } } - auto& op_info = OpInfoMap::Instance().Get(fluid_op_name); auto ptr = op_info.Creator()(fluid_op_name, in_name_map, out_name_map, attr_map); diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h index e0337313da260..3de1cb4fb2988 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h @@ -56,9 +56,11 @@ class ValueExecutionInfo { Scope* GetScope() const { return scope_; } - void Add(::pir::Value value, std::string var_name); + void Add(::pir::Value value, const std::string& var_name); - void Rename(pir::Value value, std::string new_name, std::string orig_name); + void Rename(pir::Value value, + const std::string& new_name, + const std::string& orig_name); int GetIdByName(const std::string& name) const; @@ -79,49 +81,19 @@ class ValueExecutionInfo { void ResetVarList(int id, Variable* var); - /// Check a value exist in the ValueExecutionInfo or any of its ancestors. - bool HasValue(::pir::Value value) const; + bool HasVar(const std::string& var_name) const; - /// Check a value exist in the ValueExecutionInfo. - bool HasLocalValue(::pir::Value value) const; + bool HasValue(::pir::Value value) const; std::string GetVarName(::pir::Value value) const; std::string GetVarName(const Variable* var) const; - std::string GetLocalVarName(::pir::Value value) const; - - std::string GetLocalVarName(const Variable* var) const; - int GetVarId(::pir::Value value) const; int GetVarId(const Variable* var) const; - int GetLocalVarId(::pir::Value value) const; - - int GetLocalVarId(const Variable* var) const; - private: - bool HasValueInternal(::pir::Value value) const; - - bool HasValueLocally(::pir::Value value) const; - - std::string GetVarNameInternal(::pir::Value value) const; - - std::string GetVarNameLocally(::pir::Value value) const; - - std::string GetVarNameInternal(const Variable* var) const; - - std::string GetVarNameLocally(const Variable* var) const; - - int GetVarIdInternal(::pir::Value value) const; - - int GetVarIdLocally(::pir::Value value) const; - - int GetVarIdInternal(const Variable* var) const; - - int GetVarIdLocally(const Variable* var) const; - std::shared_ptr NewChild(Scope* scope); ValueExecutionInfo* parent_{nullptr}; // not owned @@ -409,6 +381,7 @@ void BuildPhiContext(pir::Operation* op, } // EmplaceBackOutputs + VLOG(8) << "ctx->EmplaceBackOutput: "; for (size_t i = 0; i < op->num_results(); ++i) { pir::Value out_ptr = op->result(i); if (!IsInvalid(out_ptr)) { @@ -429,11 +402,15 @@ void BuildPhiContext(pir::Operation* op, ctx->EmplaceBackOutput(OutType(const_cast( &(inner_scope->FindVar(value_exec_info.GetVarName(out_ptr)) ->Get())))); + VLOG(8) << "ctx->EmplaceBackOutput DenseTensor: " + << value_exec_info.GetVarName(out_ptr); } else if (out_ptr.type() .isa()) { ctx->EmplaceBackOutput(OutType(const_cast( &(inner_scope->FindVar(value_exec_info.GetVarName(out_ptr)) ->Get())))); + VLOG(8) << "ctx->EmplaceBackOutput SelectedRows: " + << value_exec_info.GetVarName(out_ptr); } else if (out_ptr.type().isa()) { OutListType outputs; auto& variable_array = @@ -453,6 +430,8 @@ void BuildPhiContext(pir::Operation* op, variable_array[i]->Type())); } } + VLOG(8) << "ctx->EmplaceBackOutput VariableRefArray: " + << value_exec_info.GetVarName(out_ptr); ctx->EmplaceBackOutputs(outputs); } else { PADDLE_THROW( diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc similarity index 82% rename from paddle/fluid/framework/new_executor/new_ir_interpreter.cc rename to paddle/fluid/framework/new_executor/pir_interpreter.cc index 50af034414d6f..730193cbbf438 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" +#include "paddle/fluid/framework/new_executor/pir_interpreter.h" +#include #include #include "paddle/utils/flags.h" @@ -54,19 +55,25 @@ #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/pir/core/builtin_attribute.h" +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); +#endif -PHI_DECLARE_bool(enable_new_ir_in_executor); -PHI_DECLARE_bool(enable_new_ir_in_executor_trace_run); +PHI_DECLARE_bool(enable_pir_in_executor); +PHI_DECLARE_bool(enable_pir_in_executor_trace_run); namespace paddle { namespace framework { -NewIRInterpreter::NewIRInterpreter( - const platform::Place& place, - const std::vector& fetch_var_names, - const ::pir::Block* ir_block, - framework::Scope* scope, - const ExecutionConfig& execution_config) +PirInterpreter::PirInterpreter(const platform::Place& place, + const std::vector& fetch_var_names, + const ::pir::Block* ir_block, + framework::Scope* scope, + const ExecutionConfig& execution_config) : place_(place), execution_config_(execution_config), var_scope_(scope), @@ -74,7 +81,7 @@ NewIRInterpreter::NewIRInterpreter( ir_block_(ir_block), ir_stream_analyzer_(place), fetch_var_names_(fetch_var_names) { - VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_; + VLOG(4) << "PirInterpreter(): " << this << " on " << place_; static_build_ = FLAGS_new_executor_static_build && !FLAGS_new_executor_use_cuda_graph && @@ -118,11 +125,12 @@ NewIRInterpreter::NewIRInterpreter( value_exe_info_ = std::make_shared(InnerScope()); std::stringstream ss; - ss << this; + ss << this + << std::chrono::high_resolution_clock::now().time_since_epoch().count(); BuildScope(*ir_block_, ss.str(), value_exe_info_.get()); } -NewIRInterpreter::NewIRInterpreter( +PirInterpreter::PirInterpreter( const platform::Place& place, const std::vector& fetch_var_names, const ::pir::Block* ir_block, @@ -136,7 +144,7 @@ NewIRInterpreter::NewIRInterpreter( ir_block_(ir_block), ir_stream_analyzer_(place), fetch_var_names_(fetch_var_names) { - VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_; + VLOG(4) << "PirInterpreter(): " << this << " on " << place_; static_build_ = FLAGS_new_executor_static_build && !FLAGS_new_executor_use_cuda_graph && @@ -184,11 +192,11 @@ NewIRInterpreter::NewIRInterpreter( BuildScope(*ir_block_, ss.str(), value_exe_info_.get()); } -NewIRInterpreter::~NewIRInterpreter() { +PirInterpreter::~PirInterpreter() { // cancle gc's thread gc_.reset(nullptr); async_work_queue_.reset(); - VLOG(4) << "~NewIRInterpreter(): " << this << " on " << place_; + VLOG(4) << "~PirInterpreter(): " << this << " on " << place_; #ifdef PADDLE_WITH_DNNL // Clear mkl-dnn cache, @@ -197,13 +205,12 @@ NewIRInterpreter::~NewIRInterpreter() { #endif } -void NewIRInterpreter::SetCopyProgram(std::shared_ptr prog) { +void PirInterpreter::SetCopyProgram(std::shared_ptr prog) { PADDLE_THROW(platform::errors::Unimplemented( - "SetCopyProgram is not implemented in NewIRInterpreter.")); + "SetCopyProgram is not implemented in PirInterpreter.")); } -void NewIRInterpreter::SetSkipGcVars( - const std::set& skip_gc_vars) { +void PirInterpreter::SetSkipGcVars(const std::set& skip_gc_vars) { PADDLE_ENFORCE_EQ( execution_config_.skip_gc_vars.empty(), true, @@ -214,7 +221,7 @@ void NewIRInterpreter::SetSkipGcVars( execution_config_.skip_gc_vars = skip_gc_vars; } -void NewIRInterpreter::SetJitInputVars( +void PirInterpreter::SetJitInputVars( const std::set& jit_input_vars) { PADDLE_ENFORCE_EQ( execution_config_.jit_input_vars.empty(), @@ -226,15 +233,15 @@ void NewIRInterpreter::SetJitInputVars( execution_config_.jit_input_vars = jit_input_vars; } -const std::set& NewIRInterpreter::JitInputVars() const { +const std::set& PirInterpreter::JitInputVars() const { return execution_config_.jit_input_vars; } -const VariableScope* NewIRInterpreter::GetVariableScope() const { +const VariableScope* PirInterpreter::GetVariableScope() const { return &var_scope_; } -void NewIRInterpreter::reset_scope(Scope* new_scope) { +void PirInterpreter::reset_scope(Scope* new_scope) { var_scope_.SetScope(new_scope); scope_ = new_scope; for (size_t i = 0; i < value_exe_info_->GetVarList().size(); i++) { @@ -244,7 +251,7 @@ void NewIRInterpreter::reset_scope(Scope* new_scope) { } // The index should be assured valid, cause the InterpreterCore may not be // fully built, but was still cached and used. For example, see unit test - // `test_assert.py`, it may exit before `NewIRInterpreter::Convert`, + // `test_assert.py`, it may exit before `PirInterpreter::Convert`, // but still was cached and used by later tests. for (size_t i = 0; i < std::min(refs_.size(), value_exe_info_->GetVarList().size()); @@ -253,49 +260,49 @@ void NewIRInterpreter::reset_scope(Scope* new_scope) { } } -const Scope* NewIRInterpreter::local_scope() const { return local_scope_; } +const Scope* PirInterpreter::local_scope() const { return local_scope_; } -void NewIRInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) { - async_work_queue_ = reinterpret_cast(src)->GetWorkQueue(); +void PirInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) { + async_work_queue_ = reinterpret_cast(src)->GetWorkQueue(); VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << src << ") to InterpreterCore(" << this << ")"; } -void NewIRInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) { - const NewIRInterpreter& impl = dynamic_cast(src); +void PirInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) { + const PirInterpreter& impl = dynamic_cast(src); if (is_shared_results_build_ || !impl.IsSharedResultsBuild()) { return; } // share op dependency - ir_dependency_builder_.ShareDependencyFrom(impl.GetNewIrDependencyBuilder()); + ir_dependency_builder_.ShareDependencyFrom(impl.GetPirDependencyBuilder()); dependecy_count_ = impl.GetDependencyCount(); // share event analysis - ir_stream_analyzer_.ShareEventInfoFrom(impl.GetNewIrStreamAnalyzer()); + ir_stream_analyzer_.ShareEventInfoFrom(impl.GetPirStreamAnalyzer()); is_shared_results_build_ = true; VLOG(8) << "Share Build Results from InterpreterCore(" << &impl << ") to InterpreterCore(" << this << ")"; } -const interpreter::NewIrDependencyBuilder& -NewIRInterpreter::GetNewIrDependencyBuilder() const { +const interpreter::PirDependencyBuilder& +PirInterpreter::GetPirDependencyBuilder() const { return ir_dependency_builder_; } -std::shared_ptr> NewIRInterpreter::GetDependencyCount() +std::shared_ptr> PirInterpreter::GetDependencyCount() const { return dependecy_count_; } -const interpreter::NewIrStreamAnalyzer& -NewIRInterpreter::GetNewIrStreamAnalyzer() const { +const interpreter::PirStreamAnalyzer& PirInterpreter::GetPirStreamAnalyzer() + const { return ir_stream_analyzer_; } -bool NewIRInterpreter::IsSharedResultsBuild() const { +bool PirInterpreter::IsSharedResultsBuild() const { return is_shared_results_build_; } -std::shared_ptr NewIRInterpreter::GetWorkQueue() { +std::shared_ptr PirInterpreter::GetWorkQueue() { if (async_work_queue_ == nullptr) { async_work_queue_ = std::make_shared( execution_config_.host_num_threads, @@ -305,7 +312,7 @@ std::shared_ptr NewIRInterpreter::GetWorkQueue() { return async_work_queue_; } -void NewIRInterpreter::PrepareForCUDAGraphCapture() { +void PirInterpreter::PrepareForCUDAGraphCapture() { if (!FLAGS_new_executor_use_cuda_graph) return; #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_EQ( @@ -330,7 +337,7 @@ void NewIRInterpreter::PrepareForCUDAGraphCapture() { #endif } -void NewIRInterpreter::CheckCUDAGraphBeforeRun( +void PirInterpreter::CheckCUDAGraphBeforeRun( const std::vector& feed_names) { #ifdef PADDLE_WITH_CUDA if (platform::IsCUDAGraphCapturing()) { @@ -354,7 +361,7 @@ void NewIRInterpreter::CheckCUDAGraphBeforeRun( #endif } -void NewIRInterpreter::ClearLoDTensorArrayInLocalScope() { +void PirInterpreter::ClearLoDTensorArrayInLocalScope() { auto vars = local_scope_->LocalVars(); for (auto var : vars) { if (var->IsType()) { @@ -364,7 +371,7 @@ void NewIRInterpreter::ClearLoDTensorArrayInLocalScope() { } } -std::string NewIRInterpreter::GetDepsString() const { +std::string PirInterpreter::GetDepsString() const { std::stringstream ss; auto downstream_map = ir_dependency_builder_.OpDownstreamMap(); ss << "Note: when static_dep is 1, it is ok that the dynamic_dep will not " @@ -383,17 +390,17 @@ std::string NewIRInterpreter::GetDepsString() const { return ss.str(); } -bool NewIRInterpreter::HasLocalScope() const { return local_scope_ != nullptr; } +bool PirInterpreter::HasLocalScope() const { return local_scope_ != nullptr; } -Scope* NewIRInterpreter::InnerScope() const { +Scope* PirInterpreter::InnerScope() const { return local_scope_ != nullptr ? local_scope_ : scope_; } -std::string NewIRInterpreter::GetNameByValue(::pir::Value value) const { - return value_exe_info_->GetValue2VarName().at(value); +std::string PirInterpreter::GetNameByValue(::pir::Value value) const { + return value_exe_info_->GetVarName(value); } -void NewIRInterpreter::UpdateSyncOpNum() { +void PirInterpreter::UpdateSyncOpNum() { int64_t sync_op_num = 0; for (auto& ins : vec_instruction_base_) { if (ins->KernelType() == OpFuncType::kCpuSync || @@ -405,7 +412,7 @@ void NewIRInterpreter::UpdateSyncOpNum() { VLOG(4) << "Update sync op num, sync op num is: " << sync_op_num_; } -void NewIRInterpreter::UpdateNcclOpNum() { +void PirInterpreter::UpdateNcclOpNum() { static std::set nccl_op_set = { "pd_op.c_softmax_with_cross_entropy", "pd_op.c_allgather", @@ -418,7 +425,6 @@ void NewIRInterpreter::UpdateNcclOpNum() { "pd_op.c_reduce_prod", "pd_op.c_reducescatter", "pd_op.c_broadcast", - "pd_op.c_broadcast_", "pd_op.c_scatter", "pd_op.partial_send", "pd_op.partial_recv", @@ -432,7 +438,6 @@ void NewIRInterpreter::UpdateNcclOpNum() { "pd_op.distributed_fused_lamb", "pd_op.margin_cross_entropy", "pd_op.sync_batch_norm", - "pd_op.sync_batch_norm_", "pd_op.data_norm", "pd_op.class_center_sample", "pd_op.all_to_all", @@ -467,7 +472,6 @@ void NewIRInterpreter::UpdateNcclOpNum() { "pd_op.global_gather_grad", "pd_op.distributed_fused_lamb_grad", "pd_op.margin_cross_entropy_grad", - "pd_op.margin_cross_entropy_grad_", "pd_op.sync_batch_norm_grad", "pd_op.data_norm_grad", "pd_op.class_center_sample_grad", @@ -479,7 +483,77 @@ void NewIRInterpreter::UpdateNcclOpNum() { "pd_op.p_send_grad", "pd_op.reduce_scatter_grad", "pd_op.all_reduce_grad", - "pd_op.reduce_grad"}; + "pd_op.reduce_grad", + "pd_op.c_softmax_with_cross_entropy_", + "pd_op.c_allgather_", + "pd_op.c_allreduce_max_", + "pd_op.c_allreduce_min_", + "pd_op.c_allreduce_sum_", + "pd_op.c_allreduce_prod_", + "pd_op.c_reduce_max_", + "pd_op.c_reduce_min_", + "pd_op.c_reduce_prod_", + "pd_op.c_reducescatter_", + "pd_op.c_broadcast_", + "pd_op.c_scatter_", + "pd_op.partial_send_", + "pd_op.partial_recv_", + "pd_op.partial_allgather_", + "pd_op.recv_v2_", + "pd_op.send_v2_", + "pd_op.mp_allreduce_sum_", + "pd_op.barrier_", + "pd_op.alltoall_", + "pd_op.global_gather_", + "pd_op.distributed_fused_lamb_", + "pd_op.margin_cross_entropy_", + "pd_op.sync_batch_norm_", + "pd_op.data_norm_", + "pd_op.class_center_sample_", + "pd_op.all_to_all_", + "pd_op.dist_concat_", + "pd_op.all_gather_", + "pd_op.broadcast_", + "pd_op.p_recv_", + "pd_op.p_send_", + "pd_op.reduce_scatter_", + "pd_op.all_reduce_", + "pd_op.reduce_", + "pd_op.c_softmax_with_cross_entropy_grad_", + "pd_op.c_allgather_grad_", + "pd_op.c_allreduce_max_grad_", + "pd_op.c_allreduce_min_grad_", + "pd_op.c_allreduce_sum_grad_", + "pd_op.c_allreduce_prod_grad_", + "pd_op.c_reduce_max_grad_", + "pd_op.c_reduce_min_grad_", + "pd_op.c_reduce_prod_grad_", + "pd_op.c_reducescatter_grad_", + "pd_op.c_broadcast_grad_", + "pd_op.c_scatter_grad_", + "pd_op.partial_send_grad_", + "pd_op.partial_recv_grad_", + "pd_op.partial_allgather_grad_", + "pd_op.recv_v2_grad_", + "pd_op.send_v2_grad_", + "pd_op.mp_allreduce_sum_grad_", + "pd_op.barrier_grad_", + "pd_op.alltoall_grad_", + "pd_op.global_gather_grad_", + "pd_op.distributed_fused_lamb_grad_", + "pd_op.margin_cross_entropy_grad_", + "pd_op.sync_batch_norm_grad_", + "pd_op.data_norm_grad_", + "pd_op.class_center_sample_grad_", + "pd_op.all_to_all_grad_", + "pd_op.dist_concat_grad_", + "pd_op.all_gather_grad_", + "pd_op.broadcast_grad_", + "pd_op.p_recv_grad_", + "pd_op.p_send_grad_", + "pd_op.reduce_scatter_grad_", + "pd_op.all_reduce_grad_", + "pd_op.reduce_grad_"}; int64_t nccl_op_num = 0; for (auto& ins : vec_instruction_base_) { if (nccl_op_set.count(ins->Name())) { @@ -496,7 +570,7 @@ void NewIRInterpreter::UpdateNcclOpNum() { // ->(sync_run)-> OP(B) OP(O) ->(direct_run)-> OP(C) ->(direct_run)-> OP(D) If B // is run before C, B may always block to wait for A to finish executing, but in // fact, C can be executed first during this time. -void NewIRInterpreter::AnalyseExecuteOrderForTrace( +void PirInterpreter::AnalyseExecuteOrderForTrace( std::map> op_downstream_map, InstructionSchedulingPriorityLess compare) { VLOG(4) << "Analyze the execution order of Trace scheduling mode."; @@ -556,7 +630,7 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace( /// For new ir /// /// ======================== /// -void NewIRInterpreter::BuildInstruction() { +void PirInterpreter::BuildInstruction() { VLOG(6) << "Build Instructions for new ir ... "; vec_instruction_base_.clear(); size_t op_idx = 0; @@ -603,8 +677,8 @@ void NewIRInterpreter::BuildInstruction() { } #ifdef PADDLE_WITH_CINN } else if (op->dialect()->name() == "cinn_runtime") { - vec_instruction_base_.emplace_back( - std::make_unique(op_idx++, place_, op, scope_)); + vec_instruction_base_.emplace_back(std::make_unique( + op_idx++, place_, op, *(value_exe_info_.get()))); #endif } else { PADDLE_THROW(platform::errors::Unimplemented( @@ -613,7 +687,7 @@ void NewIRInterpreter::BuildInstruction() { } } -std::string NewIRInterpreter::DebugValueInfo() { +std::string PirInterpreter::DebugValueInfo() { std::stringstream os; os << "value info of interpretercore " << this << "\n" << "value -> var_name -> id -> variable*" @@ -627,7 +701,7 @@ std::string NewIRInterpreter::DebugValueInfo() { PADDLE_ENFORCE((bool)kv.first, platform::errors::PreconditionNotMet( "vlaue(%s) should not be nullptr", kv.second)); - PADDLE_ENFORCE(value_exe_info_->GetVarName2Id().count(kv.second) > 0, + PADDLE_ENFORCE(value_exe_info_->HasVar(kv.second), platform::errors::PreconditionNotMet( "var(%s) should exist in var_name_2_id_", kv.second)); auto* var = InnerScope()->FindVar(kv.second); @@ -636,13 +710,12 @@ std::string NewIRInterpreter::DebugValueInfo() { platform::errors::PreconditionNotMet( "var(%s) should exist in scope (%p)", kv.second, InnerScope())); os << kv.first.impl() << " -> " << kv.second << " -> " - << value_exe_info_->GetVarName2Id().at(kv.second) << " -> " << var - << "\n"; + << value_exe_info_->GetVarId(kv.first) << " -> " << var << "\n"; } return os.str(); } -void NewIRInterpreter::BuildInstructionDependences() { +void PirInterpreter::BuildInstructionDependences() { // analysis the dependences between instructions, add next_instr_list to each // instr, and set the dependecy_count_ size_t instr_num = vec_instruction_base_.size(); @@ -698,7 +771,7 @@ void NewIRInterpreter::BuildInstructionDependences() { } } -void NewIRInterpreter::RecordMemcpyD2H(InstructionBase* instr_node) { +void PirInterpreter::RecordMemcpyD2H(InstructionBase* instr_node) { // NOTE(zhiqiu): hot fix for jit input var if (instr_node->Name() == "pd_op.memcpy_d2h") { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); @@ -713,7 +786,7 @@ void NewIRInterpreter::RecordMemcpyD2H(InstructionBase* instr_node) { } } -void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) { +void PirInterpreter::RecordStreamForGC(InstructionBase* instr) { #if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) PADDLE_THROW(platform::errors::Unimplemented( "RecordStreamForGC is only implemented when compiled with GPU.")); @@ -731,6 +804,29 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) { gpuStream_t stream = reinterpret_cast(instr->DeviceContext()).stream(); +// TODO(lizhiyu): Only analyse the 'send_v2' for GPT pp strategy right now. +// To support all the operators for communicating in the future. +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + if (instr->Name() == "pd_op.send_v2") { + ::pir::Operation* op = instr->Operation(); + if (op->HasAttribute("use_calc_stream") && + op->attribute<::pir::BoolAttribute>("use_calc_stream").data() == + false) { + int ring_id = op->attribute<::pir::Int32Attribute>("ring_id").data(); + if (FLAGS_dynamic_static_unified_comm) { + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + stream = static_cast( + comm_context_manager.Get(std::to_string(ring_id))) + ->GetStream(); + } else { + stream = platform::NCCLCommContext::Instance() + .Get(ring_id, instr->DeviceContext().GetPlace()) + ->stream(); + } + } + } +#endif auto TensorRecordStream = [&stream](phi::DenseTensor& tensor) { auto allocation = tensor.Holder(); if (allocation == nullptr) { @@ -828,7 +924,7 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) { #endif } -void NewIRInterpreter::CheckGC(InstructionBase* instr) { +void PirInterpreter::CheckGC(InstructionBase* instr) { platform::RecordEvent record( "CheckGC", platform::TracerEventType::UserDefined, 10); @@ -856,7 +952,8 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) { } } -void NewIRInterpreter::CalculateLastLiveOps() { +void PirInterpreter::CalculateLastLiveOps() { + VLOG(4) << "PirInterpreter(): " << this << " start CalculateLastLiveOps"; // calculate last_live_ops_ for (size_t op_idx = 0; op_idx < vec_instruction_base_.size(); ++op_idx) { InstructionBase* instr = vec_instruction_base_[op_idx].get(); @@ -882,11 +979,16 @@ void NewIRInterpreter::CalculateLastLiveOps() { gc_check_vars.insert(var_id); } } + VLOG(4) << "get gc check vars for: " << instr->Name(); for (auto var_id : gc_check_vars) { Scope* inner_scope = InnerScope(); paddle::framework::Variable* var = inner_scope->FindVar( value_exe_info_->GetNameById(static_cast(var_id))); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::NotFound("Var(id=%d) should not be nullptr.", + static_cast(var_id))); if (var->IsType() || var->IsType() || var->IsType() || var->IsType() || @@ -899,6 +1001,7 @@ void NewIRInterpreter::CalculateLastLiveOps() { << framework::ToTypeName(var->Type()); } } + VLOG(4) << "update last_live_ops for: " << instr->Name(); } // clear the last_live_ops list for all vars in skip_gc_vars for (const std::string& skip_gc_var : execution_config_.skip_gc_vars) { @@ -908,7 +1011,7 @@ void NewIRInterpreter::CalculateLastLiveOps() { VLOG(8) << "Skip gc for var: " << skip_gc_var; } } - VLOG(4) << "calculate last_live_ops_"; + VLOG(4) << "clear the last_live_ops list for all vars in skip_gc_vars"; // shrink, find the downstream op that has no other op in the // downstream list happens before it @@ -949,6 +1052,7 @@ void NewIRInterpreter::CalculateLastLiveOps() { last_live_ops_[i] = minumum_last_live_ops; var_ref_count_[i] = static_cast(last_live_ops_[i].size()); } + VLOG(4) << "shrink the last_live_ops list for all vars in skip_gc_vars"; for (auto& dep : *dependecy_count_) { deps_.emplace_back(std::make_shared(dep)); @@ -957,9 +1061,10 @@ void NewIRInterpreter::CalculateLastLiveOps() { refs_.emplace_back(std::make_shared( var_ref_count_[i], value_exe_info_->GetVarList()[i])); } + VLOG(4) << "done CalculateLastLiveOps"; } -void NewIRInterpreter::ConstructEventForJitInput() { +void PirInterpreter::ConstructEventForJitInput() { for (size_t i = 0; i < dependecy_count_->size(); ++i) { if ((*dependecy_count_)[i] == 0) { InstructionBase* inst = vec_instruction_base_[i].get(); @@ -983,9 +1088,10 @@ void NewIRInterpreter::ConstructEventForJitInput() { } } -paddle::framework::FetchList NewIRInterpreter::Run( +paddle::framework::FetchList PirInterpreter::Run( const std::vector& feed_names, - const std::vector& feed_tensors) { + const std::vector& feed_tensors, + bool need_fetch) { auto FeedInput = [&] { VLOG(4) << "Feed inputs"; for (size_t i = 0; i < feed_names.size(); ++i) { @@ -1030,7 +1136,8 @@ paddle::framework::FetchList NewIRInterpreter::Run( VLOG(4) << "Done PreAnalysis"; // Run - if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 || + if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 || + execution_config_.used_for_inference || ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && (sync_op_num_ == 0))) { LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode " @@ -1045,7 +1152,8 @@ paddle::framework::FetchList NewIRInterpreter::Run( is_build_ = true; is_shared_results_build_ = true; } else { - if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 || + if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 || + execution_config_.used_for_inference || ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && (sync_op_num_ == 0))) { TraceRunImpl(); @@ -1057,41 +1165,24 @@ paddle::framework::FetchList NewIRInterpreter::Run( if (HasLocalScope()) { ClearLoDTensorArrayInLocalScope(); } + // return Fetch Tensors Scope* inner_scope = InnerScope(); - if (FLAGS_enable_new_ir_in_executor) { - framework::FetchList fetch_res; - + framework::FetchList fetch_res; + if (need_fetch) { for (auto& var_name : fetch_var_names_) { auto* var = inner_scope->FindVar(var_name); - VLOG(0) << "fetch " << var_name << "[" << var << "]"; + VLOG(4) << "fetch " << var_name << "[" << var << "]"; fetch_res.push_back(var->Get()); } - - VLOG(4) << "get fetch list size: " << fetch_res.size(); - return fetch_res; - } else { - auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); - if (fetch_var) { - auto fetch_list = - std::move(*fetch_var->GetMutable()); -#ifdef PADDLE_WITH_CUDA - if (platform::IsCUDAGraphCapturing()) { - PADDLE_ENFORCE_EQ(fetch_list.empty(), - true, - platform::errors::InvalidArgument( - "Cannot fetch data when using CUDA Graph.")); - } -#endif - return fetch_list; - } else { - return {}; - } } + + VLOG(4) << "get fetch list size: " << fetch_res.size(); + return fetch_res; } -FetchList NewIRInterpreter::Run(const std::vector& feed_names, - bool need_fetch) { +FetchList PirInterpreter::Run(const std::vector& feed_names, + bool need_fetch) { SetDeviceId(place_); CheckCUDAGraphBeforeRun(feed_names); @@ -1119,7 +1210,8 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, VLOG(4) << "Done PreAnalysis"; // Run - if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 || + if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 || + execution_config_.used_for_inference || ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && (sync_op_num_ == 0))) { LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode " @@ -1134,7 +1226,8 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, is_build_ = true; is_shared_results_build_ = true; } else { - if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 || + if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 || + execution_config_.used_for_inference || ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && (sync_op_num_ == 0))) { TraceRunImpl(); @@ -1146,41 +1239,24 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, if (HasLocalScope()) { ClearLoDTensorArrayInLocalScope(); } - // return Fetch Tensors - Scope* inner_scope = InnerScope(); - if (FLAGS_enable_new_ir_in_executor) { - framework::FetchList fetch_res; - - if (need_fetch) { - for (auto& var_name : fetch_var_names_) { - auto* var = inner_scope->FindVar(var_name); - VLOG(0) << "fetch " << var_name << "[" << var << "]"; - fetch_res.push_back(var->Get()); - } + + framework::FetchList fetch_res; + if (need_fetch) { + // return Fetch Tensors + Scope* inner_scope = InnerScope(); + + for (auto& var_name : fetch_var_names_) { + auto* var = inner_scope->FindVar(var_name); + VLOG(4) << "fetch " << var_name << "[" << var << "]"; + fetch_res.push_back(var->Get()); } + VLOG(4) << "get fetch list size: " << fetch_res.size(); - return fetch_res; - } else { - auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); - if (fetch_var && need_fetch) { - auto fetch_list = - std::move(*fetch_var->GetMutable()); -#ifdef PADDLE_WITH_CUDA - if (platform::IsCUDAGraphCapturing()) { - PADDLE_ENFORCE_EQ(fetch_list.empty(), - true, - platform::errors::InvalidArgument( - "Cannot fetch data when using CUDA Graph.")); - } -#endif - return fetch_list; - } else { - return {}; - } } + return fetch_res; } -void NewIRInterpreter::TraceRunImpl() { +void PirInterpreter::TraceRunImpl() { // lazy initialization of gc, do not create gc is the program only run once if (!gc_) { gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_base_); @@ -1193,7 +1269,7 @@ void NewIRInterpreter::TraceRunImpl() { VLOG(4) << "Done TraceRunInstructionList"; } -void NewIRInterpreter::MultiThreadRunImpl() { +void PirInterpreter::MultiThreadRunImpl() { // lazy initialization of gc, do not create gc is the program only run once if (!gc_) { gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_base_); @@ -1207,7 +1283,7 @@ void NewIRInterpreter::MultiThreadRunImpl() { VLOG(4) << "Done MultiThreadRunInstructionList"; } -void NewIRInterpreter::TraceRunInstructionList( +void PirInterpreter::TraceRunInstructionList( const std::vector>& vec_instr) { unfinished_op_number_ = vec_instr.size(); if (unfinished_op_number_ == 0) { @@ -1251,7 +1327,7 @@ void NewIRInterpreter::TraceRunInstructionList( VLOG(4) << "Done TraceRunInstructionList"; } -void NewIRInterpreter::MultiThreadRunInstructionList( +void PirInterpreter::MultiThreadRunInstructionList( const std::vector>& vec_instr) { unfinished_op_number_ = vec_instr.size(); if (unfinished_op_number_ == 0) { @@ -1332,7 +1408,7 @@ void NewIRInterpreter::MultiThreadRunInstructionList( } } -void NewIRInterpreter::RunInstructionBaseAsync(size_t instr_id) { +void PirInterpreter::RunInstructionBaseAsync(size_t instr_id) { // NOTE(Ruibiao): Due to the uncertain order in multi-threading asynchronous // scheduling, the priority order involved cross-thread scheduling is not // guaranteed. Only Ops scheduled by the same AddTask call have the guarantee @@ -1366,8 +1442,8 @@ void NewIRInterpreter::RunInstructionBaseAsync(size_t instr_id) { } } -void NewIRInterpreter::RunNextInstructions(InstructionBase* instr, - SchedulingQueue* reserved_next_ops) { +void PirInterpreter::RunNextInstructions(InstructionBase* instr, + SchedulingQueue* reserved_next_ops) { platform::RecordEvent record( "RunNextInstructions", platform::TracerEventType::UserDefined, 10); @@ -1392,14 +1468,15 @@ void NewIRInterpreter::RunNextInstructions(InstructionBase* instr, } } -void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) { +void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) { platform::RecordEvent instruction_event( instr_node->Name(), platform::TracerEventType::Operator, 1); - SetDeviceId(instr_node->DeviceContext().GetPlace()); + auto cur_place = instr_node->DeviceContext().GetPlace(); + SetDeviceId(cur_place); try { - instr_node->WaitEvent(place_); + instr_node->WaitEvent(cur_place); VLOG(4) << "begin to run op " << instr_node->Name(); VLOG(4) << "begin: " << __func__ << " OP id:" << instr_node->Id() << " name:" << instr_node->Name() << " type:" @@ -1409,9 +1486,9 @@ void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) { ? "kGpuSync" : "kGpuAsync")) << " runs on " << platform::GetCurrentThreadName(); - VLOG(4) << place_ << " " - << instr_node->DebugStringEx(scope_, - value_exe_info_->GetValue2VarName()); + + VLOG(4) << cur_place << " Before:" + << instr_node->DebugStringEx(scope_, value_exe_info_.get()); if (!instr_node->IsArtificial()) { instr_node->Run(); @@ -1432,15 +1509,15 @@ void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) { ? "kGpuSync" : "kGpuAsync")) << " runs on " << platform::GetCurrentThreadName(); - VLOG(4) << place_ << " " - << instr_node->DebugStringEx(scope_, - value_exe_info_->GetValue2VarName()); + + VLOG(4) << cur_place << " After:" + << instr_node->DebugStringEx(scope_, value_exe_info_.get()); CheckGC(instr_node); VLOG(4) << "done CheckGC"; - interpreter::LogDeviceMemoryStats(place_); + memory::LogDeviceMemoryStats(cur_place, instr_node->Name()); } VLOG(5) << "after run kernel"; - instr_node->RecordEvent(place_); + instr_node->RecordEvent(cur_place); } catch (platform::EnforceNotMet& ex) { auto* op = instr_node->Operation(); const std::vector op_callstack_attr = @@ -1461,7 +1538,7 @@ void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) { } } -void NewIRInterpreter::PreAnalysis() { +void PirInterpreter::PreAnalysis() { BuildInstructionDependences(); VLOG(4) << "Done BuildInstructionDependences"; @@ -1487,14 +1564,14 @@ void NewIRInterpreter::PreAnalysis() { VLOG(4) << "Done UpdateNcclOpNum"; } -void NewIRInterpreter::Build( +void PirInterpreter::Build( const std::vector& feed_names, std::vector* op_func_nodes) { PADDLE_THROW(platform::errors::Unimplemented( - "Build is not implemented in NewIRInterpreter.")); + "Build is not implemented in PirInterpreter.")); } -::pir::Value NewIRInterpreter::GetValueByName(const std::string& var_name) { +::pir::Value PirInterpreter::GetValueByName(const std::string& var_name) { for (auto kv : value_exe_info_->GetValue2VarName()) { if (kv.second == var_name) { return kv.first; @@ -1503,7 +1580,7 @@ ::pir::Value NewIRInterpreter::GetValueByName(const std::string& var_name) { return nullptr; } -void NewIRInterpreter::SolvePersisableVarNames() { +void PirInterpreter::SolvePersisableVarNames() { VLOG(6) << "SolvePersisableVarNames"; for (auto kv : value_exe_info_->GetValue2VarName()) { ::pir::Value value = kv.first; diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.h b/paddle/fluid/framework/new_executor/pir_interpreter.h similarity index 85% rename from paddle/fluid/framework/new_executor/new_ir_interpreter.h rename to paddle/fluid/framework/new_executor/pir_interpreter.h index 3a128791cdfce..e75817f5e9393 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.h +++ b/paddle/fluid/framework/new_executor/pir_interpreter.h @@ -25,7 +25,7 @@ class Block; namespace paddle { namespace framework { class ValueExecutionInfo; -class NewIRInterpreter : public InterpreterBaseImpl { +class PirInterpreter : public InterpreterBaseImpl { using ExecutionConfig = interpreter::ExecutionConfig; using InstructionSchedulingPriorityLess = std::function; using SchedulingQueue = @@ -34,24 +34,25 @@ class NewIRInterpreter : public InterpreterBaseImpl { InstructionSchedulingPriorityLess>; public: - NewIRInterpreter(const platform::Place& place, - const std::vector& fetch_var_names, - const ::pir::Block* ir_block, - Scope* scope, - const ExecutionConfig& execution_config = ExecutionConfig()); + PirInterpreter(const platform::Place& place, + const std::vector& fetch_var_names, + const ::pir::Block* ir_block, + Scope* scope, + const ExecutionConfig& execution_config = ExecutionConfig()); - NewIRInterpreter(const platform::Place& place, - const std::vector& fetch_var_names, - const ::pir::Block* ir_block, - Scope* scope, - std::shared_ptr value_exe_info, - const ExecutionConfig& execution_config = ExecutionConfig()); + PirInterpreter(const platform::Place& place, + const std::vector& fetch_var_names, + const ::pir::Block* ir_block, + Scope* scope, + std::shared_ptr value_exe_info, + const ExecutionConfig& execution_config = ExecutionConfig()); - ~NewIRInterpreter(); + ~PirInterpreter(); paddle::framework::FetchList Run( const std::vector& feed_names, - const std::vector& feed_tensors) override; + const std::vector& feed_tensors, + bool need_fetch = true) override; paddle::framework::FetchList Run(const std::vector& feed_names, bool need_fetch = true) override; @@ -203,9 +204,9 @@ class NewIRInterpreter : public InterpreterBaseImpl { void SolvePersisableVarNames(); - const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder() const; + const interpreter::PirDependencyBuilder& GetPirDependencyBuilder() const; - const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer() const; + const interpreter::PirStreamAnalyzer& GetPirStreamAnalyzer() const; InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less; @@ -218,9 +219,9 @@ class NewIRInterpreter : public InterpreterBaseImpl { std::vector var_ref_count_; - interpreter::NewIrDependencyBuilder ir_dependency_builder_; + interpreter::PirDependencyBuilder ir_dependency_builder_; - interpreter::NewIrStreamAnalyzer ir_stream_analyzer_; + interpreter::PirStreamAnalyzer ir_stream_analyzer_; std::vector fetch_var_names_; diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index f1646f50471a4..2978e1bf81c41 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -40,6 +40,9 @@ PHI_DECLARE_bool(dynamic_static_unified_comm); #endif +PD_DECLARE_bool(enable_host_event_recorder_hook); +PD_DECLARE_bool(log_memory_stats); + namespace paddle { namespace framework { @@ -110,8 +113,9 @@ void ProgramInterpreter::RunImpl() { interpreter::ResetAtomicGuard guard(&deps_, &refs_); - if ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && - (sync_op_num_ == 0)) { + if (execution_config_.used_for_inference || + ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && + (sync_op_num_ == 0))) { VLOG(4) << "Tracing Instruction List"; TraceInstructionList(vec_instruction_); } else { @@ -153,24 +157,26 @@ FetchList ProgramInterpreter::Run(const std::vector& feed_names, ClearLoDTensorArrayInLocalScope(); } - // return Fetch Tensors - Scope* inner_scope = - HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); - auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); - if (fetch_var && need_fetch) { - auto fetch_list = std::move(*fetch_var->GetMutable()); + if (need_fetch) { + // return Fetch Tensors + Scope* inner_scope = + HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); + auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); + if (fetch_var) { + auto fetch_list = + std::move(*fetch_var->GetMutable()); #ifdef PADDLE_WITH_CUDA - if (platform::IsCUDAGraphCapturing()) { - PADDLE_ENFORCE_EQ(fetch_list.empty(), - true, - platform::errors::InvalidArgument( - "Cannot fetch data when using CUDA Graph.")); - } + if (platform::IsCUDAGraphCapturing()) { + PADDLE_ENFORCE_EQ(fetch_list.empty(), + true, + platform::errors::InvalidArgument( + "Cannot fetch data when using CUDA Graph.")); + } #endif - return fetch_list; - } else { - return {}; + return fetch_list; + } } + return {}; } void ProgramInterpreter::Build( @@ -202,7 +208,8 @@ void ProgramInterpreter::Build( FetchList ProgramInterpreter::Run( const std::vector& feed_names, - const std::vector& feed_tensors) { + const std::vector& feed_tensors, + bool need_fetch) { SetDeviceId(place_); CheckCUDAGraphBeforeRun(feed_names); @@ -221,24 +228,27 @@ FetchList ProgramInterpreter::Run( ClearLoDTensorArrayInLocalScope(); } - // return Fetch Tensors - Scope* inner_scope = - HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); - auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); - if (fetch_var) { - auto fetch_list = std::move(*fetch_var->GetMutable()); + if (need_fetch) { + // return Fetch Tensors + Scope* inner_scope = + HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); + auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); + if (fetch_var) { + auto fetch_list = + std::move(*fetch_var->GetMutable()); #ifdef PADDLE_WITH_CUDA - if (platform::IsCUDAGraphCapturing()) { - PADDLE_ENFORCE_EQ(fetch_list.empty(), - true, - platform::errors::InvalidArgument( - "Cannot fetch data when using CUDA Graph.")); - } + if (platform::IsCUDAGraphCapturing()) { + PADDLE_ENFORCE_EQ(fetch_list.empty(), + true, + platform::errors::InvalidArgument( + "Cannot fetch data when using CUDA Graph.")); + } #endif - return fetch_list; - } else { - return {}; + return fetch_list; + } } + + return {}; } void ProgramInterpreter::SetCopyProgram(std::shared_ptr prog) { @@ -399,9 +409,10 @@ void ProgramInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) { // in kernel Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope() : var_scope_.GetMutableScope(); - instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope); + instr_node->ResetContextWithScope( + ins_map, outs_map, *local_scope, instr_node->OpBase()->Type()); } else { - instr_node->ResetContext(ins_map, outs_map); + instr_node->ResetContext(ins_map, outs_map, instr_node->OpBase()->Type()); } } @@ -656,6 +667,10 @@ void ProgramInterpreter::Convert( } #endif vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + vec_instruction_.back().UpdataRecordStreamForGcInfo(); +#endif } BuildOperatorDependences(); @@ -789,7 +804,10 @@ void ProgramInterpreter::Convert( << var_scope_.GetNameById(static_cast(i)) << " : " << item << " " << vec_instruction_[item].OpBase()->Type(); minumum_last_live_ops.insert(item); - vec_instruction_[item].AddGCCheckVar(i); + if (!(var_scope_.VarDesc(static_cast(i)) && + var_scope_.VarDesc(static_cast(i))->Persistable())) { + vec_instruction_[item].AddGCCheckVar(i); + } } } last_live_ops_[i] = minumum_last_live_ops; @@ -857,6 +875,10 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) { : var_scope_.GetMutableScope(); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope); + if (op->Type() == "while") { + op->SetOutputHooks(hookfuncs_); + } + auto op_with_kernel = dynamic_cast(op); { // If it is OperatorBase, InferShape do nothing. @@ -870,15 +892,21 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) { // see OperatorWithKernel::RunImpl in operator.cc for why if (!(op_with_kernel->HasAttr(kAllKernelsMustComputeRuntimeShape) && op_with_kernel->Attr(kAllKernelsMustComputeRuntimeShape))) { - op_with_kernel->Info().infer_shape_( - instr_node.InnerInferShapeContext().get()); + if (instr_node.can_use_infermeta_ctx_) { + op_with_kernel->Info().infer_meta_(const_cast( + instr_node.InnerCompatInferMetaContext())); + } else { + op_with_kernel->Info().infer_shape_( + instr_node.InnerInferShapeContext().get()); + } + } + if (FLAGS_enable_host_event_recorder_hook) { + platform::RecordOpInfoSupplement(op->Type(), + op->Attrs(), + *(instr_node.InnerInferShapeContext()), + *(instr_node.InnerRuntimeContext()), + op->Id()); } - infershape_event.End(); - platform::RecordOpInfoSupplement(op->Type(), - op->Attrs(), - *(instr_node.InnerInferShapeContext()), - *(instr_node.InnerRuntimeContext()), - op->Id()); } } if (op_with_kernel != nullptr && FLAGS_new_executor_use_inplace) { @@ -1016,7 +1044,9 @@ void ProgramInterpreter::RunInstruction(const Instruction& instr_node) { if (!instr_node.IsArtificial()) { RunOperator(instr_node); CheckGC(instr_node); - interpreter::LogDeviceMemoryStats(place_); + if (FLAGS_log_memory_stats) { + memory::LogDeviceMemoryStats(place_, instr_node.OpBase()->Type()); + } } instr_node.RecordEvent(place_); @@ -1200,42 +1230,11 @@ void ProgramInterpreter::RecordStreamForGC(const Instruction& instr) { PADDLE_THROW(platform::errors::Unimplemented( "RecordStreamForGC is only implemented when compiled with GPU.")); #else - if (!IsInterpretercoreFastGCEnabled() || - instr.KernelType() != OpFuncType::kGpuAsync) { - return; - } - - if (instr.DeviceContext().GetPlace().GetType() == - phi::AllocationType::CUSTOM) { - return; - } - platform::RecordEvent record( "RecordStreamForGC", platform::TracerEventType::UserDefined, 10); - gpuStream_t stream = - reinterpret_cast(instr.DeviceContext()).stream(); -// TODO(lizhiyu): Only analyse the 'send_v2' for GPT pp strategy right now. -// To support all the operators for communicating in the future. -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto operator_base_ptr = instr.OpBase(); - if ((operator_base_ptr->Type() == "send_v2") && - (operator_base_ptr->Attr("use_calc_stream") == false)) { - int ring_id = operator_base_ptr->Attr("ring_id"); - if (FLAGS_dynamic_static_unified_comm) { - const auto& comm_context_manager = - phi::distributed::CommContextManager::GetInstance(); - stream = static_cast( - comm_context_manager.Get(std::to_string(ring_id))) - ->GetStream(); - } else { - stream = platform::NCCLCommContext::Instance() - .Get(ring_id, instr.DeviceContext().GetPlace()) - ->stream(); - } - } -#endif - auto TensorRecordStream = [&stream](phi::DenseTensor& tensor) { + auto TensorRecordStream = [](phi::DenseTensor& tensor, + const gpuStream_t& stream) { auto allocation = tensor.Holder(); if (allocation == nullptr) { return; @@ -1283,19 +1282,13 @@ void ProgramInterpreter::RecordStreamForGC(const Instruction& instr) { VLOG(4) << "GC sync " << var_scope_.GetNameById(var_id) << " " << var_scope_.VarDesc(var_id); - // persistable var will be ignore while GC - if (var_scope_.VarDesc(var_id) && - var_scope_.VarDesc(var_id)->Persistable()) { - continue; - } - paddle::framework::Variable* var = var_scope_.VarRef(var_id); if (var == nullptr) { continue; } if (var->IsType()) { - TensorRecordStream(*(var->GetMutable())); + TensorRecordStream(*(var->GetMutable()), instr.stream_); } else if ( var->IsType< operators::reader:: @@ -1303,24 +1296,30 @@ void ProgramInterpreter::RecordStreamForGC(const Instruction& instr) { // do nothing } else if (var->IsType()) { TensorRecordStream( - *(var->GetMutable()->mutable_value())); + *(var->GetMutable()->mutable_value()), + instr.stream_); } else if (var->IsType()) { auto* tensor_arr = var->GetMutable(); for (auto& tensor : *tensor_arr) { - TensorRecordStream(tensor); + TensorRecordStream(tensor, instr.stream_); } } else if (var->IsType()) { TensorRecordStream( - *(var->GetMutable()->mutable_indices())); + *(var->GetMutable()->mutable_indices()), + instr.stream_); TensorRecordStream( - *(var->GetMutable()->mutable_values())); + *(var->GetMutable()->mutable_values()), + instr.stream_); } else if (var->IsType()) { TensorRecordStream( - *(var->GetMutable()->mutable_cols())); + *(var->GetMutable()->mutable_cols()), + instr.stream_); TensorRecordStream( - *(var->GetMutable()->mutable_crows())); + *(var->GetMutable()->mutable_crows()), + instr.stream_); TensorRecordStream( - *(var->GetMutable()->mutable_values())); + *(var->GetMutable()->mutable_values()), + instr.stream_); } else if (var->IsType>()) { // do nothing } else { @@ -1336,7 +1335,9 @@ void ProgramInterpreter::CheckGC(const Instruction& instr) { platform::RecordEvent record( "CheckGC", platform::TracerEventType::UserDefined, 10); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - RecordStreamForGC(instr); + if (instr.need_record_stream_for_gc_) { + RecordStreamForGC(instr); + } #endif auto& var_scope = var_scope_; @@ -1344,13 +1345,6 @@ void ProgramInterpreter::CheckGC(const Instruction& instr) { VLOG(4) << "GC:" << var_scope_.GetNameById(static_cast(var_id)) << ", id:" << var_id << ", ref:" << refs_[var_id]->DynamicRef(); bool is_ready = refs_[var_id]->CheckAndDecrease(); - // ignore all persistable var while GC - if (var_scope.VarDesc(static_cast(var_id)) && - var_scope.VarDesc(static_cast(var_id))->Persistable()) { - VLOG(4) << "Skip persistable var: " - << var_scope_.GetNameById(static_cast(var_id)); - continue; - } if (is_ready) { VLOG(6) << "Async delete variable with name : " << var_scope.GetNameById(static_cast(var_id)); @@ -1446,7 +1440,7 @@ bool ProgramInterpreter::HasLocalScope() const { // miss. When a model is all KQueueAsync type OPs, all OPs will be distributed // to the DeviceThread for execution, and the multithreading scheduling will not // have any benefits. Therefore, in the dynamic to static, when the number of -// KQueueAsync Ops is 0, we choose Trace mode. +// KQueueSync Ops is 0, we choose Trace mode. void ProgramInterpreter::TraceInstructionList( const std::vector& vec_instr) { unfinished_op_number_ = vec_instr.size(); diff --git a/paddle/fluid/framework/new_executor/program_interpreter.h b/paddle/fluid/framework/new_executor/program_interpreter.h index bef6385c211fb..9c4b8f9bf1c9b 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.h +++ b/paddle/fluid/framework/new_executor/program_interpreter.h @@ -43,7 +43,8 @@ class ProgramInterpreter : public InterpreterBaseImpl { paddle::framework::FetchList Run( const std::vector& feed_names, - const std::vector& feed_tensors) override; + const std::vector& feed_tensors, + bool need_fetch = true) override; paddle::framework::FetchList Run(const std::vector& feed_names, bool need_fetch = true) override; diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index f06bee2c884e3..aa97dab1fefe5 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -27,9 +27,9 @@ #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" -PHI_DECLARE_bool(enable_new_ir_in_executor); +PHI_DECLARE_bool(enable_pir_in_executor); PHI_DECLARE_bool(enable_pir_api); -PHI_DECLARE_bool(new_ir_apply_inplace_pass); +PHI_DECLARE_bool(pir_apply_inplace_pass); namespace paddle { namespace framework { @@ -51,11 +51,12 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, VLOG(6) << ss.str(); const auto& jobs = plan_.JobList(); - for (const auto& job : jobs) { + for (size_t job_idx = 0; job_idx < jobs.size(); ++job_idx) { + const auto& job = jobs[job_idx]; const std::string& job_type = job->Type(); std::shared_ptr program = nullptr; std::shared_ptr<::pir::Program> ir_program = nullptr; - if (FLAGS_enable_pir_api) { + if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) { ir_program = plan_.IrProgram(job_type); } else { program = std::make_shared(*(plan_.Program(job_type))); @@ -69,7 +70,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, micro_batch_id, micro_batch_num)); - if (micro_batch_num > 1 && !FLAGS_enable_pir_api) { + if (!FLAGS_enable_pir_api && !FLAGS_enable_pir_in_executor) { SetColAttrForFeedFetchOps(program, micro_batch_num, micro_batch_id); } @@ -78,12 +79,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, execution_config.skip_gc_vars = job->SkipGcVars(); // TODO(phlrain) we only support cpu for now - if (FLAGS_enable_new_ir_in_executor) { + if (FLAGS_enable_pir_in_executor) { std::shared_ptr<::pir::Program> base_program = ir_program; - if (!FLAGS_enable_pir_api) { - VLOG(6) << "begin to translate" << std::endl; - base_program = paddle::TranslateLegacyProgramToProgram(*program); - } auto block = base_program->block(); for (auto it = block->begin(); it != block->end(); ++it) { if ((*it)->isa()) { @@ -103,14 +100,15 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, .dyn_cast() .AsString() + "@fetch"; + job->SetFetchVarName(fetch_var_names_[index]); } } auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(base_program.get(), place); std::shared_ptr shared_program = std::move(kernel_program); - plan_.UpdateIrProgram("base", shared_program); + plan_.SetIrProgram("job_" + std::to_string(job_idx), shared_program); - if (FLAGS_new_ir_apply_inplace_pass) { + if (FLAGS_pir_apply_inplace_pass) { pir::PassManager pm(pir::IrContext::Instance(), 3); pm.AddPass(pir::CreateInplacePass()); pm.Run(shared_program.get()); @@ -118,9 +116,9 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, interpretercores_.emplace_back( std::make_shared(place_, - fetch_var_names_, + job->FetchVarNames(), shared_program->block(), - scope_, + micro_batch_scopes_[micro_batch_id], execution_config)); } else { interpretercores_.emplace_back( @@ -179,6 +177,12 @@ paddle::framework::FetchList StandaloneExecutor::Run( is_interpretercore_build_result_shared_ = true; } + std::vector> splited_feeds; + if (FLAGS_enable_pir_in_executor) { + SplitFeedTensors(feed_names, plan_.MicroBatchNum(), scope_, &splited_feeds); + } + + fetch_list_.resize(plan_.MicroBatchNum()); for (size_t job_idx = 0; job_idx < jobs.size(); ++job_idx) { const auto& job = jobs[job_idx]; const std::string& job_type = job->Type(); @@ -192,27 +196,36 @@ paddle::framework::FetchList StandaloneExecutor::Run( // Note(sonder): Share build results don't work for new IR now. if (type_to_first_id.count(job_type) != 0 && - !FLAGS_enable_new_ir_in_executor) { + !FLAGS_enable_pir_in_executor) { interpretercores_[job_idx]->ShareBuildResultsFrom( interpretercores_[type_to_first_id[job_type]]); } - // TODO(zhaoyinglia): use a more general method - if (jobs.size() > 1 && job_type != "forward") { - const std::vector tmp_feed_names = {}; - interpretercores_[job_idx]->Run(tmp_feed_names, /*need_fetch = */ false); + + if (FLAGS_enable_pir_in_executor) { + interpretercores_[job_idx]->Run(feed_names, + splited_feeds[job->MicroBatchId()], + /*need_fetch = */ false); + + FetchTensors(job->FetchVarNames(), + fetch_var_names_, + job->MicroBatchId(), + micro_batch_scopes_[job->MicroBatchId()], + &fetch_list_); } else { - interpretercores_[job_idx]->Run(feed_names, /*need_fetch = */ false); + if (jobs.size() > 1 && job_type != "forward") { + const std::vector tmp_feed_names = {}; + interpretercores_[job_idx]->Run(tmp_feed_names, + /*need_fetch = */ false); + } else { + interpretercores_[job_idx]->Run(feed_names, /*need_fetch = */ false); + } } } // return Fetch Tensors - if (FLAGS_enable_new_ir_in_executor) { + if (FLAGS_enable_pir_in_executor) { framework::FetchList fetch_res; - for (auto& var_name : fetch_var_names_) { - auto* var = scope_->FindVar(var_name); - fetch_res.push_back(var->Get()); - } - + MergeFetchTensors(fetch_list_, plan_.MicroBatchNum(), &fetch_res); return fetch_res; } else { auto* fetch_var = scope_->FindVar(interpreter::kFetchVarName); diff --git a/paddle/fluid/framework/new_executor/standalone_executor.h b/paddle/fluid/framework/new_executor/standalone_executor.h index cb10648855181..8feef6e5b2f91 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.h +++ b/paddle/fluid/framework/new_executor/standalone_executor.h @@ -45,13 +45,13 @@ class StandaloneExecutor { bool is_interpretercore_build_result_shared_{false}; const platform::Place place_; interpreter::Plan plan_; - - std::vector micro_batch_scopes_; std::vector> interpretercores_; Scope* scope_; + std::vector micro_batch_scopes_; std::vector fetch_var_names_; + FetchUnmergedList fetch_list_; std::vector>> vec_force_events_to_wait_; diff --git a/paddle/fluid/framework/op_info.h b/paddle/fluid/framework/op_info.h index e1bc5be8c64f9..869b30ff1c754 100644 --- a/paddle/fluid/framework/op_info.h +++ b/paddle/fluid/framework/op_info.h @@ -26,7 +26,7 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" #include "paddle/utils/flat_hash_map.h" - +#include "paddle/utils/test_macros.h" namespace paddle { namespace framework { @@ -48,6 +48,7 @@ class OpInfo { OpAttrChecker* checker_{nullptr}; InferVarTypeFN infer_var_type_; InferShapeFN infer_shape_; + InferMetaFN infer_meta_; InferInplaceOpFN infer_inplace_; InferNoNeedBufferVarsFN infer_no_need_buffer_vars_; DygraphGradOpMakerFN dygraph_grad_op_maker_; @@ -128,7 +129,7 @@ class OpInfo { } }; -class OpInfoMap { +class TEST_API OpInfoMap { public: static OpInfoMap& Instance(); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 17d5f6c4f356a..3484c5cc05940 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -607,6 +607,28 @@ RuntimeInferShapeContext::GetPhiDefaultKernelSignature() const { void RuntimeInferShapeContext::SetSkipLoD(bool skip) { can_skip_lod_ = skip; } +bool RuntimeInferShapeContext::HasRuntimeAttributes() const { + bool is_runtime = false; + if (phi::DefaultKernelSignatureMap::Instance().Has(op_.Type())) { + auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap( + GetPhiDefaultKernelSignature()->name); + if (!phi_kernels.empty()) { + const auto& args_def = phi_kernels.cbegin()->second.args_def(); + const auto& attr_defs = args_def.attribute_defs(); + for (size_t i = 0; i < attr_defs.size(); ++i) { + if (attr_defs[i].type_index == phi::AttributeType::SCALAR || + attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { + is_runtime = true; + break; + } + } + } + } else { + is_runtime = true; + } + return is_runtime; +} + std::vector RuntimeInferShapeContext::GetOutputsLod( const std::string& out) const { auto out_it = ctx_.outputs.find(out); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 421f8b2c2e772..535992085451e 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -233,6 +233,8 @@ class RuntimeInferShapeContext : public InferShapeContext { std::vector GetOutputsDim(const std::string& name) const; + bool HasRuntimeAttributes() const; + protected: DDim GetDim(Variable* var) const; diff --git a/paddle/fluid/framework/phi_utils.h b/paddle/fluid/framework/phi_utils.h index 67153a7001ece..d1eb5558c5454 100644 --- a/paddle/fluid/framework/phi_utils.h +++ b/paddle/fluid/framework/phi_utils.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/init_default_kernel_signature_map.h" #include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" @@ -60,8 +61,6 @@ class KernelArgsNameMaker { virtual const paddle::small_vector& GetAttrsArgsNames() = 0; }; -TEST_API void InitDefaultKernelSignatureMap(); - // TODO(Wilber): support others device context. template struct ConvertToPhiContext { diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 4ad1bcb80c4bc..98cdeac691fca 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/imperative/type_defs.h" #include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/infermeta_utils.h" #include "paddle/pir/core/block.h" #include "paddle/pir/core/value.h" #include "paddle/utils/blank.h" @@ -102,6 +103,7 @@ using InferVarTypeFN = std::function; using InferShapeFN = std::function; +using InferMetaFN = std::function; using InplacePair = std::unordered_map; using InferInplaceOpFN = std::function; diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index da39c21e84c03..df964a5139303 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -69,13 +69,8 @@ if(WIN32 AND WITH_GPU) cc_library(paddle_inference DEPS ${fluid_modules} ${STATIC_INFERENCE_API} ${utils_modules}) else() - if(WIN32) - create_static_lib(paddle_inference ${phi_modules} ${fluid_modules} - ${ir_targets} ${STATIC_INFERENCE_API} ${utils_modules}) - else() - create_static_lib(paddle_inference ${phi_modules} ${fluid_modules} - ${ir_targets} ${STATIC_INFERENCE_API} ${utils_modules}) - endif() + create_static_lib(paddle_inference ${phi_modules} ${fluid_modules} + ${ir_targets} ${STATIC_INFERENCE_API} ${utils_modules}) endif() if(NOT APPLE) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 2e74062bedff6..91bac8e7c0d0d 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -358,7 +358,9 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( // so we must find all the var_name+id. // https://github.com/PaddlePaddle/Paddle/pull/53184 for (auto *n : graph->Nodes()) { - if (n->IsVar() && input_names.count(n->Name())) { + if (n->IsVar() && + find(graph_params.begin(), graph_params.end(), n->Name()) != + graph_params.end()) { input_names_with_id.insert( RenameVarBeUnique(n->Name(), std::to_string(n->id()))); } @@ -586,6 +588,13 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( auto inspector_serialize = Get("inspector_serialize"); auto disable_trt_plugin_fp16 = Get("disable_trt_plugin_fp16"); auto context_memory_sharing = Get("context_memory_sharing"); + if (context_memory_sharing && TRT_VERSION < 7200) { + // https://forums.developer.nvidia.com/t/nvinfer1-createexecutioncontextwithoutdevicememory-returns-nullptr/111878/2 + // when trt version less than 7.2, + // createExecutionContextWithoutDeviceMemory() has bug. + // so, we cannot enable engine context memory sharing. + context_memory_sharing = false; + } auto enable_low_precision_io = Get("enable_low_precision_io"); auto workspace_size = Get("workspace_size"); auto gpu_device_id = Get("gpu_device_id"); diff --git a/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc b/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc index 126d16933fd82..acc8611ec917b 100644 --- a/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc +++ b/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc @@ -16,11 +16,17 @@ #include "paddle/fluid/inference/analysis/argument.h" +PHI_DECLARE_bool(enable_pir_in_executor); + namespace paddle { namespace inference { namespace analysis { void InferenceOpReplacePass::RunImpl(Argument* argument) { + if (FLAGS_enable_pir_in_executor) { + return; + } + std::unordered_map replaced_map{ {"conditional_block", "conditional_block_infer"}, {"merge_lod_tensor", "merge_lod_tensor_infer"}, diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index c3d4c3329016a..cc9c2c2e6f5f5 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -577,6 +577,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(apply_optim_); CP_MEMBER(skip_load_params_); + CP_MEMBER(use_new_executor_); + if (use_gpu_) { PADDLE_ENFORCE_EQ(use_xpu_, false, diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index a098bc524f255..99b50c9b8ab28 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -98,6 +98,21 @@ #include "paddle/phi/backends/xpu/xpu_info.h" #endif +#ifdef PADDLE_WITH_NVTX +#include "paddle/fluid/platform/device/gpu/cuda/cuda_profiler.h" +#endif + +#include "paddle/fluid/ir_adaptor/translator/translate.h" +#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" +#include "paddle/fluid/pir/transforms/inplace_pass.h" +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" +#include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h" +#include "paddle/phi/core/flags.h" +#include "paddle/pir/pass/pass_manager.h" + +PHI_DECLARE_bool(enable_pir_in_executor); +PHI_DECLARE_bool(pir_apply_inplace_pass); + namespace paddle { namespace { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -334,15 +349,13 @@ bool AnalysisPredictor::Init( const std::shared_ptr &parent_scope, const std::shared_ptr &program) { VLOG(3) << "Predictor::init()"; +#ifdef PADDLE_WITH_NVTX if (config_.with_profile_) { LOG(WARNING) << "Profiler is activated, which might affect the performance"; - auto tracking_device = config_.use_gpu() ? platform::ProfilerState::kAll - : platform::ProfilerState::kCPU; - platform::EnableProfiler(tracking_device); - } else { - VLOG(2) << "Profiler is deactivated, and no profiling report will be " - "generated."; + platform::CudaProfilerStart(); + platform::NvprofEnableRecordEvent(); } +#endif if (!status_is_cloned_) { root_predictor_id_ = predictor_id_; @@ -702,6 +715,45 @@ bool AnalysisPredictor::PrepareExecutor() { executor_->Prepare( sub_scope_, *inference_program_, 0, config_.use_feed_fetch_ops_); + if (config_.new_executor_enabled()) { + framework::interpreter::ExecutionConfig execution_config; + execution_config.create_local_scope = false; + execution_config.used_for_inference = true; + auto input_names = GetInputNames(); + execution_config.skip_gc_vars.insert(input_names.begin(), + input_names.end()); + auto output_names = GetOutputNames(); + execution_config.skip_gc_vars.insert(output_names.begin(), + output_names.end()); + + if (FLAGS_enable_pir_in_executor) { + pir_program_ = std::move( + paddle::TranslateLegacyProgramToProgram(*inference_program_)); + + ::pir::PassManager pm(::pir::IrContext::Instance(), 2); + pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass()); + pm.AddPass(::pir::CreateDeadCodeEliminationPass()); + + pm.EnableIRPrinting(); + pm.Run(pir_program_.get()); + + pir_program_ = std::move( + paddle::dialect::PdOpLowerToKernelPass(pir_program_.get(), place_)); + + if (FLAGS_pir_apply_inplace_pass) { + ::pir::PassManager pm(::pir::IrContext::Instance(), 3); + pm.AddPass(::pir::CreateInplacePass()); + pm.Run(pir_program_.get()); + } + + executor_->PrepareInterpreterCore( + sub_scope_, *pir_program_, execution_config); + } else { + executor_->PrepareInterpreterCore( + sub_scope_, *inference_program_, execution_config); + } + } + if (config_.enable_memory_optim_) { auto *pass_res_info = inference::analysis::PassResultInfoForRuntime::Instance(); @@ -1082,8 +1134,6 @@ bool AnalysisPredictor::Run(const std::vector &inputs, if (config_.use_mkldnn_) MkldnnPreSet(inputs); #endif VLOG(3) << "Predictor::predict"; - inference::Timer timer; - timer.tic(); // set feed variable framework::Scope *scope = sub_scope_ ? sub_scope_ : scope_.get(); PADDLE_ENFORCE_NOT_NULL( @@ -1107,9 +1157,13 @@ bool AnalysisPredictor::Run(const std::vector &inputs, HookCollectShapeRangeInfo(); } - // Run the inference program - // if share variables, we need not create variables - executor_->Run(); + if (config_.new_executor_enabled()) { + executor_->RunInterpreterCore(); + } else { + // Run the inference program + // if share variables, we need not create variables + executor_->Run(); + } // get fetch variable if (!GetFetch(output_data, scope)) { @@ -1117,8 +1171,6 @@ bool AnalysisPredictor::Run(const std::vector &inputs, return false; } - VLOG(3) << "predict cost: " << timer.toc() << "ms"; - // All the containers in the scope will be hold in inference, but the // operators assume that the container will be reset after each batch. // Here is a bugfix, collect all the container variables, and reset then to a @@ -1178,9 +1230,13 @@ bool AnalysisPredictor::Run(const std::vector &inputs, HookCollectShapeRangeInfo(); } - // Run the inference program - // if share variables, we need not create variables - executor_->Run(); + if (config_.new_executor_enabled()) { + executor_->RunInterpreterCore(); + } else { + // Run the inference program + // if share variables, we need not create variables + executor_->Run(); + } inference::DisplayMemoryInfo(place_, "after run"); @@ -2094,11 +2150,7 @@ bool AnalysisPredictor::ZeroCopyRun() { #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) if (config_.dist_config().use_dist_model()) { VLOG(3) << "ZeroCopyRun will use the fleet executor."; - inference::Timer timer; - timer.tic(); fleet_exe_->Run(config_.dist_config().carrier_id()); - VLOG(3) << "Fleet executor inf runs once use: " - << std::to_string(timer.toc()) << "ms"; return true; } #endif @@ -2155,7 +2207,11 @@ bool AnalysisPredictor::ZeroCopyRun() { } #endif - executor_->Run(); + if (config_.new_executor_enabled()) { + executor_->RunInterpreterCore(); + } else { + executor_->Run(); + } inference::DisplayMemoryInfo(place_, "after run"); #ifdef PADDLE_WITH_XPU @@ -2607,10 +2663,12 @@ AnalysisPredictor::~AnalysisPredictor() { // NOLINT SaveTrtCalibToDisk(); } #endif +#ifdef PADDLE_WITH_NVTX if (config_.with_profile_) { - platform::DisableProfiler(platform::EventSortingKey::kTotal, - "./profile.log"); + platform::NvprofDisableRecordEvent(); + platform::CudaProfilerStop(); } +#endif if (sub_scope_) { if (framework::global_transfer_scope_key().find(sub_scope_) != framework::global_transfer_scope_key().end()) { diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index beecfc9743b10..11e40ed74921a 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -13,15 +13,13 @@ // limitations under the License. #pragma once + #include #include #include #include #include -#include "paddle/phi/common/data_type.h" -#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) -#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" -#endif + #include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/op_compatible_info.h" #include "paddle/fluid/inference/analysis/analyzer.h" @@ -32,12 +30,20 @@ #include "paddle/fluid/inference/api/resource_manager.h" #include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/fluid/string/printf.h" -#include "paddle/phi/core/dense_tensor.h" + +#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) +#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" +#endif + #ifdef PADDLE_WITH_TESTING #include #include #endif +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/pir/core/program.h" + namespace paddle_infer { namespace experimental { class InternalUtils; @@ -564,15 +570,13 @@ class AnalysisPredictor : public PaddlePredictor { std::shared_ptr scope_; framework::Scope *sub_scope_{nullptr}; std::shared_ptr inference_program_; - framework::OpCompatibleMap op_compatible_map_; + std::shared_ptr pir_program_; std::vector feeds_; std::map feed_names_; // Sorted according to the idx. std::map idx2feeds_; std::vector fetches_; std::map idx2fetches_; - std::once_flag register_input_hook_flag_; - std::once_flag register_output_hook_flag_; phi::DataType model_precision_{phi::DataType::FLOAT32}; @@ -592,16 +596,14 @@ class AnalysisPredictor : public PaddlePredictor { details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; // A mutex help to make Clone thread safe. std::mutex clone_mutex_; + static int clone_num_; - // For memory optimization. - const size_t max_shape_collect_count_{1000}; - int need_collect_var_shapes_{-1}; // -1 for default, 0 for false, 1 for true. - std::vector>> batch_var_shapes_; int predictor_id_; int root_predictor_id_{-1}; private: - std::vector hookfuncs_; + std::once_flag register_input_hook_flag_; + std::once_flag register_output_hook_flag_; std::vector output_hookfuncs_; std::vector input_hookfuncs_; // Some status here that help to determine the status inside the predictor. @@ -609,7 +611,6 @@ class AnalysisPredictor : public PaddlePredictor { std::map>> shape_info_; std::map>> shape_tensor_value_; - static int clone_num_; bool private_context_{false}; void *predictor_stream_{nullptr}; diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.h b/paddle/fluid/inference/api/mkldnn_quantizer.h index a44da8085f35b..17fe7fff3aa21 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.h +++ b/paddle/fluid/inference/api/mkldnn_quantizer.h @@ -21,7 +21,6 @@ #include #include -#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/api/analysis_predictor.h" #include "paddle/fluid/inference/api/api_impl.h" diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 4f9982f0a6d40..94215dddc6cce 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -880,6 +880,10 @@ struct PD_INFER_DECL AnalysisConfig { /// int tensorrt_optimization_level() { return trt_optimization_level_; } + void EnableNewExecutor(bool x = true) { use_new_executor_ = x; } + + bool new_executor_enabled() const { return use_new_executor_; } + void EnableDlnne( int min_subgraph_size = 3, int max_batch_size = 1, @@ -1291,7 +1295,7 @@ struct PD_INFER_DECL AnalysisConfig { // memory reuse related. bool enable_memory_optim_{false}; - bool trt_engine_memory_sharing_{false}; + bool trt_engine_memory_sharing_{true}; int trt_engine_memory_sharing_identifier_{0}; std::unordered_set trt_ops_run_float_; @@ -1305,6 +1309,8 @@ struct PD_INFER_DECL AnalysisConfig { bool use_feed_fetch_ops_{true}; bool ir_debug_{false}; + bool use_new_executor_{false}; + bool specify_input_name_{false}; int cpu_math_library_num_threads_{1}; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index c7f3f87a4d192..25c2e0988c419 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -507,6 +507,8 @@ void CpuPassStrategy::EraseFcMkldnnPasses() { XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { passes_.assign({ + "delete_quant_dequant_linear_op_pass", + "delete_weight_dequant_linear_op_pass", "delete_assign_op_pass", "delete_dropout_op_pass", "delete_concat_op_pass", @@ -559,9 +561,11 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "fast_where_xpu_fuse_pass", "elementwise_mul_add_fuse_pass", "link_xpu_op_max_pass", - "delete_isolated_node_pass", // "auto_mixed_precision_pass", "cast_mixed_precision_op_fuse_pass", + "xpu_quantize_op_pass", + "xpu_quantize_squash_pass", + "delete_isolated_node_pass", "inplace_op_var_pass", }); use_xpu_ = true; diff --git a/paddle/fluid/inference/paddle_inference.map b/paddle/fluid/inference/paddle_inference.map index 191f5934166c4..6d47d6ca11cf4 100644 --- a/paddle/fluid/inference/paddle_inference.map +++ b/paddle/fluid/inference/paddle_inference.map @@ -82,6 +82,7 @@ *Pass*; *profile*; *phi*; + *pir*; PD_*; *cinn*; local: diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 38510606fe68d..3eb01c0951e27 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -374,16 +374,18 @@ class OpConverter { void SupportFP32MixPrecision(const std::string& output_name, const std::string& op_type, nvinfer1::ILayer* layer) { -#if IS_TRT_VERSION_GE(8210) if (engine_->OpIsRunFloat(output_name) || engine_->OpIsRunFloat(op_type)) { +#if IS_TRT_VERSION_GE(8210) VLOG(3) << op_type << "(output: " << output_name << ")" << " is forced to run in FP32 precision."; layer->resetPrecision(); layer->setPrecision(nvinfer1::DataType::kFLOAT); - } #else - LOG(INFO) << "Set layer precision needs TensorRT version 8.2.1 and after."; + VLOG(3) + << op_type << "(output: " << output_name << ")" + << ": Set layer precision needs TensorRT version 8.2.1 and after."; #endif + } } nvinfer1::ITensor* Cast(nvinfer1::ITensor* input, nvinfer1::DataType dtype) { diff --git a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu index 1033dc65f2dcc..b3b0cd35fb300 100644 --- a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu @@ -15,10 +15,10 @@ #include #include "glog/logging.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/utils.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/phi/core/distributed/nccl_comm_context.h" #include "paddle/phi/core/flags.h" diff --git a/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu index b64aba25f89b6..ef6d0761bb636 100644 --- a/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu @@ -31,10 +31,12 @@ void SwishPlugin::terminate() TRT_NOEXCEPT {} bool SwishPlugin::supportsFormat( nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT { if (with_fp16_) { - return type == nvinfer1::DataType::kFLOAT || - type == nvinfer1::DataType::kHALF; + return (type == nvinfer1::DataType::kFLOAT || + type == nvinfer1::DataType::kHALF) && + (format == nvinfer1::TensorFormat::kLINEAR); } - return type == nvinfer1::DataType::kFLOAT; + return (type == nvinfer1::DataType::kFLOAT) && + (format == nvinfer1::TensorFormat::kLINEAR); } nvinfer1::Dims SwishPlugin::getOutputDimensions(int index, @@ -179,13 +181,11 @@ bool SwishPluginDynamic::supportsFormatCombination( if (with_fp16_) { bool res = (in.type == nvinfer1::DataType::kFLOAT || in.type == nvinfer1::DataType::kHALF); -// encounter trt crash bug -#if IS_TRT_VERSION_LT(8000) res = res && (in.format == nvinfer1::TensorFormat::kLINEAR); -#endif return res; } else { - return in.type == nvinfer1::DataType::kFLOAT; + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); } } const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; diff --git a/paddle/fluid/ir_adaptor/translator/CMakeLists.txt b/paddle/fluid/ir_adaptor/translator/CMakeLists.txt index 4ac1dc065143f..af377c6682bad 100644 --- a/paddle/fluid/ir_adaptor/translator/CMakeLists.txt +++ b/paddle/fluid/ir_adaptor/translator/CMakeLists.txt @@ -20,4 +20,4 @@ file(GLOB PD_PROGRAM_TRANSLATOR_SRCS "*.cc") cc_library( program_translator SRCS ${PD_PROGRAM_TRANSLATOR_SRCS} ${op_compat_source_file} - DEPS proto_desc pd_op_dialect pir framework_proto) + DEPS proto_desc op_dialect_vjp pir framework_proto) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 3b665d174df55..a52154ea8bea8 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -33,6 +33,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/phi/core/utils/data_type.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/builtin_type.h" @@ -72,6 +73,7 @@ using InputHandlerFn = std::function; using AttributeHandlerFn = std::function; +using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage; constexpr char kTargetDialectPrefix[] = "pd_op."; // NOLINT constexpr char kEmptyVarName[] = "@EMPTY@"; // NOLINT @@ -259,6 +261,11 @@ pir::OpInfo OpTranscriber::LoopkUpOpInfo(pir::IrContext* ctx, continue; } VarDesc* var = op_desc.Block()->FindVarRecursive(legacy_input_vars[0]); + IR_ENFORCE(var != nullptr, + "[op:%s] Input %s should not be null", + op_desc.Type(), + legacy_input_vars[0]); + if (var->GetType() == paddle::framework::proto::VarType::LOD_TENSOR) { need_inputs_sig.emplace_back("dense"); } else if (var->GetType() == @@ -467,6 +474,10 @@ std::vector OpTranscriber::GenerateOperationInput( // Vector if (legacy_input_vars.size() == 1) { VarDesc* var = op_desc.Block()->FindVarRecursive(legacy_input_vars[0]); + IR_ENFORCE(var != nullptr, + "[op:%s] Input %s should not be null", + op_desc.Type(), + legacy_input_vars[0]); if (var->GetType() == paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) { is_vector = false; @@ -599,6 +610,10 @@ OpTranscriber::GenerateOperationOutput(pir::IrContext* ctx, continue; } VarDesc* var = block->FindVarRecursive(var_name); + IR_ENFORCE(var != nullptr, + "[op:%s] Output %s should not be null", + op_desc.Type(), + var_name); VLOG(10) << "[output translating]" << "[" << op_desc.Type() << "]" << info.name << " var: " << var_name << " type: " << var->GetType(); @@ -677,7 +692,7 @@ void OpTranscriber::RecordOpResultMapping(pir::IrContext* ctx, pir::OpResult value = operation->result(idx_in_op); bool generated_by_vector = value.type().isa(); - param_map->UpdateValue( + param_map->PushValue( arg_name, VariableDefiningInfo( value, @@ -911,9 +926,10 @@ pir::OpResult TranslateDropOutStateIn(pir::IrContext* ctx, // `DropoutState` is a tensor VarDesc* dropout_state = op_desc.Block()->FindVarRecursive(legacy_output_vars[0]); - if (dropout_state == nullptr) { - IR_THROW("Unexpected: Rnn Op should have a non-empty DropoutState"); - } + IR_ENFORCE(dropout_state != nullptr, + "[op:%s] Output %s should not be null", + op_desc.Type(), + legacy_output_vars[0]); auto& type_translator = TypeTranslator::instance(); pir::Type translated_var_type = type_translator[dropout_state->GetType()](ctx, *dropout_state); @@ -1011,7 +1027,10 @@ struct DataOpTranscriber : public FeedOpTranscriber { const std::string& normalized_op_name, const OpAttributeInfoList& op_attr_infos, const OpDesc& op_desc) override { - int allocate_type = paddle::get(op_desc.GetAttr("place")); + int allocate_type = PADDLE_GET_CONST(int, op_desc.GetAttr("place")); + int var_dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); + auto phi_dtype = phi::TransToPhiDataType(var_dtype); + auto& attribute_translator = AttributeTranslator::instance(); pir::Attribute shape = attribute_translator( "paddle::dialect::IntArrayAttribute", op_desc.GetAttr("shape")); @@ -1020,8 +1039,7 @@ struct DataOpTranscriber : public FeedOpTranscriber { pir::StrAttribute::get(ctx, op_desc.GetAttrIfExists("name"))}, {"shape", shape}, - {"dtype", - paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32)}, + {"dtype", paddle::dialect::DataTypeAttribute::get(ctx, phi_dtype)}, {"place", paddle::dialect::PlaceAttribute::get( ctx, phi::Place(static_cast(allocate_type)))}, @@ -1242,6 +1260,398 @@ struct TrilAndTriuOpTranscriber : public OpTranscriber { } }; +using ValueInfo = + std::tuple, dialect::DenseTensorType, pir::OpResult>; + +ValueInfo GetTensorInfoByVarName(const OpDesc& op_desc, + const std::vector& names, + TranslationContext* param_map, + const std::string& var_name) { + IR_ENFORCE(names.size() == 1, + "Expected op[%s]'s input %s has only 1 variable, but got %d", + op_desc.Type(), + var_name, + names.size()); + const auto& name = names[0]; + IR_ENFORCE(param_map->count(name) > 0, + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + name); + const auto& defining_info = param_map->at(name); + + pir::OpResult value = defining_info.value.dyn_cast(); + IR_ENFORCE( + value, "Expected op[%s]'s input %s is not null", op_desc.Type(), name); + const pir::Type& type = value.type(); + IR_ENFORCE(type.isa(), + "Expected op[%s]'s input %s is DenseTensor but got %s", + op_desc.Type(), + name, + type); + dialect::DenseTensorType tensor_type = + type.dyn_cast(); + + std::vector shape = phi::vectorize(tensor_type.dims()); + + return std::make_tuple(shape, tensor_type, value); +} + +struct MulOpTranscriber : public OpTranscriber { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { + const std::string& target_op_name = paddle::dialect::MatmulOp::name(); + const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + IR_THROW("Op %d should have corresponding OpInfo %d", + op_desc.Type(), + target_op_name); + } + return op_info; + } + + pir::AttributeMap TranslateOpAttribute( + pir::IrContext* ctx, + const std::string& normalized_op_name, + const OpAttributeInfoList& op_attr_infos, + const OpDesc& op_desc) override { + pir::AttributeMap attribute_map = {}; + + attribute_map["transpose_x"] = pir::BoolAttribute::get(ctx, false); + attribute_map["transpose_y"] = pir::BoolAttribute::get(ctx, false); + + return attribute_map; + } + + std::vector GenerateOperationInput( + pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfoList& input_infos, + pir::Block* block) override { + const int x_num_col_dims = + PADDLE_GET_CONST(int, op_desc.GetAttr("x_num_col_dims")); + const int y_num_col_dims = + PADDLE_GET_CONST(int, op_desc.GetAttr("y_num_col_dims")); + + ValueInfo x_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("X", true), param_map, "X"); + + const auto& [x_shape, x_tensor_type, x_value] = x_info; + + IR_ENFORCE(x_num_col_dims <= static_cast(x_shape.size()), + "Expected op[%s]'s attr `x_num_col_dims` less than or equal to " + "dim of input X %s, but got %d", + op_desc.Type(), + x_shape.size(), + x_num_col_dims); + + ValueInfo y_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("Y", true), param_map, "Y"); + + const auto& [y_shape, y_tensor_type, y_value] = y_info; + + IR_ENFORCE(y_num_col_dims <= static_cast(y_shape.size()), + "Expected op[%s]'s attr `y_num_col_dims` less than or equal to " + "dim of input Y %s, but got %d", + op_desc.Type(), + y_shape.size(), + y_num_col_dims); + + pir::Builder builder(ctx, block); + + std::vector x_new_shape({ + std::max(std::accumulate(x_shape.begin(), + x_shape.begin() + x_num_col_dims, + static_cast(1), + std::multiplies()), + static_cast(-1)), + std::max(std::accumulate(x_shape.begin() + x_num_col_dims, + x_shape.end(), + static_cast(1), + std::multiplies()), + static_cast(-1)), + }); + dialect::ReshapeOp reshape_op_x = + builder.Build(x_value, x_new_shape); + pir::OpResult x_new = reshape_op_x.out(); + VLOG(6) << "[" << op_desc.Type() << "] x_shape change from " + << x_tensor_type.dims() << " to " << phi::make_ddim(x_new_shape); + + std::vector y_new_shape( + {std::max(std::accumulate(y_shape.begin(), + y_shape.begin() + y_num_col_dims, + static_cast(1), + std::multiplies()), + static_cast(-1)), + std::max(std::accumulate(y_shape.begin() + y_num_col_dims, + y_shape.end(), + static_cast(1), + std::multiplies()), + static_cast(-1))}); + + dialect::ReshapeOp reshape_op_y = + builder.Build(y_value, y_new_shape); + pir::OpResult y_new = reshape_op_y.out(); + VLOG(6) << "[" << op_desc.Type() << "] y_shape change from " + << y_tensor_type.dims() << " to " << phi::make_ddim(y_new_shape); + + return {x_new, y_new}; + } + + void RecordOpResultMapping(pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + pir::Operation* operation, + const OpOutputMapping& arg_to_idx) override { + OpTranscriber::RecordOpResultMapping( + ctx, param_map, op_desc, operation, arg_to_idx); + if (op_desc.HasOutput("Out")) { + ValueInfo out_info = GetTensorInfoByVarName( + op_desc, op_desc.Output("Out"), param_map, "Out"); + + const dialect::DenseTensorType& out_tensor_type = std::get<1>(out_info); + pir::OpResult& out_value = std::get<2>(out_info); + + const auto& output_vars = op_desc.Output("Out"); + const auto& output_name = output_vars[0]; + + const int x_num_col_dims = + PADDLE_GET_CONST(int, op_desc.GetAttr("x_num_col_dims")); + const int y_num_col_dims = + PADDLE_GET_CONST(int, op_desc.GetAttr("y_num_col_dims")); + + ValueInfo x_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("X", true), param_map, "X"); + + const std::vector& x_shape = std::get<0>(x_info); + + ValueInfo y_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("Y", true), param_map, "Y"); + + const std::vector& y_shape = std::get<0>(y_info); + + std::vector out_new_shape(x_shape.begin(), + x_shape.begin() + x_num_col_dims); + out_new_shape.insert( + out_new_shape.end(), y_shape.begin() + y_num_col_dims, y_shape.end()); + + pir::Builder builder(ctx, operation->GetParent()); + dialect::ReshapeOp reshape_op_out = + builder.Build(out_value, out_new_shape); + pir::OpResult out_new = reshape_op_out.out().dyn_cast(); + VLOG(6) << "[" << op_desc.Type() << "] out_shape change from " + << out_tensor_type.dims() << " to " + << phi::make_ddim(out_new_shape); + + param_map->PushValue(output_name, + VariableDefiningInfo(out_new, false, -1)); + } + } +}; + +struct MulGradOpTranscriber : public OpTranscriber { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { + const std::string& target_op_name = paddle::dialect::MatmulGradOp::name(); + VLOG(6) << "[op name normalizing: " << op_desc.Type() << " to " + << target_op_name; + const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + IR_THROW("Op %d should have corresponding OpInfo %d", + op_desc.Type(), + target_op_name); + } + return op_info; + } + + pir::AttributeMap TranslateOpAttribute( + pir::IrContext* ctx, + const std::string& normalized_op_name, + const OpAttributeInfoList& op_attr_infos, + const OpDesc& op_desc) override { + pir::AttributeMap attribute_map = {}; + + attribute_map["transpose_x"] = pir::BoolAttribute::get(ctx, false); + attribute_map["transpose_y"] = pir::BoolAttribute::get(ctx, false); + + return attribute_map; + } + + std::vector GenerateOperationInput( + pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfoList& input_infos, + pir::Block* block) override { + const int x_num_col_dims = + PADDLE_GET_CONST(int, op_desc.GetAttr("x_num_col_dims")); + const int y_num_col_dims = + PADDLE_GET_CONST(int, op_desc.GetAttr("y_num_col_dims")); + + ValueInfo x_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("X", true), param_map, "X"); + + const auto& [x_shape, x_tensor_type, x_value] = x_info; + + IR_ENFORCE(x_num_col_dims <= static_cast(x_shape.size()), + "Expected op[%s]'s attr `x_num_col_dims` less than or equal to " + "dim of input X %s, but got %d", + op_desc.Type(), + x_shape.size(), + x_num_col_dims); + + ValueInfo y_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("Y", true), param_map, "Y"); + + const auto& [y_shape, y_tensor_type, y_value] = y_info; + + IR_ENFORCE(y_num_col_dims <= static_cast(y_shape.size()), + "Expected op[%s]'s attr `y_num_col_dims` less than or equal to " + "dim of input Y %s, but got %d", + op_desc.Type(), + y_shape.size(), + y_num_col_dims); + + ValueInfo out_grad_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("Out@GRAD", true), param_map, "Out@GRAD"); + + const dialect::DenseTensorType& out_grad_tensor_type = + std::get<1>(out_grad_info); + pir::OpResult& out_grad_value = std::get<2>(out_grad_info); + + pir::Builder builder(ctx, block); + + std::vector x_new_shape({ + std::max(std::accumulate(x_shape.begin(), + x_shape.begin() + x_num_col_dims, + static_cast(1), + std::multiplies()), + static_cast(-1)), + std::max(std::accumulate(x_shape.begin() + x_num_col_dims, + x_shape.end(), + static_cast(1), + std::multiplies()), + static_cast(-1)), + }); + dialect::ReshapeOp reshape_op_x = + builder.Build(x_value, x_new_shape); + pir::OpResult x_new = reshape_op_x.out(); + VLOG(6) << "[" << op_desc.Type() << "] x_shape change from " + << x_tensor_type.dims() << " to " << phi::make_ddim(x_new_shape); + + std::vector y_new_shape( + {std::max(std::accumulate(y_shape.begin(), + y_shape.begin() + y_num_col_dims, + static_cast(1), + std::multiplies()), + static_cast(-1)), + std::max(std::accumulate(y_shape.begin() + y_num_col_dims, + y_shape.end(), + static_cast(1), + std::multiplies()), + static_cast(-1))}); + + dialect::ReshapeOp reshape_op_y = + builder.Build(y_value, y_new_shape); + pir::OpResult y_new = reshape_op_y.out(); + VLOG(6) << "[" << op_desc.Type() << "] y_shape change from " + << y_tensor_type.dims() << " to " << phi::make_ddim(y_new_shape); + + std::vector out_grad_new_shape( + {x_new_shape.front(), y_new_shape.back()}); + + dialect::ReshapeOp reshape_op_out_grad = + builder.Build(out_grad_value, out_grad_new_shape); + pir::OpResult out_grad_new = reshape_op_out_grad.out(); + VLOG(6) << "[" << op_desc.Type() << "] out_grad_shape change from " + << out_grad_tensor_type.dims() << " to " + << phi::make_ddim(out_grad_new_shape); + + return {x_new, y_new, out_grad_new}; + } + + void RecordOpResultMapping(pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + pir::Operation* operation, + const OpOutputMapping& arg_to_idx) override { + OpTranscriber::RecordOpResultMapping( + ctx, param_map, op_desc, operation, arg_to_idx); + + const auto& x_grad_output = op_desc.Output("X@GRAD"); + const auto& y_grad_output = op_desc.Output("Y@GRAD"); + if (x_grad_output.size() < 1 && y_grad_output.size() < 1) { + return; + } + + pir::Builder builder(ctx, operation->GetParent()); + + auto gradReshape = [&](const std::string& var_name) { + const auto& grad_output = op_desc.Output(var_name); + IR_ENFORCE(grad_output.size() == 1, + "Expected op[%s]'s output %s has only 1 variable, but got %d", + op_desc.Type(), + var_name, + grad_output.size()); + const auto& grad_var_name = grad_output[0]; + + auto idx_iter = arg_to_idx.find(grad_var_name); + if (idx_iter == arg_to_idx.end()) { + IR_THROW("op[%s] should have got its %s", op_desc.Type(), var_name); + } + auto [idx_in_op, idx_in_vec] = idx_iter->second; + VLOG(10) << "[output recording]" + << "[" << op_desc.Type() << "]" << grad_var_name << " " + << idx_in_op << " " << idx_in_vec; + + VarDesc* var_desc = + op_desc.Block()->FindVarRecursive(var_name.substr(0, 1)); + IR_ENFORCE(var_desc != nullptr, + "[op:%s] Input %s should not be null", + op_desc.Type(), + var_name.substr(0, 1)); + std::vector shape = var_desc->GetShape(); + DenseTensorTypeStorage::Dim dim = phi::make_ddim(shape); + + pir::OpResult value_res = operation->result(idx_in_op); + auto reshape_op = builder.Build(value_res, shape); + + IR_ENFORCE(value_res, + "Expected op[%s]'s input %s is not null", + op_desc.Type(), + grad_var_name); + pir::Type grad_type = value_res.type(); + IR_ENFORCE(grad_type.isa(), + "Expected op[%s]'s input %s is DenseTensor but got %s", + op_desc.Type(), + grad_var_name, + grad_type); + dialect::DenseTensorType grad_tensor_type = + grad_type.dyn_cast(); + + VLOG(10) << "[" << op_desc.Type() << "] shape of " << var_name + << " change from " << grad_tensor_type.dims() << " to " << dim; + + param_map->PushValue(grad_var_name, + VariableDefiningInfo(reshape_op.out(), false, -1)); + }; + + if (x_grad_output.size()) { + gradReshape("X@GRAD"); + } + + if (y_grad_output.size() < 1) { + return; + } + + if (y_grad_output.size()) { + gradReshape("Y@GRAD"); + } + } +}; + struct FillConstant2FullTranscriber : public OpTranscriber { pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) override { @@ -1468,6 +1878,34 @@ struct OneHotTranscriber : public OpTranscriber { }; }; +pir::Attribute TranslateDtypeForArange(pir::IrContext* ctx, + const OpDesc& op_desc, + const OpAttributeInfo& attr_info) { + IR_ENFORCE(op_desc.Input("Start").size() == 1, + "[op:%s] Input [Start]'s size should be equal to 1", + op_desc.Type()); + auto var_desc = op_desc.Block()->FindVarRecursive(op_desc.Input("Start")[0]); + IR_ENFORCE(var_desc != nullptr, + "[op:%s] Input %s should not be null", + op_desc.Type(), + op_desc.Input("Start")[0]); + auto start_proto_dtype = var_desc->GetDataType(); + auto start_phi_dtype = phi::TransToPhiDataType(start_proto_dtype); + auto dtype_attr = + paddle::dialect::DataTypeAttribute::get(ctx, start_phi_dtype); + return dtype_attr; +} + +struct ArangeOpTranscriber : public OpTranscriber { + AttributeHandlerFn GetSpecialAttributeHandlers( + const std::string& attr_name) override { + if (attr_name != "dtype") { + return nullptr; + } + return TranslateDtypeForArange; + } +}; + pir::Attribute TranslateReduceAll(pir::IrContext* ctx, const OpDesc& op_desc, const OpAttributeInfo& attr_info) { @@ -1490,8 +1928,8 @@ pir::Attribute TranslateReduceAll(pir::IrContext* ctx, struct ReduceOpTranscriber : public OpTranscriber { AttributeHandlerFn GetSpecialAttributeHandlers( - const std::string& input_name) override { - if (input_name != "axis") { + const std::string& attr_name) override { + if (attr_name != "axis") { return nullptr; } return TranslateReduceAll; @@ -1697,8 +2135,8 @@ struct ElementwiseGradTranscriber : public OpTranscriber { pir::OpResult value = operation->result(idx_in_op); pir::Builder builder(ctx, operation->GetParent()); auto reshape_op = builder.Build(value, y_shape); - param_map->UpdateValue(y_grad_var_name, - VariableDefiningInfo(reshape_op.out(), false, -1)); + param_map->PushValue(y_grad_var_name, + VariableDefiningInfo(reshape_op.out(), false, -1)); } }; @@ -1865,8 +2303,8 @@ struct FusedFeedForwardOpTranscriber : public OpTranscriber { auto output_var = output_vars[0]; auto fused_feedforward_op = operation->dyn_cast(); - param_map->UpdateValue(output_var, - VariableDefiningInfo{fused_feedforward_op.out()}); + param_map->PushValue(output_var, + VariableDefiningInfo{fused_feedforward_op.out()}); } } }; @@ -1886,6 +2324,119 @@ struct ShareBufferOpTranscriber : public OpTranscriber { } }; +struct RandIntOpTranscriber : public OpTranscriber { + std::tuple GenerateOperationOutput( + pir::IrContext* ctx, + const OpDesc& op_desc, + const OpOutputInfoList& output_infos) { + OpOutputMapping arg_to_idx; + OpOutputTypeList op_output_types = {}; + + auto& type_translator = TypeTranslator::instance(); + + const BlockDesc* block = op_desc.Block(); + std::string legacy_output_name = "Out"; + const auto& legacy_output_vars = op_desc.Output(legacy_output_name); + auto& var_name = legacy_output_vars[0]; + VarDesc* var = block->FindVarRecursive(var_name); + IR_ENFORCE(var != nullptr, + "[op:%s] Output %s should not be null", + op_desc.Type(), + var_name); + int dtype_attr_val = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); + + paddle::framework::proto::VarType::Type var_type = + static_cast(dtype_attr_val); + + pir::Type dtype = type_translator[var_type](ctx, *var); + paddle::dialect::DenseTensorTypeStorage::Dim dim = + phi::make_ddim(var->GetShape()); + paddle::dialect::DenseTensorTypeStorage::DataLayout layout = + paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED; + paddle::dialect::DenseTensorTypeStorage::LoD lod = {}; + size_t offset = 0; + pir::Type translated_var_type = paddle::dialect::DenseTensorType::get( + ctx, dtype, dim, layout, lod, offset); + arg_to_idx[var_name] = {0, 0}; + op_output_types.push_back(translated_var_type); + return {op_output_types, arg_to_idx}; + } +}; + +struct RepeatInterLeaveOpTranscriber : public OpTranscriber { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { + std::string target_op_name; + if (op_desc.HasInput("RepeatsTensor") && + !op_desc.Input("RepeatsTensor").empty()) { + target_op_name = "pd_op.repeat_interleave_with_tensor_index"; + } else { + target_op_name = "pd_op.repeat_interleave"; + } + const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); + return op_info; + } + + std::vector GenerateOperationInput( + pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfoList& input_infos, + pir::Block* block) override { + std::vector op_inputs; + auto x_names = op_desc.Input("X", true); + auto input = param_map->at(x_names[0]).value; + op_inputs.push_back(input); + if (op_desc.HasInput("RepeatsTensor") && + !op_desc.Input("RepeatsTensor").empty()) { + auto repeats_names = op_desc.Input("RepeatsTensor", true); + input = param_map->at(repeats_names[0]).value; + op_inputs.push_back(input); + } + return op_inputs; + } +}; + +struct RepeatInterLeaveGradOpTranscriber : public OpTranscriber { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { + std::string target_op_name; + if (op_desc.HasInput("RepeatsTensor") && + !op_desc.Input("RepeatsTensor").empty()) { + target_op_name = "pd_op.repeat_interleave_with_tensor_index_grad"; + } else { + target_op_name = "pd_op.repeat_interleave_grad"; + } + const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); + return op_info; + } + + std::vector GenerateOperationInput( + pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfoList& input_infos, + pir::Block* block) override { + std::vector op_inputs; + auto x_names = op_desc.Input("X", true); + auto input = param_map->at(x_names[0]).value; + op_inputs.push_back(input); + if (op_desc.HasInput("RepeatsTensor") && + !op_desc.Input("RepeatsTensor").empty()) { + auto repeats_names = op_desc.Input("RepeatsTensor", true); + input = param_map->at(repeats_names[0]).value; + op_inputs.push_back(input); + } + auto out_grad_names = op_desc.Input("Out@GRAD", true); + input = param_map->at(out_grad_names[0]).value; + op_inputs.push_back(input); + + return op_inputs; + } +}; + OpTranslator::OpTranslator() { pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); @@ -1893,9 +2444,10 @@ OpTranslator::OpTranslator() { general_handler = OpTranscriber(); special_handlers["add_n"] = AddNOpTranscriber(); special_handlers["assign_value"] = AssignValueOpTranscriber(); + special_handlers["range"] = ArangeOpTranscriber(); special_handlers["cast"] = CastOpTranscriber(); - special_handlers["feed"] = FeedOpTranscriber(); special_handlers["data"] = DataOpTranscriber(); + special_handlers["feed"] = FeedOpTranscriber(); special_handlers["fetch"] = FetchOpTranscriber(); special_handlers["fetch_v2"] = FetchOpTranscriber(); special_handlers["fill_constant"] = FillConstantTranscriber(); @@ -1905,16 +2457,22 @@ OpTranslator::OpTranslator() { special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber(); special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber(); special_handlers["one_hot_v2"] = OneHotTranscriber(); + special_handlers["randint"] = RandIntOpTranscriber(); special_handlers["reduce_all"] = ReduceOpTranscriber(); special_handlers["reduce_any"] = ReduceOpTranscriber(); + special_handlers["repeat_interleave"] = RepeatInterLeaveOpTranscriber(); + special_handlers["repeat_interleave_grad"] = + RepeatInterLeaveGradOpTranscriber(); special_handlers["rnn"] = RnnOpTranscriber(); - special_handlers["shadow_output"] = ShadowOutputOpTranscriber(); - special_handlers["share_buffer"] = ShareBufferOpTranscriber(); special_handlers["set_value"] = LegacySetValueDispatcher(); special_handlers["set_value_grad"] = SetValueGradOpTranscriber(); + special_handlers["shadow_output"] = ShadowOutputOpTranscriber(); + special_handlers["share_buffer"] = ShareBufferOpTranscriber(); special_handlers["split"] = SplitOpTranscriber(); special_handlers["sum"] = AddNOpTranscriber(); special_handlers["tril_triu"] = TrilAndTriuOpTranscriber(); + special_handlers["mul"] = MulOpTranscriber(); + special_handlers["mul_grad"] = MulGradOpTranscriber(); // special handler for elementwise ops with axis != -1 // note(lyk): maybe we should do this by a pass, which seems more reasonable diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index dba2bae8dc911..88b4e45a2ba9d 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -35,7 +35,7 @@ #include "paddle/pir/core/operation.h" #include "paddle/pir/core/value.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" -#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" namespace paddle { namespace translator { @@ -46,7 +46,7 @@ using VarDesc = ::paddle::framework::VarDesc; using TCKey = TranslationContext::Key; using TCValue = TranslationContext::Value; -using TCConatiner = TranslationContext::Conatiner; +using TCContainer = TranslationContext::Container; const std::unordered_set ProgramTranslator::no_cast_var_names = { "feed", @@ -60,40 +60,56 @@ const std::unordered_set ProgramTranslator::unsupported_ops = { static std::vector GetCondOpIds(const BlockDesc& src_block, uint64_t first_id) { - std::vector op_list = {first_id}; - if (((first_id + 1) < src_block.OpSize()) && - (src_block.Op(static_cast(first_id + 1))->Type() == "logical_not")) { - op_list.emplace_back(first_id + 1); + uint64_t temp_id = first_id; + // add conditional_block + std::vector op_list = {temp_id}; + temp_id++; + // add logical_not + if ((temp_id < src_block.OpSize()) && + (src_block.Op(static_cast(temp_id))->Type() == "logical_not")) { + op_list.emplace_back(temp_id); + temp_id++; } - if (((first_id + 2) < src_block.OpSize()) && - (src_block.Op(static_cast(first_id + 2))->Type() == + // add conditional_block + if ((temp_id < src_block.OpSize()) && + (src_block.Op(static_cast(temp_id))->Type() == "conditional_block")) { - op_list.emplace_back(first_id + 2); + op_list.emplace_back(temp_id); + temp_id++; } - if (((first_id + 3) < src_block.OpSize()) && - (src_block.Op(static_cast(first_id + 3))->Type() == "cast")) { - op_list.emplace_back(first_id + 3); + // add cast + if ((temp_id < src_block.OpSize()) && + (src_block.Op(static_cast(temp_id))->Type() == "cast")) { + op_list.emplace_back(temp_id); + temp_id++; } // Note(zhangbo): Some output variables are input, without select_input op. - std::vector output_names = - src_block.Op(static_cast(first_id))->Output("Out"); - std::vector input_names = - src_block.Op(static_cast(first_id))->Input("Input"); - std::vector diffs(output_names.size()); - auto iter = std::set_difference(output_names.begin(), - output_names.end(), - input_names.begin(), - input_names.end(), - diffs.begin()); - diffs.resize(iter - diffs.begin()); - size_t output_size = diffs.size(); - for (size_t i = 0; i < output_size; i++) { - if (((first_id + 4 + i) < src_block.OpSize()) && - (src_block.Op(static_cast(first_id + 4 + i))->Type() == - "select_input")) { - op_list.emplace_back(first_id + 4 + i); + std::vector init_op_list; + while (temp_id < src_block.OpSize()) { + if ((src_block.Op(static_cast(temp_id))->Type() == "fill_constant") || + (src_block.Op(static_cast(temp_id))->Type() == "assign_value")) { + init_op_list.emplace_back(temp_id); + temp_id++; + } else { + break; } } + std::vector select_input_op_list; + while (temp_id < src_block.OpSize()) { + if (src_block.Op(static_cast(temp_id))->Type() == "select_input") { + select_input_op_list.emplace_back(temp_id); + temp_id++; + } else { + break; + } + } + + if (select_input_op_list.size() > 0) { + op_list.insert(op_list.end(), init_op_list.begin(), init_op_list.end()); + } + op_list.insert( + op_list.end(), select_input_op_list.begin(), select_input_op_list.end()); + return op_list; } @@ -114,66 +130,86 @@ const std::string& ConditionBlockCombination::CondVarName() const { return op_list_[0]->Input("Cond")[0]; } -size_t ConditionBlockCombination::OutputSize() const { - std::vector output_names = op_list_[0]->Output("Out"); - std::vector input_names = op_list_[0]->Input("Input"); - std::vector diffs(output_names.size()); - auto iter = std::set_difference(output_names.begin(), - output_names.end(), - input_names.begin(), - input_names.end(), - diffs.begin()); - diffs.resize(iter - diffs.begin()); - return diffs.size(); -} - -std::vector<::paddle::framework::VarDesc*> +std::vector> ConditionBlockCombination::OutputVars() const { - std::vector<::paddle::framework::VarDesc*> outputs; - if (this->OutputSize() > 0) { - for (size_t i = 4; i < op_list_.size(); i++) { - outputs.emplace_back(op_list_[i]->Block()->FindVarRecursive( - op_list_[i]->Output("Out")[0])); + std::vector<::paddle::framework::VarDesc*> if_outputs; + std::vector<::paddle::framework::VarDesc*> true_block_outputs; + std::vector<::paddle::framework::VarDesc*> false_block_outputs; + for (::paddle::framework::OpDesc* op : op_list_) { + if (op->Type() == "select_input") { + if_outputs.emplace_back( + op->Block()->FindVarRecursive(op->Output("Out")[0])); + true_block_outputs.emplace_back( + op->Block()->FindVarRecursive(op->Input("X")[1])); + false_block_outputs.emplace_back( + op->Block()->FindVarRecursive(op->Input("X")[0])); } } - return outputs; + return {if_outputs, true_block_outputs, false_block_outputs}; +} + +size_t ConditionBlockCombination::MainOutputSize() const { + return OutputVars()[0].size(); } std::vector ConditionBlockCombination::TrueBlockOutputVarNames() const { - std::vector output_names = op_list_[0]->Output("Out"); - std::vector input_names = op_list_[0]->Input("Input"); - std::vector diffs(output_names.size()); - auto iter = std::set_difference(output_names.begin(), - output_names.end(), - input_names.begin(), - input_names.end(), - diffs.begin()); - diffs.resize(iter - diffs.begin()); - return diffs; + std::vector output_names; + for (::paddle::framework::OpDesc* op : op_list_) { + if (op->Type() == "select_input") { + output_names.emplace_back(op->Input("X")[1]); + } + } + return output_names; } -std::vector ConditionBlockCombination::FalseBlockOutputVarNames() - const { - if (op_list_.size() > 1) { - std::vector output_names = op_list_[2]->Output("Out"); - std::vector input_names = op_list_[2]->Input("Input"); - std::vector diffs(output_names.size()); - auto iter = std::set_difference(output_names.begin(), - output_names.end(), - input_names.begin(), - input_names.end(), - diffs.begin()); - diffs.resize(iter - diffs.begin()); - return diffs; - } - return {""}; +std::vector<::paddle::framework::OpDesc*> +ConditionBlockCombination::TrueBlockInitOps() const { + std::vector<::paddle::framework::OpDesc*> init_ops; + std::vector output_names = TrueBlockOutputVarNames(); + for (::paddle::framework::OpDesc* op : op_list_) { + if ((op->Type() == "fill_constant") || (op->Type() == "assign_value")) { + auto out_name = op->Output("Out")[0]; + if (std::find(output_names.begin(), output_names.end(), out_name) != + output_names.end()) { + init_ops.emplace_back(op); + } + } + } + return init_ops; } int ConditionBlockCombination::TrueBlockId() const { return op_list_[0]->GetBlockAttrId("sub_block"); } +std::vector ConditionBlockCombination::FalseBlockOutputVarNames() + const { + std::vector output_names; + for (::paddle::framework::OpDesc* op : op_list_) { + if (op->Type() == "select_input") { + output_names.emplace_back(op->Input("X")[0]); + } + } + return output_names; +} + +std::vector<::paddle::framework::OpDesc*> +ConditionBlockCombination::FalseBlockInitOps() const { + std::vector<::paddle::framework::OpDesc*> init_ops; + std::vector output_names = FalseBlockOutputVarNames(); + for (::paddle::framework::OpDesc* op : op_list_) { + if ((op->Type() == "fill_constant") || (op->Type() == "assign_value")) { + auto out_name = op->Output("Out")[0]; + if (std::find(output_names.begin(), output_names.end(), out_name) != + output_names.end()) { + init_ops.emplace_back(op); + } + } + } + return init_ops; +} + int ConditionBlockCombination::FalseBlockId() const { if (op_list_.size() > 1) { return op_list_[2]->GetBlockAttrId("sub_block"); @@ -210,10 +246,9 @@ bool ConditionBlockCombination::Verify( return false; } } else { - if (op_list[id]->Type() != "select_input") { - return false; - } - if (op_list[id]->Input("Mask")[0] != op_list[3]->Output("Out")[0]) { + if ((op_list[id]->Type() != "select_input") && + (op_list[id]->Type() != "fill_constant") && + (op_list[id]->Type() != "assign_value")) { return false; } } @@ -227,6 +262,9 @@ const TCValue& TranslationContext::operator[](const TCKey& key) const { const TCValue& TranslationContext::at(const TCKey& key) const { auto it = container_.find(key); + if (it == container_.end() && parent_) { + return parent_->at(key); + } PADDLE_ENFORCE_NE(it, container_.end(), platform::errors::InvalidArgument( @@ -243,12 +281,13 @@ const TCValue& TranslationContext::at(const TCKey& key) const { size_t TranslationContext::count(const TCKey& key) const { auto it = container_.find(key); if (it == container_.end()) { + if (parent_) return parent_->count(key); return 0u; } const auto& values = it->second; PADDLE_ENFORCE_NE( values.size(), - 0, + 0u, platform::errors::InvalidArgument( "param %s should have size > 0, but get:%d", key, values.size())); return values.size(); @@ -261,12 +300,9 @@ void TranslationContext::PopValue(const Key& key) { container_[key].pop_back(); } -void TranslationContext::UpdateValue(const Key& key, const Value& value) { - auto& vec = container_[key]; - if (vec.empty()) - vec.push_back(value); - else - vec.back() = value; +TranslationContext* TranslationContext::CreateInnerContext() { + sons_.emplace_back(std::make_unique(this)); + return sons_.back().get(); } ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, @@ -282,6 +318,7 @@ void ProgramTranslator::Translate() { TranslateBlock(legacy_program_->Block(0), 0, legacy_program_->Block(0).OpSize(), + ¶m_map_, program_->block()); SetParameterFromSingleBlock(legacy_program_->Block(0)); @@ -301,9 +338,11 @@ void ProgramTranslator::TranslateBlock( const BlockDesc& src_block, uint64_t start_id, uint64_t end_id, - pir::Block* dest_block, + TranslationContext* translation_ctx, + pir::Block* dst_block, bool for_cond_block, - std::vector skip_cond_assign) { + const std::vector& cond_sub_block_outputs, + const std::vector<::paddle::framework::OpDesc*>& cond_init_ops) { VLOG(8) << "=============>start to translate a block"; PADDLE_ENFORCE( (src_block.OpSize() >= end_id) && (start_id <= end_id), @@ -315,7 +354,7 @@ void ProgramTranslator::TranslateBlock( src_block.OpSize())); std::unordered_map translate_completed; - std::vector assign_inputs; + std::map assign_output_2_input; for (uint64_t op_id = start_id; op_id < end_id; op_id++) { if (translate_completed.count(op_id) && translate_completed.at(op_id)) { continue; @@ -330,58 +369,70 @@ void ProgramTranslator::TranslateBlock( "Not support translated %s op", op->Type())); if (op->Type() == "conditional_block") { - std::vector cond_op_list = {op}; std::vector cond_op_ids = GetCondOpIds(src_block, op_id); ConditionBlockCombination cond_op_combination(src_block, cond_op_ids); - pir::Operation* if_op = - TranslateCondIfOperation(cond_op_combination, dest_block); + pir::Operation* if_op = TranslateCondIfOperation( + cond_op_combination, translation_ctx, dst_block); for (auto cond_id : cond_op_ids) { translate_completed[cond_id] = true; } VLOG(10) << "[op translated][conditional_block]" << if_op; } else if (op->Type() == "while") { - TranslateWhileOperation(op, dest_block); + TranslateWhileOperation(op, translation_ctx, dst_block); } else { if (for_cond_block && op->Type() == "assign" && - std::count(skip_cond_assign.begin(), - skip_cond_assign.end(), + std::count(cond_sub_block_outputs.begin(), + cond_sub_block_outputs.end(), op->Output("Out")[0])) { - assign_inputs.push_back(op->Input("X")[0]); + assign_output_2_input[op->Output("Out")[0]] = op->Input("X")[0]; translate_completed[op_id] = true; } else { - TranslateGeneralOperation(op, dest_block); + TranslateGeneralOperation(op, translation_ctx, dst_block); translate_completed[op_id] = true; } } } + // NOTE(zhangbo): If conditional_block operator has output, the cf.yeild // operator needs to be inserted if (for_cond_block) { + // insert init ops + for (::paddle::framework::OpDesc* init_op : cond_init_ops) { + TranslateGeneralOperation(init_op, translation_ctx, dst_block); + } + // insert yeild op std::vector yeild_inputs; - for (size_t id = 0; id < assign_inputs.size(); id++) { - yeild_inputs.emplace_back(param_map_[assign_inputs[id]].value); + for (auto output_name : cond_sub_block_outputs) { + if (assign_output_2_input.count(output_name) != 0) { + yeild_inputs.emplace_back( + (*translation_ctx)[assign_output_2_input[output_name]].value); + } else { + yeild_inputs.emplace_back((*translation_ctx)[output_name].value); + } } pir::AttributeMap attribute_map; auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); pir::Operation* yeild_op = pir::Operation::Create(yeild_inputs, attribute_map, {}, yeild_info); - dest_block->push_back(yeild_op); + dst_block->push_back(yeild_op); } } pir::Operation* ProgramTranslator::TranslateCondIfOperation( - const ConditionBlockCombination& cond_ops, pir::Block* dest_block) { + const ConditionBlockCombination& cond_ops, + TranslationContext* translation_ctx, + pir::Block* dst_block) { auto& type_translator = TypeTranslator::instance(); auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::IfOp::name()); std::vector op_inputs = { - param_map_[cond_ops.CondVarName()].value}; + (*translation_ctx)[cond_ops.CondVarName()].value}; // NOTE(zhangbo): Now paddle::dialect::IfOp has 0 attribute pir::AttributeMap attribute_map; std::vector op_output_types; std::vector<::paddle::framework::VarDesc*> output_vardescs = - cond_ops.OutputVars(); + cond_ops.OutputVars()[0]; for (auto var_desc : output_vardescs) { IR_ENFORCE(var_desc != nullptr, "[control flow] Output should not be null"); pir::Type translated_var_type = @@ -394,11 +445,11 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( op_inputs, attribute_map, op_output_types, op_info, 2); for (size_t i = 0; i < output_vardescs.size(); i++) { - param_map_.PushValue(output_vardescs[i]->Name(), - VariableDefiningInfo(operation->result(i))); + translation_ctx->PushValue(output_vardescs[i]->Name(), + VariableDefiningInfo(operation->result(i))); } - dest_block->push_back(operation); + dst_block->push_back(operation); VLOG(4) << "[general op][conditional_block] IfOp creation end."; if (cond_ops.TrueBlockId() != -1) { @@ -406,12 +457,17 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( legacy_program_->Block(cond_ops.TrueBlockId()); pir::Region& true_region = operation->region(0); if (true_region.empty()) true_region.emplace_back(); + + auto* true_block_context = translation_ctx->CreateInnerContext(); + TranslateBlock(true_sub_block, 0, true_sub_block.OpSize(), + true_block_context, true_region.front(), true, - cond_ops.TrueBlockOutputVarNames()); + cond_ops.TrueBlockOutputVarNames(), + cond_ops.TrueBlockInitOps()); } VLOG(4) << "[general op][conditional_block] IfOp true block translate end."; @@ -420,12 +476,15 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( legacy_program_->Block(cond_ops.FalseBlockId()); pir::Region& false_region = operation->region(1); if (false_region.empty()) false_region.emplace_back(); + auto* false_block_context = translation_ctx->CreateInnerContext(); TranslateBlock(false_sub_block, 0, false_sub_block.OpSize(), + false_block_context, false_region.front(), true, - cond_ops.FalseBlockOutputVarNames()); + cond_ops.FalseBlockOutputVarNames(), + cond_ops.FalseBlockInitOps()); } VLOG(4) << "[general op][conditional_block] IfOp false block translate end."; @@ -434,8 +493,10 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( return operation; } -void ProgramTranslator::TranslateWhileOperation(const OpDesc* op, - pir::Block* dest_block) { +void ProgramTranslator::TranslateWhileOperation( + const OpDesc* op, + TranslationContext* translation_ctx, + pir::Block* dst_block) { VLOG(8) << "=============>Start to translate while op:" << op; auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block")); int index = static_cast(sub_block.OpSize()) - 1; @@ -443,7 +504,7 @@ void ProgramTranslator::TranslateWhileOperation(const OpDesc* op, while (index >= 0) { auto sub_op = sub_block.Op(index); if (sub_op->Type() == "assign" && - param_map_.count(sub_op->Output("Out")[0]) > 0) { + translation_ctx->count(sub_op->Output("Out")[0]) > 0) { loop_vars_reverse.emplace_back(sub_op->Output("Out")[0], sub_op->Input("X")[0]); --index; @@ -460,55 +521,54 @@ void ProgramTranslator::TranslateWhileOperation(const OpDesc* op, "condition var")); auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::WhileOp::name()); std::vector op_inputs{ - param_map_.at(loop_vars_reverse[0].first).value}; + translation_ctx->at(loop_vars_reverse[0].first).value}; std::vector op_outputs_type; auto body_block = new pir::Block(); + auto* body_block_context = translation_ctx->CreateInnerContext(); for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) { auto& name = loop_vars_reverse[idx].first; - auto val = param_map_.at(name).value; - auto val_type = val.type(); - op_inputs.push_back(val); + auto& tc_value = translation_ctx->at(name); + auto val_type = tc_value.value.type(); + op_inputs.push_back(tc_value.value); op_outputs_type.push_back(val_type); - param_map_.PushValue(name, body_block->AddArgument(val_type)); + body_block_context->PushValue(name, body_block->AddArgument(val_type)); } pir::Operation* while_op = pir::Operation::Create(op_inputs, {}, op_outputs_type, op_info, 1); - dest_block->push_back(while_op); + dst_block->push_back(while_op); while_op->region(0).push_back(body_block); - TranslateBlock(sub_block, 0, index + 1, body_block); + TranslateBlock(sub_block, 0, index + 1, body_block_context, body_block); auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); std::vector yeild_inputs{ - param_map_.at(loop_vars_reverse[0].second).value}; + body_block_context->at(loop_vars_reverse[0].second).value}; for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) { auto& name = loop_vars_reverse[idx].second; - yeild_inputs.push_back(param_map_.at(name).value); + yeild_inputs.push_back(body_block_context->at(name).value); } body_block->push_back( pir::Operation::Create(yeild_inputs, {}, {}, yeild_info)); - - for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) { - auto& name = loop_vars_reverse[idx].first; - param_map_.PopValue(name); - } auto name_iter = loop_vars_reverse.rbegin(); for (size_t idx = 0; idx < while_op->num_results(); ++idx) { - param_map_.UpdateValue(name_iter++->first, while_op->result(idx)); + translation_ctx->PushValue(name_iter++->first, while_op->result(idx)); } + while_op->Verify(); VLOG(8) << "=============>end to translate while op:" << op; } -void ProgramTranslator::TranslateGeneralOperation(const OpDesc* src_op, - pir::Block* dest_block) { +void ProgramTranslator::TranslateGeneralOperation( + const OpDesc* src_op, + TranslationContext* translation_ctx, + pir::Block* dst_block) { auto& op_translator = OpTranslator::instance(); OpTranslateFn& fn = op_translator[src_op->Type()]; if (src_op->Type() == "shadow_output") { - if (!param_map_.count(src_op->Input("x")[0])) { + if (!translation_ctx->count(src_op->Input("x")[0])) { return; } } - pir::Operation* operation = fn(ctx_, ¶m_map_, *src_op, dest_block); + pir::Operation* operation = fn(ctx_, translation_ctx, *src_op, dst_block); VLOG(10) << "[op translated][general]" << operation << "end"; } @@ -597,20 +657,6 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) { } } -void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) { - auto& op_translator = OpTranslator::instance(); - for (auto op : block.AllOps()) { - OpTranslateFn& fn = op_translator[op->Type()]; - if (op->Type() == "shadow_output") { - if (!param_map_.count(op->Input("x")[0])) { - continue; - } - } - pir::Operation* operation = fn(ctx_, ¶m_map_, *op, program_->block()); - VLOG(10) << "[op translated][special]" << operation; - } -} - void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { const auto& ops = block.AllOps(); for (auto op_desc = ops.rbegin(); op_desc != ops.rend(); op_desc++) { @@ -684,10 +730,7 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue( } for (const auto& value_info : value_list) { pir::OpResult value = value_info.value.dyn_cast(); - if (!value) { - PADDLE_THROW(phi::errors::PreconditionNotMet( - "Value of [%s] can not ber None", var_name)); - } + if (!value) continue; auto* defining_op = value.owner(); PADDLE_ENFORCE_NOT_NULL( defining_op, @@ -725,10 +768,7 @@ void ProgramTranslator::SetIsPersisableAttributeForAllValue( } for (const auto& value_info : value_list) { pir::OpResult value = value_info.value.dyn_cast(); - if (!value) { - PADDLE_THROW(phi::errors::PreconditionNotMet( - "Value of [%s] can not ber None", var_name)); - } + if (!value) continue; auto* defining_op = value.owner(); PADDLE_ENFORCE_NOT_NULL( defining_op, @@ -753,15 +793,17 @@ void ProgramTranslator::SetIsPersisableAttributeForAllValue( } } -std::unordered_map> -ProgramTranslator::VarDesc2Value() { - std::unordered_map> var_desc_2_value; +std::unordered_map> +ProgramTranslator::VarDesc2OpResult() { + std::unordered_map> + var_desc_2_opresult; for (const auto& [var_name, value_info_list] : param_map_) { for (const auto& value_info : value_info_list) { - var_desc_2_value[var_name].push_back(value_info.value); + var_desc_2_opresult[var_name].push_back( + value_info.value.dyn_cast()); } } - return var_desc_2_value; + return var_desc_2_opresult; } } // namespace translator diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index 9d9e1b99552af..e4c6a517d7dc1 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -18,6 +18,7 @@ #include #include #include + #include "paddle/fluid/framework/op_call_stack.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/program_desc.h" @@ -49,14 +50,25 @@ class ConditionBlockCombination { public: ConditionBlockCombination(const ::paddle::framework::BlockDesc& src_block, const std::vector& op_ids); + const std::string& CondVarName() const; - int TrueBlockId() const; - int FalseBlockId() const; - size_t OutputSize() const; - std::vector<::paddle::framework::VarDesc*> OutputVars() const; + + std::vector> OutputVars() const; + + size_t MainOutputSize() const; + std::vector TrueBlockOutputVarNames() const; + + std::vector<::paddle::framework::OpDesc*> TrueBlockInitOps() const; + + int TrueBlockId() const; + std::vector FalseBlockOutputVarNames() const; + std::vector<::paddle::framework::OpDesc*> FalseBlockInitOps() const; + + int FalseBlockId() const; + private: bool Verify(const std::vector<::paddle::framework::OpDesc*>& op_list); @@ -68,9 +80,10 @@ class TranslationContext { using Key = std::string; using Value = VariableDefiningInfo; using ValueList = std::vector; - using Conatiner = std::unordered_map; + using Container = std::unordered_map; TranslationContext() {} + explicit TranslationContext(TranslationContext* parent) : parent_(parent) {} ~TranslationContext() {} const Value& operator[](const Key& key) const; @@ -78,15 +91,18 @@ class TranslationContext { size_t count(const Key& key) const; // Caution: not exactly same as count in stl library - void UpdateValue(const Key& key, const Value& value); void PushValue(const Key& key, const Value& value); void PopValue(const Key& key); + TranslationContext* CreateInnerContext(); - Conatiner::const_iterator begin() const { return container_.begin(); } - Conatiner::const_iterator end() const { return container_.end(); } + Container::const_iterator begin() const { return container_.begin(); } + Container::const_iterator end() const { return container_.end(); } private: - Conatiner container_; + Container container_; + TranslationContext* parent_ = nullptr; + std::vector> + sons_; // used to seperate different block }; class ProgramTranslator { @@ -101,7 +117,8 @@ class ProgramTranslator { void Translate(); - std::unordered_map> VarDesc2Value(); + std::unordered_map> + VarDesc2OpResult(); private: const ProgramDesc* legacy_program_; // not owned @@ -122,23 +139,31 @@ class ProgramTranslator { static const std::unordered_set unsupported_ops; - void TranslateBlock(const BlockDesc& src_block, - uint64_t start_id, - uint64_t end_id, - pir::Block* dest_block, - bool for_cond_block = false, - std::vector skip_cond_assign = {}); - void TranslateGeneralOperation(const OpDesc* src_op, pir::Block* dest_block); + void TranslateBlock( + const BlockDesc& src_block, + uint64_t start_id, + uint64_t end_id, + TranslationContext* translation_ctx, + pir::Block* dst_block, + bool for_cond_block = false, + const std::vector& cond_sub_block_outputs = {}, + const std::vector<::paddle::framework::OpDesc*>& cond_init_ops = {}); + void TranslateGeneralOperation(const OpDesc* src_op, + TranslationContext* translation_ctx, + pir::Block* dst_block); void GetParameterForSingleBlock(const BlockDesc& block); - void InsertOperationToSingleBlock(const BlockDesc& block); void SetParameterFromSingleBlock(const BlockDesc& block); void SetStopGradientAttributeForAllValue(const BlockDesc& block); void SetIsPersisableAttributeForAllValue(const BlockDesc& block); /// Translate methods for control flow ops. pir::Operation* TranslateCondIfOperation( - const ConditionBlockCombination& cond_ops, pir::Block* dest_block); - void TranslateWhileOperation(const OpDesc* op, pir::Block* dest_block); + const ConditionBlockCombination& cond_ops, + TranslationContext* translation_ctx, + pir::Block* dst_block); + void TranslateWhileOperation(const OpDesc* op, + TranslationContext* translation_ctx, + pir::Block* dst_block); }; } // namespace translator diff --git a/paddle/fluid/ir_adaptor/translator/translate.cc b/paddle/fluid/ir_adaptor/translator/translate.cc index 0f98e557743fc..7a7081fe1acbf 100644 --- a/paddle/fluid/ir_adaptor/translator/translate.cc +++ b/paddle/fluid/ir_adaptor/translator/translate.cc @@ -34,8 +34,9 @@ std::unique_ptr TranslateLegacyProgramToProgram( auto program = std::make_unique(ctx); translator::ProgramTranslator program_translator(&legacy_program, program.get()); + VLOG(6) << "begin to translate"; program_translator.Translate(); - + VLOG(6) << "translate done"; return program; } diff --git a/paddle/fluid/ir_adaptor/translator/utils.cc b/paddle/fluid/ir_adaptor/translator/utils.cc index e8102e4e686a2..7f50115c5c578 100644 --- a/paddle/fluid/ir_adaptor/translator/utils.cc +++ b/paddle/fluid/ir_adaptor/translator/utils.cc @@ -59,7 +59,7 @@ pir::Operation* InsertSliceOperationForTarget( op_info); block->push_back(operation); pir::OpResult target_op_result = operation->result(0); - param_map->UpdateValue(arg_name, VariableDefiningInfo(target_op_result)); + param_map->PushValue(arg_name, VariableDefiningInfo(target_op_result)); return operation; } diff --git a/paddle/fluid/memory/stats.cc b/paddle/fluid/memory/stats.cc index 0289859dff30e..e18646f0e82bf 100644 --- a/paddle/fluid/memory/stats.cc +++ b/paddle/fluid/memory/stats.cc @@ -15,8 +15,13 @@ limitations under the License. */ #include "paddle/fluid/memory/stats.h" #include "paddle/fluid/memory/allocation/spin_lock.h" +#include "paddle/fluid/platform/flags.h" #include "paddle/phi/core/macros.h" +PADDLE_DEFINE_EXPORTED_bool( + log_memory_stats, + false, + "Log memory stats after each op runs, just used for debug."); namespace paddle { namespace memory { @@ -104,6 +109,28 @@ void HostMemoryStatUpdate(const std::string& stat_type, StatRegistry::GetInstance()->Update("Host" + stat_type, dev_id, increment); } +void LogDeviceMemoryStats(const platform::Place& place, + const std::string& op_name) { + if (FLAGS_log_memory_stats && platform::is_gpu_place(place)) { + VLOG(0) << "After launching op_name: " << op_name << ", " + << "memory_allocated: " + << static_cast(memory::DeviceMemoryStatCurrentValue( + "Allocated", place.device)) / + 1024 / 1024 + << " MB, " + << "max_memory_allocated: " + << static_cast(memory::DeviceMemoryStatPeakValue( + "Allocated", place.device)) / + 1024 / 1024 + << " MB, " + << "max_memory_reserved: " + << static_cast(memory::DeviceMemoryStatPeakValue( + "Reserved", place.device)) / + 1024 / 1024 + << " MB"; + } +} + #define DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, id) \ StatRegistry::GetInstance()->Register( \ "Device" #item, id, Stat::GetInstance()); diff --git a/paddle/fluid/memory/stats.h b/paddle/fluid/memory/stats.h index bd4761f41116e..d2c8b04bc70ab 100644 --- a/paddle/fluid/memory/stats.h +++ b/paddle/fluid/memory/stats.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/platform/place.h" #include "paddle/phi/common/thread_data_registry.h" #include "paddle/utils/string/string_helper.h" @@ -122,6 +123,9 @@ void HostMemoryStatUpdate(const std::string& stat_type, int dev_id, int64_t increment); +void LogDeviceMemoryStats(const platform::Place& place, + const std::string& op_name); + #define DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, id) \ case id: \ stat = paddle::memory::Stat< \ diff --git a/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu b/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu index 1a3823767ad63..b0bebf5202eee 100644 --- a/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu +++ b/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu @@ -412,7 +412,7 @@ TEST_F(StreamSafeCUDAAllocTest, CUDAMutilThreadMutilStreamTest) { CheckResult(); } -#ifdef PADDLE_WITH_CUDA +#if (defined(PADDLE_WITH_CUDA) && (CUDA_VERSION >= 11000)) TEST_F(StreamSafeCUDAAllocTest, CUDAGraphTest) { MultiStreamRun(); CUDAGraphRun(); diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 1d45cee715409..270e0debbdb1b 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -35,8 +35,6 @@ namespace operators { void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNorm"); - OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "BatchNorm"); - OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "BatchNorm"); OP_INOUT_CHECK(ctx->HasInput("Mean"), "Input", "Mean", "BatchNorm"); OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "Variance", "BatchNorm"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "BatchNorm"); @@ -118,48 +116,54 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ? x_dims[1] : x_dims[x_dims.size() - 1]); - auto scale_dim = ctx->GetInputDim("Scale"); - auto bias_dim = ctx->GetInputDim("Bias"); + if (ctx->HasInput("Scale")) { + auto scale_dim = ctx->GetInputDim("Scale"); + PADDLE_ENFORCE_EQ( + scale_dim.size(), + 1UL, + platform::errors::InvalidArgument( + "ShapeError: the dimension of scale must equal to 1." + "But received: the shape of scale is [%s], the dimension " + "of scale is [%d]", + scale_dim, + scale_dim.size())); + } - PADDLE_ENFORCE_EQ( - scale_dim.size(), - 1UL, - platform::errors::InvalidArgument( - "ShapeError: the dimension of scale must equal to 1." - "But received: the shape of scale is [%s], the dimension " - "of scale is [%d]", - scale_dim, - scale_dim.size())); - PADDLE_ENFORCE_EQ(bias_dim.size(), - 1UL, - platform::errors::InvalidArgument( - "ShapeError: the dimension of bias must equal to 1." - "But received: the shape of bias is [%s],the dimension " - "of bias is [%d]", - bias_dim, - bias_dim.size())); + if (ctx->HasInput("Bias")) { + auto bias_dim = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ( + bias_dim.size(), + 1UL, + platform::errors::InvalidArgument( + "ShapeError: the dimension of bias must equal to 1." + "But received: the shape of bias is [%s],the dimension " + "of bias is [%d]", + bias_dim, + bias_dim.size())); + } bool check = true; - if ((!ctx->IsRuntime()) && - (phi::product(scale_dim) <= 0 || phi::product(bias_dim) <= 0)) { + if (!ctx->HasInput("Scale") || !ctx->HasInput("Bias") || + ((!ctx->IsRuntime()) && (phi::product(ctx->GetInputDim("Scale")) <= 0 || + phi::product(ctx->GetInputDim("Bias")) <= 0))) { check = false; } if (check) { - PADDLE_ENFORCE_EQ(scale_dim[0], + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C, platform::errors::InvalidArgument( "ShapeError: the shape of scale must equal to [%d]" "But received: the shape of scale is [%d]", C, - scale_dim[0])); - PADDLE_ENFORCE_EQ(bias_dim[0], + ctx->GetInputDim("Scale")[0])); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], C, platform::errors::InvalidArgument( "ShapeError: the shape of bias must equal to [%d]" "But received: the shape of bias is [%d]", C, - bias_dim[0])); + ctx->GetInputDim("Bias")[0])); } ctx->SetOutputDim("Y", x_dims); ctx->ShareLoD("X", "Y"); @@ -185,16 +189,20 @@ phi::KernelKey BatchNormOp::GetExpectedKernelType( if (input_data_type == framework::proto::VarType::FP64) { bn_param_type = framework::proto::VarType::FP64; } - PADDLE_ENFORCE_EQ( - bn_param_type, - framework::TransToProtoVarType( - ctx.Input("Scale")->dtype()), - platform::errors::InvalidArgument("Scale input should be of float type")); - PADDLE_ENFORCE_EQ( - bn_param_type, - framework::TransToProtoVarType( - ctx.Input("Bias")->dtype()), - platform::errors::InvalidArgument("Bias input should be of float type")); + if (ctx.HasInput("Scale")) { + PADDLE_ENFORCE_EQ(bn_param_type, + framework::TransToProtoVarType( + ctx.Input("Scale")->dtype()), + platform::errors::InvalidArgument( + "Scale input should be of float type")); + } + if (ctx.HasInput("Bias")) { + PADDLE_ENFORCE_EQ(bn_param_type, + framework::TransToProtoVarType( + ctx.Input("Bias")->dtype()), + platform::errors::InvalidArgument( + "Bias input should be of float type")); + } PADDLE_ENFORCE_EQ( bn_param_type, framework::TransToProtoVarType( @@ -205,7 +213,6 @@ phi::KernelKey BatchNormOp::GetExpectedKernelType( ctx.Input("Variance")->dtype()), platform::errors::InvalidArgument( "Variance input should be of float type")); - return phi::KernelKey(input_data_type, ctx.GetPlace()); } @@ -257,10 +264,12 @@ void BatchNormOpMaker::Make() { AddInput("X", "The input tensor"); AddInput("Scale", "Scale is a 1-dimensional tensor of size C " - "that is applied to the output"); + "that is applied to the output") + .AsDispensable(); AddInput("Bias", "Bias is a 1-dimensional tensor of size C " - "that is applied to the output"); + "that is applied to the output") + .AsDispensable(); AddInput("Mean", "The global mean (for training) or " "estimated mean (for testing)"); diff --git a/paddle/fluid/operators/collective/alltoall_op.cu.cc b/paddle/fluid/operators/collective/alltoall_op.cu.cc index bdb774c62bdc4..11b51602d4d75 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cu.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cu.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/alltoall_op.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/utils.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" @@ -27,6 +27,8 @@ PHI_DECLARE_bool(dynamic_static_unified_comm); namespace paddle { namespace operators { +using phi::distributed::GetPartialTensor; + template class AllToAllOpCUDAKernel : public framework::OpKernel { public: @@ -103,9 +105,9 @@ class AllToAllOpCUDAKernel : public framework::OpKernel { if (comm_ctx) { comm_ctx->GroupStart(); for (auto i = 0; i < nranks; ++i) { - auto send_buf = distributed::GetPartialTensor(*x, offset, send_numel); + auto send_buf = GetPartialTensor(*x, offset, send_numel); comm_ctx->Send(send_buf, send_numel, i, stream); - auto recv_buf = distributed::GetPartialTensor(*out, offset, send_numel); + auto recv_buf = GetPartialTensor(*out, offset, send_numel); comm_ctx->Recv(&recv_buf, send_numel, i, stream); offset += send_numel; } diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc index 06be523a50b27..bd105c35886cb 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_allgather_op.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) diff --git a/paddle/fluid/operators/collective/c_comm_init_op.cc b/paddle/fluid/operators/collective/c_comm_init_op.cc index e6815115865aa..8f2ff85fed5aa 100644 --- a/paddle/fluid/operators/collective/c_comm_init_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_op.cc @@ -34,7 +34,7 @@ limitations under the License. */ PHI_DECLARE_bool(dynamic_static_unified_comm); #endif -#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/store/store_utils.h" #include "paddle/phi/core/distributed/store/tcp_store.h" diff --git a/paddle/fluid/operators/collective/c_concat_op.cu.cc b/paddle/fluid/operators/collective/c_concat_op.cu.cc index 37616be1128f7..d13179cbae48b 100644 --- a/paddle/fluid/operators/collective/c_concat_op.cu.cc +++ b/paddle/fluid/operators/collective/c_concat_op.cu.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/c_concat_op.h" #include -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/fluid/operators/math/concat_and_split.h" diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index 50988fd381483..737784d96c0ee 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc index 9da5d6ad1d840..cd1cf0c017636 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_reducescatter_op.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) diff --git a/paddle/fluid/operators/collective/c_scatter_op.cu.cc b/paddle/fluid/operators/collective/c_scatter_op.cu.cc index ea5a0dda1fd97..7f4b4f6734de0 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cu.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_scatter_op.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) diff --git a/paddle/fluid/operators/collective/global_gather_op.cu.cc b/paddle/fluid/operators/collective/global_gather_op.cu.cc index e296a4d218f1f..d95c194452174 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cu.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cu.cc @@ -15,10 +15,10 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/global_gather_op.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/distributed/collective/process_group_nccl.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/nccl_comm_context.h" @@ -279,7 +279,7 @@ struct GlobalGatherProcessGroupFunctor { out->mutable_data(out_dims, place); for (auto i = 0; i < n_expert; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + distributed::ProcessGroupNCCL::GroupStart(); for (auto j = 0; j < nranks; ++j) { int idx = i + j * n_expert; if (cpu_global_count_data[idx]) { @@ -299,7 +299,7 @@ struct GlobalGatherProcessGroupFunctor { /*sync_op*/ true); } } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + distributed::ProcessGroupNCCL::GroupEnd(); } #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index 45d91dc724108..d8cd6d4be5f54 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -15,10 +15,10 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/global_scatter_op.h" #include "paddle/phi/core/distributed/comm_context_manager.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/framework/convert_utils.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/distributed/collective/process_group_nccl.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/phi/core/distributed/nccl_comm_context.h" @@ -286,7 +286,7 @@ struct GlobalScatterProcessGroupFunctor { out->mutable_data(out_dims, place); for (auto i = 0; i < n_expert; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + distributed::ProcessGroupNCCL::GroupStart(); for (auto j = 0; j < nranks; ++j) { int idx = i + j * n_expert; if (cpu_local_count_data[idx]) { @@ -306,7 +306,7 @@ struct GlobalScatterProcessGroupFunctor { recv_ptr += cpu_global_count_data[idx]; } } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + distributed::ProcessGroupNCCL::GroupEnd(); } #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc index cf353c12ffa49..b0cdabce48503 100644 --- a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc @@ -23,7 +23,6 @@ limitations under the License. */ PHI_DECLARE_bool(dynamic_static_unified_comm); #endif -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" namespace paddle { diff --git a/paddle/fluid/operators/collective/partial_recv_op.cu.cc b/paddle/fluid/operators/collective/partial_recv_op.cu.cc index 2a6aea1c7a13a..c8844058696e1 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cu.cc @@ -23,7 +23,6 @@ limitations under the License. */ PHI_DECLARE_bool(dynamic_static_unified_comm); #endif -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" namespace paddle { diff --git a/paddle/fluid/operators/collective/partial_send_op.cu.cc b/paddle/fluid/operators/collective/partial_send_op.cu.cc index 67089a18c8e4f..39858b3ed37a2 100644 --- a/paddle/fluid/operators/collective/partial_send_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_send_op.cu.cc @@ -23,8 +23,6 @@ limitations under the License. */ PHI_DECLARE_bool(dynamic_static_unified_comm); #endif -#include "paddle/fluid/distributed/collective/utils.h" -#include "paddle/fluid/framework/convert_utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" namespace paddle { diff --git a/paddle/fluid/operators/detection/mask_util.cc b/paddle/fluid/operators/detection/mask_util.cc index f3e5b166b43b8..5b4dc92f4f6af 100644 --- a/paddle/fluid/operators/detection/mask_util.cc +++ b/paddle/fluid/operators/detection/mask_util.cc @@ -194,13 +194,24 @@ void Polys2MaskWrtBox(const std::vector>& polygons, w = std::max(w, static_cast(1.)); h = std::max(h, static_cast(1.)); - uint8_t* msk = nullptr; + // short-circuit for case "polygons.size() == 1" if (polygons.size() == 1UL) { - msk = mask; - } else { - msk = reinterpret_cast( - malloc(M * M * polygons.size() * sizeof(uint8_t))); // NOLINT + int k = static_cast(polygons[0].size() / 2); + std::vector p; + for (int j = 0; j < k; ++j) { + float pw = (polygons[0][2 * j] - box[0]) * M / w; // NOLINT + float ph = (polygons[0][2 * j + 1] - box[1]) * M / h; // NOLINT + p.push_back(pw); + p.push_back(ph); + } + Poly2Mask(p.data(), k, M, M, mask); + + return; } + + uint8_t* msk = reinterpret_cast( + malloc(M * M * polygons.size() * sizeof(uint8_t))); // NOLINT + for (size_t i = 0; i < polygons.size(); ++i) { int k = static_cast(polygons[i].size() / 2); std::vector p; @@ -214,19 +225,17 @@ void Polys2MaskWrtBox(const std::vector>& polygons, Poly2Mask(p.data(), k, M, M, msk_i); } - if (polygons.size() > 1UL) { - for (size_t i = 0; i < polygons.size(); ++i) { - uint8_t* msk_i = msk + i * M * M; - for (int j = 0; j < M * M; ++j) { - if (i == 0) { - mask[j] = msk_i[j]; - } else { - mask[j] = (mask[j] + msk_i[j]) > 0 ? 1 : 0; - } + for (size_t i = 0; i < polygons.size(); ++i) { + uint8_t* msk_i = msk + i * M * M; + for (int j = 0; j < M * M; ++j) { + if (i == 0) { + mask[j] = msk_i[j]; + } else { + mask[j] = (mask[j] + msk_i[j]) > 0 ? 1 : 0; } } - free(msk); // NOLINT } + free(msk); // NOLINT } } // namespace operators diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 42c41effb80ed..b8bb9a123fba3 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -13,7 +13,6 @@ register_operators( yolo_box_head_op yolo_box_post_op fusion_group_op - fusion_gru_op fusion_lstm_op fused_bn_add_activation_op fused_attention_op @@ -27,8 +26,6 @@ register_operators( fused_gate_attention_op resnet_basic_block_op) -# fusion_gru_op does not have CUDA kernel -op_library(fusion_gru_op) op_library(fusion_lstm_op) if(WITH_AVX AND AVX512F_FOUND diff --git a/paddle/fluid/operators/fused/fused_attention_utils.h b/paddle/fluid/operators/fused/fused_attention_utils.h index c059a194d0ea5..7d17041133bcd 100644 --- a/paddle/fluid/operators/fused/fused_attention_utils.h +++ b/paddle/fluid/operators/fused/fused_attention_utils.h @@ -23,7 +23,6 @@ PHI_DECLARE_bool(dynamic_static_unified_comm); #endif -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/errors.h" diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index ba12bdc8b9d7f..40717402846db 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -24,7 +24,6 @@ limitations under the License. */ #include -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/fused/attention_layer_norm.h" diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc deleted file mode 100644 index 541233949b5d2..0000000000000 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ /dev/null @@ -1,565 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fused/fusion_gru_op.h" - -#include // for memcpy -#include -#include - -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/fc_functor.h" -#include "paddle/phi/kernels/funcs/jit/kernels.h" -#include "paddle/phi/kernels/funcs/sequence2batch.h" - -namespace paddle { -namespace operators { - -void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_gru"); - OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_gru"); - OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_gru"); - OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_gru"); - OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_gru"); - auto x_dims = ctx->GetInputDim("X"); - auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) - ? phi::flatten_to_2d(x_dims, 1) - : x_dims; - PADDLE_ENFORCE_EQ( - x_mat_dims.size(), - 2, - platform::errors::InvalidArgument("The size of input X dims should be 2, " - "or 3 with second dimension equal to " - "1, but now Input X dim is:[%s] ", - x_dims)); - - auto wx_dims = ctx->GetInputDim("WeightX"); - PADDLE_ENFORCE_EQ(wx_dims.size(), - 2, - platform::errors::InvalidArgument( - "The rank of Input(WeightX) should be 2, but received " - "WeightX dim size is:%d, WeightX dim is:[%s] ", - wx_dims.size(), - wx_dims)); - PADDLE_ENFORCE_EQ( - wx_dims[0], - x_mat_dims[1], - platform::errors::InvalidArgument( - "The first dimension of flattened WeightX" - "should equal to last dimension of flattened input X, but " - "received fattened WeightX dimension is:%d, flattened X dimension " - "is:%d", - wx_dims[0], - x_mat_dims[1])); - - int frame_size = static_cast(wx_dims[1] / 3); - auto wh_dims = ctx->GetInputDim("WeightH"); - - PADDLE_ENFORCE_EQ(wh_dims.size(), - 2, - platform::errors::InvalidArgument( - "The rank of Input(WeightH) should be 2, but received " - "WeightH dim size is:%d, WeightH dim is:[%s]", - wh_dims.size(), - wh_dims)); - PADDLE_ENFORCE_EQ(wh_dims[0], - frame_size, - platform::errors::InvalidArgument( - "The first dimension of WeightH " - "should equal to frame_size, but received WeightH's " - "first dimension is: " - "%d, frame size is:%d", - wh_dims[0], - frame_size)); - PADDLE_ENFORCE_EQ(wh_dims[1], - 3 * frame_size, - platform::errors::InvalidArgument( - "The second dimension of Input(WeightH) " - "should equal to 3 * frame_size, but received WeightH " - "is:%d, frame size is:%d", - wh_dims[1], - frame_size)); - - if (ctx->HasInput("H0")) { - auto h0_dims = ctx->GetInputDim("H0"); - PADDLE_ENFORCE_EQ(h0_dims[1], - frame_size, - platform::errors::InvalidArgument( - "The width of H0 must be equal to frame_size, but " - "receiced the width of H0 is:%d, frame size is:%d", - h0_dims[1], - frame_size)); - } - if (ctx->HasInput("Bias")) { - auto b_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(b_dims.size(), - 2, - platform::errors::InvalidArgument( - "The rank of Input(Bias) should be 2, but received " - "Bias rank is:%d, Bias dim is:[%s]", - b_dims.size(), - b_dims)); - PADDLE_ENFORCE_EQ(b_dims[0], - 1, - platform::errors::InvalidArgument( - "The first dimension of Input(Bias) should be 1, but " - "received Bias first dim is:%d, Bias dim is:[%s]", - b_dims[0], - b_dims)); - PADDLE_ENFORCE_EQ(b_dims[1], - frame_size * 3, - platform::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but " - "received bias dim is:[%s], frame size is:%d", - b_dims, - frame_size)); - } - framework::DDim out_dims({x_mat_dims[0], frame_size}); - ctx->SetOutputDim("Hidden", out_dims); - ctx->ShareLoD("X", "Hidden"); - int xx_width = 0; - if (ctx->Attrs().Get("use_seq")) { - xx_width = static_cast(wx_dims[1]); - } else { - xx_width = static_cast(x_mat_dims[1] > wx_dims[1] ? wx_dims[1] - : x_mat_dims[1]); - OP_INOUT_CHECK( - ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0", "fusion_gru"); - OP_INOUT_CHECK( - ctx->HasOutput("BatchedInput"), "Output", "BatchedInput", "fusion_gru"); - OP_INOUT_CHECK( - ctx->HasOutput("BatchedOut"), "Output", "BatchedOut", "fusion_gru"); - ctx->SetOutputDim("BatchedInput", {x_mat_dims[0], wx_dims[1]}); - ctx->SetOutputDim("BatchedOut", out_dims); - } - ctx->SetOutputDim("XX", {x_mat_dims[0], xx_width}); - ctx->ShareLoD("X", "XX"); -} - -phi::KernelKey FusionGRUOp::GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return phi::KernelKey(data_type, ctx.GetPlace()); -} - -void FusionGRUOpMaker::Make() { - AddInput( - "X", - "(phi::DenseTensor) the input is a LodTensor, which support " - "variable-time length input sequence. The underlying tensor in " - "this phi::DenseTensor is a matrix with shape (T X M), where T is the " - "total time steps in this mini-batch, M is the dim size of x."); - AddInput( - "H0", - "(phi::DenseTensor, optional) The initial hidden state is an optional " - "input. This is a tensor with shape (N x D), where N is the " - "batch size, D is the hidden size.") - .AsDispensable(); - AddInput("WeightX", - "(phi::DenseTensor) The FC weight with shape (M x 3D)," - "where M is the dim size of x, D is the hidden size. "); - AddInput( - "WeightH", - "(phi::DenseTensor) (D x 3D) Same as GRUOp, where D is the hidden size. " - "This weight is not exactly D x 3D as: {W_update, W_reset, W_state}" - "Acutally they are D x 2D and D x D two part weights." - "{W_update, W_reset; W_state}" - "{D x (D + D); D x D}"); - AddInput("Bias", - "(phi::DenseTensor, optional) (1 x 3D)." - "Almost same as GRUOp." - "Note: if have FC bias it should be added on this bias.") - .AsDispensable(); - AddOutput("ReorderedH0", - "(phi::DenseTensor) (N x D), which N is the min-batch size.") - .AsIntermediate(); - AddOutput("XX", - "(phi::DenseTensor) the result after X * WeightX (size is T x 3D)" - " or batched_X (size is T x M), this will be automatically chosen," - " where T is the total time steps in this mini-batch," - " D is the hidden size, M is the dim size of x input.") - .AsIntermediate(); - AddOutput("BatchedInput", - "(phi::DenseTensor) This is the batched result of input X" - "or the batched result after fc, shape (T x 3D)") - .AsIntermediate(); - AddOutput("BatchedOut", "(phi::DenseTensor) (T X D) save batched hidden.") - .AsIntermediate(); - AddOutput("Hidden", "(phi::DenseTensor) (T x D) Same as GRUOp"); - AddAttr("activation", - "(string, default tanh) " - "The activation type used for output candidate {h}_t.") - .SetDefault("tanh"); - AddAttr( - "gate_activation", - "(string, default sigmoid) " - "The activation type used in update gate and reset gate.") - .SetDefault("sigmoid"); - AddAttr("is_reverse", - "(bool, default: False) " - "whether to compute reversed GRU.") - .SetDefault(false); - AddAttr("use_seq", - "(bool, default: True) " - "whether to use seq mode to compute GRU.") - .SetDefault(true); - AddAttr("origin_mode", - "bool" - "use origin mode in article https://arxiv.org/abs/1412.3555") - .SetDefault(false); - AddAttr("use_mkldnn", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); - AddAttr( - "mkldnn_data_type", - "(string, default \"float32\"). Data type of mkldnn kernel") - .SetDefault("float32") - .InEnum({"float32", "int8", "bfloat16"}); - AddAttr("Scale_data", - "Scale to be used for int8 input/output data." - "Only used with MKL-DNN INT8.") - .SetDefault(1.0f); - AddAttr("Shift_data", - "Shift to be used for int8 input/output data." - "Only used with MKL-DNN INT8.") - .SetDefault(0.0f); - AddAttr>("Scale_weights", - "Scale_weights to be used for int8 weights data." - "Only used with MKL-DNN INT8.") - .SetDefault({1.0f}); - AddAttr("force_fp32_output", - "(bool, default false) Force INT8 kernel output FP32, only " - "used in MKL-DNN INT8") - .SetDefault(false); - AddComment(R"DOC( -The Fusion complete GRU Operator. -This operator fuse the fully-connected operator into GRU, -more details can refer to GRU op. -)DOC"); -} - -template -class FusionGRUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - if (ctx.Attr("use_seq")) { - SeqCompute(ctx); - } else { - BatchCompute(ctx); - } - } - -#define INIT_BASE_DEFINES \ - auto* x = ctx.Input("X"); \ - auto* wh = ctx.Input("WeightH"); \ - auto* xx = ctx.Output("XX"); \ - auto x_lod = x->lod(); \ - auto x_dims = x->dims(); /* T x M*/ \ - auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) \ - ? phi::flatten_to_2d(x_dims, 1) \ - : x_dims; \ - auto wh_dims = wh->dims(); /* D x 3D*/ \ - const int total_T = x_mat_dims[0]; \ - const int D3 = wh_dims[1] - -#define INIT_OTHER_DEFINES \ - auto* h0 = ctx.Input("H0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* bias = ctx.Input("Bias"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - bool is_reverse = ctx.Attr("is_reverse"); \ - const int M = x_mat_dims[1]; \ - const int D = wh_dims[0]; \ - const int D2 = D * 2; \ - const phi::jit::gru_attr_t attr( \ - D, \ - phi::jit::to_kerneltype(ctx.Attr("gate_activation")), \ - phi::jit::to_kerneltype(ctx.Attr("activation"))); \ - phi::jit::gru_t one_step; \ - auto ComputeH1 = phi::jit::KernelFuncs, \ - platform::CPUPlace>::Cache() \ - .At(attr); \ - auto ComputeHtPart1 = phi::jit::KernelFuncs, \ - platform::CPUPlace>::Cache() \ - .At(attr); \ - auto ComputeHtPart2 = phi::jit::KernelFuncs, \ - platform::CPUPlace>::Cache() \ - .At(attr); \ - const T* x_data = x->data(); \ - const T* wx_data = wx->data(); \ - const T* wh_data = wh->data(); \ - auto place = ctx.GetPlace(); \ - T* xx_data = xx->mutable_data(place) - - void SeqCompute(const framework::ExecutionContext& ctx) const { - INIT_BASE_DEFINES; - INIT_OTHER_DEFINES; - const int N = static_cast(x_lod[0].size() - 1); - const T* h0_data = h0 ? h0->data() : nullptr; - const T* wh_state_data = wh_data + D * D2; - T* hidden_out_data = hidden_out->mutable_data(place); - - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - - phi::funcs::FCFunctor fc; - fc(dev_ctx, - total_T, - D3, - M, - x_data, - wx_data, - xx_data, - bias ? bias->data() : nullptr); - - int xx_offset = D3; - int gate_offset = D; - if (is_reverse) { - const int offset = (total_T - 1) * D; - xx_data = xx_data + offset * 3; - hidden_out_data = hidden_out_data + offset; - xx_offset = -D3; - gate_offset = -D; - } - auto move_step = [&]() { - xx_data = xx_data + xx_offset; - hidden_out_data = hidden_out_data + gate_offset; - }; - for (int i = 0; i < N; ++i) { - int bid = is_reverse ? N - 1 - i : i; - int seq_len = static_cast(x_lod[0][bid + 1] - x_lod[0][bid]); - const T* prev_hidden_data = nullptr; - int tstart = 0; - if (h0_data) { - prev_hidden_data = h0_data + bid * D; - } else { - one_step.gates = xx_data; - one_step.ht = hidden_out_data; - ComputeH1(&one_step, &attr); - prev_hidden_data = hidden_out_data; - tstart = 1; - move_step(); - } - for (int step = tstart; step < seq_len; ++step) { - // gemm prev * (Wu + Wr) - blas.GEMM(CblasNoTrans, - CblasNoTrans, - 1, - D2, - D, - static_cast(1), - prev_hidden_data, - D, - wh_data, - D2, - static_cast(1), - xx_data, - D3); - one_step.gates = xx_data; - one_step.ht_1 = prev_hidden_data; - one_step.ht = hidden_out_data; - ComputeHtPart1(&one_step, &attr); - // gemm rt * Ws - blas.GEMM(CblasNoTrans, - CblasNoTrans, - 1, - D, - D, - static_cast(1), - hidden_out_data, - D, - wh_state_data, - D, - static_cast(1), - xx_data + D2, - D3); - ComputeHtPart2(&one_step, &attr); - // save prev - prev_hidden_data = hidden_out_data; - move_step(); - } - } - } - - void BatchCompute(const framework::ExecutionContext& ctx) const { - INIT_BASE_DEFINES; - if (x_lod[0].size() == 2) { - xx->Resize({total_T, D3}); - SeqCompute(ctx); - return; - } - INIT_OTHER_DEFINES; - auto* reordered_h0 = ctx.Output("ReorderedH0"); - auto* batched_input = ctx.Output("BatchedInput"); - auto* batched_out = ctx.Output("BatchedOut"); - T* batched_input_data = batched_input->mutable_data(place); - T* batched_out_data = batched_out->mutable_data(place); - hidden_out->mutable_data(place); - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - phi::funcs::LoDTensor2BatchFunctor to_batch; - - phi::funcs::FCFunctor fc; - if (M > D3) { - fc(dev_ctx, - total_T, - D3, - M, - x_data, - wx_data, - xx_data, - bias ? bias->data() : nullptr); - to_batch(dev_ctx, *xx, batched_input, true, is_reverse); - } else { - to_batch(dev_ctx, *x, xx, true, is_reverse); - batched_input->set_lod(xx->lod()); - fc(dev_ctx, - total_T, - D3, - M, - xx_data, - wx_data, - batched_input_data, - bias ? bias->data() : nullptr); - } - - auto batched_lod = batched_input->lod(); - const auto& seq_order = batched_lod[2]; - const int max_bs = static_cast(seq_order.size()); - reordered_h0->Resize({max_bs, D}); - - int tstart = 0; - T* prev_hidden_data = nullptr; - if (h0) { - // reorder h0 - T* reordered_h0_data = reordered_h0->mutable_data(place); - const T* h0_data = h0->data(); - prev_hidden_data = reordered_h0_data; - size_t sz = sizeof(T) * D; - for (int i = 0; i < max_bs; ++i) { - std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); - reordered_h0_data += D; - } - } else { - // compute without h0 - T* cur_in_data = batched_input_data; - T* cur_out_data = batched_out_data; - // W: {W_update, W_reset; W_state} - for (int i = 0; i < max_bs; ++i) { - one_step.gates = cur_in_data; - one_step.ht = cur_out_data; - ComputeH1(&one_step, &attr); - // add offset - cur_in_data += D3; - cur_out_data += D; - } - tstart = 1; - prev_hidden_data = batched_out_data; - } - // Then start from next - const T* wh_state_data = wh_data + D * D2; - const auto& batch_starts = batched_lod[0]; - const int max_seq_len = static_cast(batch_starts.size() - 1); - batched_input_data = batched_input_data + tstart * max_bs * D3; - batched_out_data = batched_out_data + tstart * max_bs * D; - for (int step = tstart; step < max_seq_len; ++step) { - const int cur_bs = - static_cast(batch_starts[step + 1] - batch_starts[step]); - // gemm prev * (Wu + Wr) - blas.GEMM(CblasNoTrans, - CblasNoTrans, - cur_bs, - D2, - D, - static_cast(1), - prev_hidden_data, - D, - wh_data, - D2, - static_cast(1), - batched_input_data, - D3); - - T* cur_batched_data = batched_input_data; - T* cur_out_data = batched_out_data; - T* cur_prev_hidden_data = prev_hidden_data; - for (int i = 0; i < cur_bs; ++i) { - one_step.gates = cur_batched_data; - one_step.ht_1 = cur_prev_hidden_data; - one_step.ht = cur_out_data; - ComputeHtPart1(&one_step, &attr); - - cur_batched_data += D3; - cur_prev_hidden_data += D; - cur_out_data += D; - } - - cur_batched_data = batched_input_data; - cur_out_data = batched_out_data; - blas.GEMM(CblasNoTrans, - CblasNoTrans, - cur_bs, - D, - D, - static_cast(1), - cur_out_data, - D, - wh_state_data, - D, - static_cast(1), - cur_batched_data + D2, - D3); - - cur_prev_hidden_data = prev_hidden_data; - for (int i = 0; i < cur_bs; ++i) { - one_step.gates = cur_batched_data; - one_step.ht_1 = cur_prev_hidden_data; - one_step.ht = cur_out_data; - ComputeHtPart2(&one_step, &attr); - cur_batched_data += D3; - cur_prev_hidden_data += D; - cur_out_data += D; - } - prev_hidden_data = batched_out_data; - batched_out_data = cur_out_data; - batched_input_data = cur_batched_data; - } - - phi::funcs::Batch2LoDTensorFunctor to_seq; - batched_out->set_lod(batched_lod); - to_seq(dev_ctx, *batched_out, hidden_out); - } -#undef INIT_OTHER_DEFINES -#undef INIT_BASE_DEFINES -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker); - -PD_REGISTER_STRUCT_KERNEL( - fusion_gru, CPU, ALL_LAYOUT, ops::FusionGRUKernel, float, double) {} - -/* ========================== register checkpoint ===========================*/ -REGISTER_OP_VERSION(fusion_gru) - .AddCheckpoint( - R"ROC(Upgrade fusion_gru add a new attribute [Scale_weights])ROC", - paddle::framework::compatible::OpVersionDesc().NewAttr( - "Scale_weights", - "The added attribute 'Scale_weights' is not yet " - "registered.", - std::vector{1.0f})); diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc deleted file mode 100644 index de70a5b6b5cf5..0000000000000 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc +++ /dev/null @@ -1,290 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h" - -#include // for min, max -#include - -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/fc_functor.h" - -namespace paddle { -namespace operators { - -void FusionSeqConvEltAddReluOp::InferShape( - framework::InferShapeContext* ctx) const { - OP_INOUT_CHECK( - ctx->HasInput("X"), "Input", "X", "fusion_seqconv_eltadd_relu"); - OP_INOUT_CHECK( - ctx->HasInput("Filter"), "Input", "Filter", "fusion_seqconv_eltadd_relu"); - OP_INOUT_CHECK( - ctx->HasInput("Bias"), "Input", "Bias", "fusion_seqconv_eltadd_relu"); - - OP_INOUT_CHECK( - ctx->HasOutput("Out"), "Output", "Out", "fusion_seqconv_eltadd_relu"); - OP_INOUT_CHECK(ctx->HasOutput("ColMat"), - "Output", - "ColMat", - "fusion_seqconv_eltadd_relu"); - - auto x_dims = ctx->GetInputDim("X"); - auto w_dims = ctx->GetInputDim("Filter"); - int context_length = ctx->Attrs().Get("contextLength"); - PADDLE_ENFORCE_EQ(ctx->Attrs().Get("contextStride"), - 1, - platform::errors::InvalidArgument( - "Currently, FusionSeqConvEltAddReluOp only supports " - "contextStride=1, but received value is: %d.", - ctx->Attrs().Get("contextStride"))); - - PADDLE_ENFORCE_EQ( - x_dims.size(), - 2, - platform::errors::InvalidArgument( - "Input(X) should be 2-D tensor, but reveiced value is: %d.", - x_dims.size())); - - PADDLE_ENFORCE_EQ( - w_dims.size(), - 2, - platform::errors::InvalidArgument( - "Filter should be 2-D tensor, but reveiced value is: %d.", - w_dims.size())); - - PADDLE_ENFORCE_EQ(w_dims[0], - context_length * x_dims[1], - platform::errors::InvalidArgument( - "Filter's height should be equal to context_length * " - "input_hidden_size, but received Filter height is: %d," - "context_length is: %d, input_hidden_size is: %d.", - w_dims[0], - context_length, - x_dims[1])); - - PADDLE_ENFORCE_GT( - context_length + ctx->Attrs().Get("contextStart"), - 0, - platform::errors::InvalidArgument( - "contextStart size should be smaller than contextLength, " - "but received context_length is: %d, contextStart is: " - "%d.", - context_length, - ctx->Attrs().Get("contextStart"))); - - ctx->SetOutputDim("Out", {x_dims[0], w_dims[1]}); - ctx->SetOutputDim("ColMat", {x_dims[0], w_dims[0]}); - ctx->ShareLoD("X", "Out"); -} - -phi::KernelKey FusionSeqConvEltAddReluOp::GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); -} - -void FusionSeqConvEltAddReluOpMaker::Make() { - AddInput( - "X", - "(phi::DenseTensor) the input is a LodTensor, which support " - "variable-time length input sequence. The underlying tensor in " - "this phi::DenseTensor is a matrix with shape (T X M), where T is the " - "total time steps in this mini-batch, M is the dim size of x."); - // PaddingData only support false yet, should be ensured at pass. - AddInput( - "Filter", - "(phi::DenseTensor) same as the input(Filter) of sequence conv op is an " - "learnable parameter." - "This is a tensor with shape (K, N), where K is the " - "context_length * dim size of x, N is the output feature size."); - AddInput( - "Bias", - "(phi::DenseTensor) the learnable weights. shape (1, N), where N is the " - "output feature size"); - AddOutput( - "Out", - "(phi::DenseTensor) the output(Out) is a LodTensor, which support " - "variable-time length output sequence. The underlying tensor in " - "this phi::DenseTensor is a matrix with shape (T, N), where, T is the " - "total time steps in this mini-batch, N is the output feature size."); - AddOutput("ColMat", - "(phi::DenseTensor) (T, K), where T is where T is the " - "total time steps in this mini-batch, K is height of Filter") - .AsIntermediate(); - AddAttr("contextLength", - "(int) the contextLength of FusionSeqConvEltAddReluOp is the " - "height of the convolution kernel.") - .GreaterThan(0); - AddAttr("contextStart", - "(int, default:0) the contextStart of FusionSeqConvEltAddReluOp " - "represents the beginning of the convolution of the number of " - "rows of sequence, which can be negative. The negative number " - "means to pad contextStart time-steps of zeros or learnable " - "parameters at the beginning of each instance. The positive " - "number means to skip contextStart time-steps of each " - "instance.") - .SetDefault(0); - AddAttr( - "contextStride", - "(int, default:1) the contextStride of FusionSeqConvEltAddReluOp " - "represents the stride length of convolution kernel. " - "Currently, FusionSeqConvEltAddReluOp only supports" - "contextStride=1.") - .SetDefault(1) - .GreaterThan(0); - AddComment(R"DOC( -Fusion Sequence Conv and ElementwiseAdd Operator. -)DOC"); -} - -template -class FusionSeqConvEltAddReluKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* w = ctx.Input("Filter"); - auto* b = ctx.Input("Bias"); - auto* y = ctx.Output("Out"); - auto* col = ctx.Output("ColMat"); - - auto x_lod = x->lod(); - auto x_dims = phi::vectorize(x->dims()); - auto w_dims = phi::vectorize(w->dims()); - PADDLE_ENFORCE_EQ( - b->numel(), - w_dims[1], - platform::errors::InvalidArgument( - "bias size should be equal to weights feature size, but received " - "bias size is: %d, weights feature size is: %d.", - b->numel(), - w_dims[1])); - PADDLE_ENFORCE_EQ( - x_lod.size(), - 1UL, - platform::errors::InvalidArgument( - "Only support one level sequence now, but received value is: %d.", - x_lod.size())); - - const T* x_data = x->data(); - const T* w_data = w->data(); - const T* b_data = b->data(); - T* y_data = y->mutable_data(ctx.GetPlace()); - T* col_data = col->mutable_data(ctx.GetPlace()); - - int context_start = ctx.Attr("contextStart"); - int context_length = ctx.Attr("contextLength"); - int up_pad = std::max(0, -context_start); - int down_pad = std::max(0, context_start + context_length - 1); - // im2col - int src_mat_w = static_cast(x_dims[1]); - int src_mat_w_sz = src_mat_w * sizeof(T); - int col_mat_w = static_cast(w_dims[0]); - int col_mat_w_sz = col_mat_w * sizeof(T); - for (int i = 0; i < static_cast(x_lod[0].size()) - 1; ++i) { - int st = static_cast(x_lod[0][i]); - int ed = static_cast(x_lod[0][i + 1]); - const T* src_data = x_data + st * src_mat_w; - T* dst_data = col_data + st * col_mat_w; - int seq_len = ed - st; - if (seq_len > up_pad + down_pad) { - // zero all up_pad and fill data - std::memset(dst_data, 0, up_pad * col_mat_w_sz); - dst_data = dst_data + up_pad * src_mat_w; - int copy_size = col_mat_w_sz - up_pad * src_mat_w_sz; - for (int j = 0; j < up_pad; ++j) { - // blas.VCOPY? - std::memcpy(dst_data, src_data, copy_size); - dst_data += (col_mat_w - src_mat_w); - copy_size += src_mat_w_sz; - } - // fill data - if (context_start > 0) { - src_data += context_start * src_mat_w; - } - for (int j = 0; j < seq_len - up_pad - down_pad; ++j) { - std::memcpy(dst_data, src_data, copy_size); - dst_data += col_mat_w; - src_data += src_mat_w; - } - // zero all down_pad and fill data - std::memset(dst_data, 0, down_pad * col_mat_w_sz); - copy_size -= src_mat_w_sz; - for (int j = 0; j < down_pad; ++j) { - if (copy_size < 0) { - copy_size = 0; - } - std::memcpy(dst_data, src_data, copy_size); - dst_data += col_mat_w; - src_data += src_mat_w; - copy_size -= src_mat_w_sz; - } - } else { - std::memset(dst_data, 0, seq_len * col_mat_w_sz); - dst_data = dst_data + up_pad * src_mat_w; - int zero_sz = up_pad * src_mat_w_sz; - int cur_src_sz = seq_len * src_mat_w_sz; - for (int j = 0; j < std::min(up_pad, seq_len); ++j) { - int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz); - std::memcpy(dst_data, src_data, copy_size); - dst_data += (col_mat_w - src_mat_w); - zero_sz -= src_mat_w_sz; - } - // from bottom - dst_data = col_data + ed * col_mat_w; - src_data = x_data + st * src_mat_w; - if (context_start > 0) { - src_data += context_start * src_mat_w; - } - zero_sz = down_pad * src_mat_w_sz; - for (int j = 1; j <= std::min(down_pad, seq_len); ++j) { - int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz); - if (copy_size < 0) { - copy_size = 0; - } - std::memcpy(dst_data - (zero_sz + copy_size) / sizeof(T), - src_data + std::max(seq_len - j - up_pad, 0) * src_mat_w, - copy_size); - dst_data -= col_mat_w; - zero_sz -= src_mat_w_sz; - } - } - } - auto& dev_ctx = ctx.template device_context(); - phi::funcs::FCFunctor fc; - fc(dev_ctx, - x_dims[0], - w_dims[1], - w_dims[0], - col_data, - w_data, - y_data, - b_data, - true); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(fusion_seqconv_eltadd_relu, - ops::FusionSeqConvEltAddReluOp, - ops::FusionSeqConvEltAddReluOpMaker); - -PD_REGISTER_STRUCT_KERNEL(fusion_seqconv_eltadd_relu, - CPU, - ALL_LAYOUT, - ops::FusionSeqConvEltAddReluKernel, - float, - double) {} diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h deleted file mode 100644 index 42e0c57b1133a..0000000000000 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class FusionSeqConvEltAddReluOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override; -}; - -class FusionSeqConvEltAddReluOpMaker - : public framework::OpProtoAndCheckerMaker { - public: - void Make() override; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc deleted file mode 100644 index 03b5971b1482a..0000000000000 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ /dev/null @@ -1,302 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h" - -#include - -#include "paddle/phi/backends/cpu/cpu_info.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/cpu_vec.h" -#include "paddle/phi/kernels/funcs/fc_functor.h" - -namespace paddle { -namespace operators { - -void FusionSeqExpandConcatFCOp::InferShape( - framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), - 1UL, - platform::errors::InvalidArgument( - "Inputs(X) of FusionSeqExpandConcatFCOp should larger " - "than 1, but received value is: %d.", - ctx->Inputs("X").size())); - OP_INOUT_CHECK(ctx->HasInput("FCWeight"), - "Input", - "FCWeight", - "fusion_seqexpand_concat_fc"); - OP_INOUT_CHECK( - ctx->HasOutput("Out"), "Output", "Out", "fusion_seqexpand_concat_fc"); - OP_INOUT_CHECK( - ctx->HasOutput("FCOut"), "Output", "FCOut", "fusion_seqexpand_concat_fc"); - - auto ins_dims = ctx->GetInputsDim("X"); - auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D - PADDLE_ENFORCE_EQ( - w_dims.size(), - 2, - platform::errors::InvalidArgument( - "Input(FCWeight)'s rank must be 2, but received value is: %d.", - w_dims.size())); - const int D = static_cast(w_dims[1]); - int sum = static_cast(ins_dims[0][1]); - for (size_t i = 1; i < ins_dims.size(); ++i) { - sum += static_cast(ins_dims[i][1]); - } - PADDLE_ENFORCE_EQ( - sum, - w_dims[0], - platform::errors::InvalidArgument("FC height should be sum of all inputs " - "width, but received FC height is: %d, " - "sum of all inputs width is: %d.", - w_dims[0], - sum)); - if (ctx->HasInput("FCBias")) { - auto b_dims = ctx->GetInputDim("FCBias"); - PADDLE_ENFORCE_EQ( - b_dims.size() == 1 || b_dims.size() == 2, - true, - platform::errors::InvalidArgument( - "FCBias dim should be 1 or 2, but received value is: %d.", - b_dims.size())); - if (b_dims.size() == 1) { - PADDLE_ENFORCE_EQ(b_dims[0], - D, - platform::errors::InvalidArgument( - "FCBias shapes must be %d when FCBias dim = 1, but " - "received value is: %d.", - D, - b_dims[0])); - } else { - PADDLE_ENFORCE_EQ(b_dims[0], - 1, - platform::errors::InvalidArgument( - "FCBias shapes must be 1x%d, when FCBias dim = 2, " - "but received dim[0] is: %d.", - D, - b_dims[0])); - PADDLE_ENFORCE_EQ(b_dims[1], - D, - platform::errors::InvalidArgument( - "FCBias shapes must be 1x%d, when FCBias dim = 2, " - "but received dim[1] is: %d.", - D, - b_dims[1])); - } - } - - ctx->SetOutputDim("Out", {ins_dims[0][0], D}); - // fcout should be reshape when run since can not get lod in infershape - // explicit share the ref lod - ctx->ShareLoD("X", "Out", 0); -} - -phi::KernelKey FusionSeqExpandConcatFCOp::GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); -} - -void FusionSeqExpandConcatFCOpMaker::Make() { - AddInput("X", - "(phi::DenseTensor) input LodDTensors, the first one must be have " - "ref lod " - "for sequence expand, and the rest input should have same lod.") - .AsDuplicable(); - AddInput("FCWeight", "(phi::DenseTensor) the weights of fc."); - AddInput("FCBias", "(phi::DenseTensor, optional) the bias of fc.") - .AsDispensable(); - AddOutput("Out", "(phi::DenseTensor) Output LodTensor."); - AddOutput( - "FCOut", - "(phi::DenseTensor) the intermediate tensor to keep the result of fc." - "Shape is (N x D), where N is the batch size, D is the output dim of fc") - .AsIntermediate(); - AddAttr("fc_activation", - "(string, default: identity)" - "The activation for the result of fc." - "`identity` by default.") - .SetDefault("identity") - .InEnum({"sigmoid", "tanh", "relu", "identity"}); - AddComment(R"DOC( -Fusion Sequence expand + concat + fc Operator. - -All below conditions should be meet: - -The ref_level of seq_expand should be 0. - -The ref lod of seq_expand level is the first input of concat. - -The other inputs should have same lod and same batch size of ref lod. - -The seq len of other inputs should be 1. - -The concat axis should be 1. - -)DOC"); -} - -template -class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto ins = ctx.MultiInput("X"); - auto* w = ctx.Input("FCWeight"); - auto* b = ctx.Input("FCBias"); - auto* out = ctx.Output("Out"); - auto* fc_out = ctx.Output("FCOut"); - - auto* ref_in = ins[0]; - auto ref_lod = ref_in->lod(); - auto in1_lod = ins[1]->lod(); - auto ref_dims = ref_in->dims(); // T x M0 - auto in1_dims = ins[1]->dims(); // N x M1 - auto w_dims = w->dims(); - const int N = static_cast(ref_lod[0].size() - 1); - const int total_T = static_cast(ref_dims[0]); - const int M0 = static_cast(ref_dims[1]); - const int M1 = static_cast(in1_dims[1]); - const int D = static_cast(w_dims[1]); - - // some check and fcout should be reshape here - // since infershape can not get lod info - PADDLE_ENFORCE_EQ( - ref_lod.size(), - 1UL, - platform::errors::InvalidArgument( - "Only support input lod size is 1, but received value is: %d.", - ref_lod.size())); - PADDLE_ENFORCE_EQ( - in1_lod.size(), - 1UL, - platform::errors::InvalidArgument( - "Only support input lod size is 1, but received value is: %d.", - in1_lod.size())); - PADDLE_ENFORCE_EQ(static_cast(in1_lod[0].size() - 1), - N, - platform::errors::InvalidArgument( - "Batch size of all inputs should be equal to %d, but " - "received value is: %d.", - N, - static_cast(in1_lod[0].size() - 1))); - PADDLE_ENFORCE_EQ( - static_cast(in1_lod[0][N]), - N, - platform::errors::InvalidArgument("Seq_length of other inputs should " - "be %d, but received value is: %d.", - N, - static_cast(in1_lod[0][N]))); - PADDLE_ENFORCE_EQ( - in1_dims[0], - N, - platform::errors::InvalidArgument( - "input height should be batch size: %d, but received value is %d.", - N, - in1_dims[0])); - for (size_t i = 2; i < ins.size(); ++i) { - PADDLE_ENFORCE_EQ(ins[i]->dims()[0], - N, - platform::errors::InvalidArgument( - "All other inputs height should be equal to %d, " - "but received value is: %d.", - N, - ins[i]->dims()[0])); - PADDLE_ENFORCE_EQ(ins[i]->lod(), - in1_lod, - platform::errors::InvalidArgument( - "All other inputs should have same lod: %d, but " - "received value is: %d.", - in1_lod, - ins[i]->lod())); - } - fc_out->Resize({N, D}); - - std::function fc_act; - auto& fc_act_str = ctx.Attr("fc_activation"); - if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) { - phi::funcs::VecActivations act_functor; - fc_act = act_functor(fc_act_str); - } else { - phi::funcs::VecActivations act_functor; - fc_act = act_functor(fc_act_str); - } - - const T* ref_in_data = ref_in->data(); - const T* in1_data = ins[1]->data(); - const T* w_data = w->data(); - T* out_data = out->mutable_data(ctx.GetPlace()); - T* fc_out_data = fc_out->mutable_data(ctx.GetPlace()); - - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - - phi::funcs::FCFunctor fc; - fc(dev_ctx, - total_T, - D, - M0, - ref_in_data, - w_data, - out_data, - b ? b->data() : NULL); - w_data = w_data + M0 * D; - // first write on - blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data); - w_data = w_data + M1 * D; - for (size_t i = 2; i < ins.size(); ++i) { - // add on - const T* in_data = ins[i]->data(); - const int K = static_cast(ins[i]->dims()[1]); - blas.GEMM(CblasNoTrans, - CblasNoTrans, - N, - D, - K, - static_cast(1), - in_data, - K, - w_data, - D, - static_cast(1), - fc_out_data, - D); - w_data = w_data + K * D; - } - T* cur_out_data = out_data; - for (int i = 0; i < N; ++i) { - int seq_len = static_cast(ref_lod[0][i + 1] - ref_lod[0][i]); - T* src = fc_out_data + i * D; - for (int step = 0; step < seq_len; ++step) { - blas.VADD(D, cur_out_data, src, cur_out_data); - cur_out_data = cur_out_data + D; - } - } - fc_act(total_T * D, out_data, out_data); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(fusion_seqexpand_concat_fc, - ops::FusionSeqExpandConcatFCOp, - ops::FusionSeqExpandConcatFCOpMaker); - -PD_REGISTER_STRUCT_KERNEL(fusion_seqexpand_concat_fc, - CPU, - ALL_LAYOUT, - ops::FusionSeqExpandConcatFCOpKernel, - float, - double) {} diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h deleted file mode 100644 index 7438b6c717487..0000000000000 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class FusionSeqExpandConcatFCOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override; -}; - -class FusionSeqExpandConcatFCOpMaker - : public framework::OpProtoAndCheckerMaker { - public: - void Make() override; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc deleted file mode 100644 index 5ec5e8081bb6f..0000000000000 --- a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc +++ /dev/null @@ -1,387 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/operators/fused/fusion_gru_op.h" -#include "paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h" -#include "paddle/phi/backends/onednn/onednn_reuse.h" -#include "paddle/phi/core/expect.h" - -namespace paddle { -namespace operators { - -using phi::OneDNNContext; -using phi::funcs::OneDNNGetDataType; -using phi::funcs::OneDNNMemDesc; -using phi::funcs::RNNReorderType; -using OneDNNMemoryFormat = dnnl::memory::format_tag; - -template -class GRUMKLDNNHandler : public RNNMKLDNNHandler { - public: - GRUMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, - const OneDNNContext& dev_ctx, - const dnnl::engine onednn_engine, - platform::Place cpu_place UNUSED, - const phi::DenseTensor* input, - const phi::DenseTensor* weight_h, - const phi::DenseTensor* h0, - const bool is_reverse, - const int64_t N, - const int64_t Ti, - const int64_t IC, - const int64_t OC, - const std::string& unique_name UNUSED) - : RNNMKLDNNHandler( - ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - input, - weight_h, - h0, - is_reverse, - N, - Ti, - IC, - OC, - 3, - ctx.InputName("X") + ctx.InputName("WeightH")) { - const bool is_INT8 = std::is_same::value; - - if (unlikely(!this->isCached())) { - // oneDNN kernel has hardcoded activation functions - PADDLE_ENFORCE_EQ( - ctx.Attr("gate_activation"), - "sigmoid", - platform::errors::Unimplemented( - "oneDNN fusion_gru supports only sigmoid as a gate activation.")); - PADDLE_ENFORCE_EQ( - ctx.Attr("activation"), - "tanh", - platform::errors::Unimplemented( - "oneDNN fusion_gru supports only tanh as an activation.")); - - // Weights for int8 kernel are of a type s8 - const auto weights_dt = - is_INT8 ? dnnl::memory::data_type::s8 : OneDNNGetDataType(); - - // oneDNN RNN dimensions - const int64_t D = 1; // Directions - const int64_t L = 1; // Layers (PP supports only 1 stacked layer) - const int64_t G = 3; // Number of Gates, 3 for GRU - - // Create memory descriptors - auto input_md = OneDNNMemDesc( - {Ti, N, IC}, OneDNNGetDataType(), OneDNNMemoryFormat::ntc); - auto weight_x_md = - OneDNNMemDesc({L, D, IC, G, OC}, weights_dt, OneDNNMemoryFormat::any); - auto weight_h_md = - OneDNNMemDesc({L, D, OC, G, OC}, weights_dt, OneDNNMemoryFormat::any); - auto bias_md = OneDNNMemDesc( - {L, D, G, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldgo); - auto hidden_md = OneDNNMemDesc( - {Ti, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ntc); - auto h0_md = OneDNNMemDesc( - {L, D, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldnc); - - // Create GRU oneDNN primitive - const auto direction = - is_reverse ? dnnl::rnn_direction::unidirectional_right2left - : dnnl::rnn_direction::unidirectional_left2right; - - this->AcquireForwardPrimitiveDescriptor( - this->attr_, - dnnl::prop_kind::forward_inference, - direction, - input_md, - h0_md, - weight_x_md, - weight_h_md, - bias_md, - hidden_md, - dnnl::memory::desc()); - } - } - - template - std::shared_ptr AcquireWeightXMemory( - const phi::DenseTensor* weight_x, const bool origin_mode) { - const std::string wx_key = this->memory_key_ + "@weight_x"; - auto memory_p = - std::static_pointer_cast(this->dev_ctx_.GetBlob(wx_key)); - - if (!memory_p) { - auto user_md = OneDNNMemDesc({1, 1, this->IC, this->G, this->OC}, - OneDNNGetDataType(), - OneDNNMemoryFormat::ldigo); - auto user_memory = dnnl::memory(user_md, this->engine_); - - auto* weight_x_data = reinterpret_cast(user_memory.get_data_handle()); - memcpy(weight_x_data, - weight_x->data(), - sizeof(U) * this->IC * this->G * this->OC); - - if (origin_mode == false) { - for (int64_t i = 0; i < this->IC; ++i) { - for (int64_t j = 0; j < this->OC; ++j) { - U minus_one(-1.0f); - weight_x_data[j] = minus_one * weight_x_data[j]; - } - weight_x_data += 3 * this->OC; - } - } - - memory_p = std::make_shared( - this->fwd_pd_->weights_layer_desc(), this->engine_); - - auto& astream = OneDNNContext::tls().get_stream(); - dnnl::reorder(user_memory, *memory_p, this->attr_) - .execute(astream, user_memory, *memory_p); - - this->dev_ctx_.SetBlob(wx_key, memory_p); - } - return memory_p; - } - - template - std::shared_ptr AcquireWeightHMemory( - const phi::DenseTensor* weight_h, const bool origin_mode) { - const std::string wh_key = this->memory_key_ + "@weight_h"; - auto memory_p = - std::static_pointer_cast(this->dev_ctx_.GetBlob(wh_key)); - - if (!memory_p) { - auto user_md = OneDNNMemDesc({1, 1, this->OC, this->G, this->OC}, - OneDNNGetDataType(), - OneDNNMemoryFormat::ldigo); - auto user_memory = dnnl::memory(user_md, this->engine_); - - // Reorder weights_h from PP format [OC, 2OC] + [OC, OC] to - // oneDNN format [OC, 3OC] - auto* weight_h_data = reinterpret_cast(user_memory.get_data_handle()); - auto* user_weight_h_data = weight_h->data(); - - auto src1_iter = user_weight_h_data; - auto src2_iter = user_weight_h_data + 2 * this->OC * this->OC; - - for (int64_t c = 0; c < this->OC; ++c) { - memcpy(weight_h_data, src1_iter, 2 * this->OC * sizeof(U)); - memcpy(weight_h_data + 2 * this->OC, src2_iter, this->OC * sizeof(U)); - - src1_iter += 2 * this->OC; - src2_iter += this->OC; - weight_h_data += 3 * this->OC; - } - - weight_h_data = reinterpret_cast(user_memory.get_data_handle()); - - if (origin_mode == false) { - for (int64_t i = 0; i < this->OC; ++i) { - for (int64_t j = 0; j < this->OC; ++j) { - U minus_one(-1.0f); - weight_h_data[j] = minus_one * weight_h_data[j]; - } - weight_h_data += 3 * this->OC; - } - } - - memory_p = std::make_shared( - this->fwd_pd_->weights_iter_desc(), this->engine_); - - auto& astream = OneDNNContext::tls().get_stream(); - dnnl::reorder(user_memory, *memory_p, this->attr_) - .execute(astream, user_memory, *memory_p); - - this->dev_ctx_.SetBlob(wh_key, memory_p); - } - return memory_p; - } - - std::shared_ptr AcquireBiasMemory(const phi::DenseTensor* bias, - const bool origin_mode) { - const std::string bias_key = this->memory_key_ + "@bias"; - auto memory_p = std::static_pointer_cast( - this->dev_ctx_.GetBlob(bias_key)); - - if (!memory_p) { - memory_p = std::make_shared(this->fwd_pd_->bias_desc(), - this->engine_); - auto* bias_data = reinterpret_cast(memory_p->get_data_handle()); - if (bias) { - const float* user_bias_data = - bias->data(); // Bias in oneDNN is always float - memcpy(bias_data, user_bias_data, sizeof(float) * this->G * this->OC); - } else { - // oneDNN always need bias memory, if it's not provided in PP, let - // oneDNN allocate memory and set it to 0 - memset(bias_data, 0, sizeof(float) * this->G * this->OC); - } - - if (origin_mode == false && bias) { - for (int64_t i = 0; i < this->OC; ++i) { - bias_data[i] *= -1; - } - } - this->dev_ctx_.SetBlob(bias_key, memory_p); - } - return memory_p; - } -}; - -template -class FusionGRUMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const bool is_bf16 = std::is_same::value; - const bool force_fp32_output = ctx.Attr("force_fp32_output"); - - // BF16 does not support force output - if (!is_bf16 && force_fp32_output) { // NOLINT - RunKernel(ctx); - } else { - RunKernel(ctx); - } - } - - template - void RunKernel(const framework::ExecutionContext& ctx) const { - auto& dev_ctx = ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - // Get Tensors - const auto* input = ctx.Input("X"); - const auto* h0 = ctx.Input("H0"); - const auto* weight_x = ctx.Input("WeightX"); - const auto* weight_h = ctx.Input("WeightH"); - const auto* bias = ctx.Input("Bias"); - auto* hidden = ctx.Output("Hidden"); - auto x_dims = input->dims(); - auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) - ? phi::flatten_to_2d(x_dims, 1) - : x_dims; - // Get attributes - const bool is_reverse = ctx.Attr("is_reverse"); - const bool origin_mode = ctx.Attr("origin_mode"); - - // Get tensor dimensions - const auto x_mat_dims_vec = phi::vectorize(x_mat_dims); - const auto weight_h_dims = phi::vectorize(weight_h->dims()); - const auto& input_lod = input->lod()[0]; - - // Calculate RNN dimensions - const int64_t N = input_lod.size() - 1; // Number of sentences (batches) - const int64_t Ti = // Max length of the sentence in a batch - [&input_lod]() { - size_t res = 0; - for (size_t i = 0; i < (input_lod.size() - 1); ++i) { - res = std::max(res, input_lod[i + 1] - input_lod[i]); - } - return res; - }(); - const int64_t IC = x_mat_dims_vec[1]; // Input channels - const int64_t OC = weight_h_dims[0]; // Output channels - - GRUMKLDNNHandler handler( - ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - input, - weight_h, - h0, - is_reverse, - N, - Ti, - IC, - OC, - ctx.InputName("X") + ctx.InputName("WeightH")); - - auto input_memory_p = - handler.AcquireInputMemoryWithReorder(input, is_reverse); - - std::shared_ptr h0_memory_p, weight_h_memory_p, - weight_x_memory_p; - - if (framework::TransToProtoVarType(weight_h->dtype()) == - paddle::framework::proto::VarType_Type_FP32) { - h0_memory_p = handler.template AcquireH0Memory(h0); - weight_x_memory_p = - handler.template AcquireWeightXMemory(weight_x, origin_mode); - weight_h_memory_p = - handler.template AcquireWeightHMemory(weight_h, origin_mode); - } else if (framework::TransToProtoVarType(weight_h->dtype()) == - paddle::framework::proto::VarType_Type_BF16) { - h0_memory_p = - handler.template AcquireH0Memory(h0); - weight_x_memory_p = - handler.template AcquireWeightXMemory( - weight_x, origin_mode); - weight_h_memory_p = - handler.template AcquireWeightHMemory( - weight_h, origin_mode); - } else { - h0_memory_p = handler.template AcquireH0Memory(h0); - weight_x_memory_p = - handler.template AcquireWeightXMemory(weight_x, origin_mode); - weight_h_memory_p = - handler.template AcquireWeightHMemory(weight_h, origin_mode); - } - - auto bias_memory_p = handler.AcquireBiasMemory(bias, origin_mode); - auto hidden_onednn_memory_p = handler.AcquireOutputMemory(); - - std::unordered_map gru_args = { - {DNNL_ARG_SRC_LAYER, *input_memory_p}, - {DNNL_ARG_SRC_ITER, *h0_memory_p}, - {DNNL_ARG_WEIGHTS_LAYER, *weight_x_memory_p}, - {DNNL_ARG_WEIGHTS_ITER, *weight_h_memory_p}, - {DNNL_ARG_BIAS, *bias_memory_p}, - {DNNL_ARG_DST_LAYER, *hidden_onednn_memory_p}}; - - auto gru_forward_p = handler.AcquireForwardPrimitive(); - - auto& astream = OneDNNContext::tls().get_stream(); - gru_forward_p->execute(astream, gru_args); - astream.wait(); - - auto* hidden_onednn_data = hidden_onednn_memory_p->get_data_handle(); - auto* hidden_data = - phi::funcs::to_void_cast(hidden->mutable_data(ctx.GetPlace())); - if (handler.is_NTC()) { - handler.reorderRNNdata(hidden_onednn_data, - hidden_data, - input_lod, - is_reverse, - RNNReorderType::NTC_PP); - } else { - handler.reorderRNNdata(hidden_onednn_data, - hidden_data, - input_lod, - is_reverse, - RNNReorderType::TNC_PP); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(fusion_gru, - MKLDNN, - phi::CPUPlace, - ops::FusionGRUMKLDNNKernel, - ops::FusionGRUMKLDNNKernel, - ops::FusionGRUMKLDNNKernel); diff --git a/paddle/fluid/operators/generator/generate_op.py b/paddle/fluid/operators/generator/generate_op.py index 777a9d19ba950..2cb485c2fc176 100644 --- a/paddle/fluid/operators/generator/generate_op.py +++ b/paddle/fluid/operators/generator/generate_op.py @@ -92,6 +92,7 @@ def process_scalar(op_item, scalar_configs): scalar_map = { 'Scalar': 'float', 'Scalar(float)': 'float', + 'Scalar(double)': 'double', 'Scalar(int)': 'int', 'Scalar(int64_t)': 'int64_t', } @@ -115,6 +116,14 @@ def process_scalar(op_item, scalar_configs): if 'data_type' in scalar_config else scalar_map[attr_type] ) + if ( + attr_type == 'Scalar(double)' + and attr_item['data_type'] == 'std::string' + and 'default_value' in attr_item + ): + attr_item['default_value'] = ( + '"' + attr_item['default_value'] + '"' + ) if attr_item['is_support_tensor'] is False: attr_item['tensor_name'] = scalar_config['tensor_name'] diff --git a/paddle/fluid/operators/generator/parse_utils.py b/paddle/fluid/operators/generator/parse_utils.py index 3a2429f534573..4832073ca9c9f 100644 --- a/paddle/fluid/operators/generator/parse_utils.py +++ b/paddle/fluid/operators/generator/parse_utils.py @@ -365,6 +365,7 @@ def check_op_config(op_entry, op_name): 'data_transform', 'composite', 'support_dygraph_mode', + 'support_tensor', ) infer_meta_key_set = ('func', 'param', 'spmd_rule') kernel_key_set = ( @@ -508,6 +509,11 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): else: data_trans = None + if "support_tensor" in op_entry.keys(): + support_tensor = op_entry["support_tensor"] + else: + support_tensor = [] + op = { "name": op_name, "inputs": inputs, @@ -515,6 +521,7 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): "outputs": outputs, "no_need_buffer": no_buffer_args, "data_transform": data_trans, + "support_tensor": support_tensor, } # op should be is_base_op or is_invoke_op or is_only_composite_op diff --git a/paddle/fluid/operators/generator/type_mapping.py b/paddle/fluid/operators/generator/type_mapping.py index 56e01a997e61b..2ba6501194be4 100644 --- a/paddle/fluid/operators/generator/type_mapping.py +++ b/paddle/fluid/operators/generator/type_mapping.py @@ -31,6 +31,7 @@ 'Scalar(int)': 'const Scalar&', 'Scalar(int64_t)': 'const Scalar&', 'Scalar(float)': 'const Scalar&', + 'Scalar(double)': 'const Scalar&', 'Scalar[]': 'const std::vector&', 'Place': 'Place', 'DataLayout': 'DataLayout', @@ -59,6 +60,7 @@ 'Scalar(int)': 'int', 'Scalar(int64_t)': 'int64_t', 'Scalar(float)': 'float', + 'Scalar(double)': 'double', 'Scalar[]': 'std::vector', 'Place': 'int', 'DataLayout': 'int', diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index e06f7d443b84d..89b9d9d00e871 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -57,7 +57,7 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { flattened matrix is equal to the product of $X$'s first `x_num_col_dims` dimensions' sizes, and width of the flattened matrix is equal to the product of $X$'s last `rank(x) - num_col_dims` - dimensions' size. For example, suppose $X$ is a 6-dimensional + dimensions' size. For example, suppose $X$ is a 5-dimensional tensor with the shape [2, 3, 4, 5, 6], and `x_num_col_dims` = 3. Thus, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index a672f5ac99aa8..6b0a36fc56472 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -22,6 +22,7 @@ #include "paddle/phi/core/cuda_stream.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/utils/data_type.h" @@ -29,8 +30,6 @@ #include "paddle/phi/kernels/funcs/tensor_to_string.h" #include "paddle/utils/optional.h" -#include "paddle/fluid/distributed/collective/utils.h" - #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/phi/core/distributed/nccl_comm_context.h" #include "paddle/phi/core/flags.h" @@ -2404,9 +2403,9 @@ void DistributedFusedLambKernel( if (num_devices > 1) { // ncclAllGather if (local_comm_ctx) { - auto send_buf = paddle::distributed::GetPartialTensor( + auto send_buf = distributed::GetPartialTensor( *fp32_param_out, fp32_offset, fp32_numel_each_device); - auto recv_buf = paddle::distributed::GetPartialTensor( + auto recv_buf = distributed::GetPartialTensor( *fp32_param_out, 0, fp32_numel_each_device); local_comm_ctx->AllGather(&recv_buf, send_buf, stream); } else { @@ -2442,9 +2441,9 @@ void DistributedFusedLambKernel( if (num_devices > 1) { // ncclAllGather if (local_comm_ctx) { - auto send_buf = paddle::distributed::GetPartialTensor( + auto send_buf = distributed::GetPartialTensor( *fp16_param_out, fp16_offset, fp16_numel_each_device); - auto recv_buf = paddle::distributed::GetPartialTensor( + auto recv_buf = distributed::GetPartialTensor( *fp16_param_out, 0, fp16_numel_each_device); local_comm_ctx->AllGather(&recv_buf, send_buf, stream); } else { diff --git a/paddle/fluid/operators/optimizers/sparse_momentum_op.h b/paddle/fluid/operators/optimizers/sparse_momentum_op.h index f1b162be46610..d29b4b8fb2e5a 100644 --- a/paddle/fluid/operators/optimizers/sparse_momentum_op.h +++ b/paddle/fluid/operators/optimizers/sparse_momentum_op.h @@ -366,8 +366,7 @@ class SparseMomentumOpKernel : public framework::OpKernel { MT mu = static_cast(ctx.Attr("mu")); MT rescale_grad = static_cast(ctx.Attr("rescale_grad")); - int axis = ctx.Attr("axis"); - // get axis from tensor + int axis = 0; if (ctx.HasInput("Axis")) { phi::DenseTensor cpu_axis; const phi::DenseTensor* axis_tensor = ctx.Input("Axis"); @@ -379,6 +378,8 @@ class SparseMomentumOpKernel : public framework::OpKernel { } else if (axis_type == framework::proto::VarType::INT64) { axis = static_cast(cpu_axis.data()[0]); } + } else { + axis = ctx.Attr("axis"); } PADDLE_ENFORCE_EQ( axis == 0 || axis == 1, diff --git a/paddle/fluid/pir/CMakeLists.txt b/paddle/fluid/pir/CMakeLists.txt index 1ff77c6d7187e..24f5e2892de8e 100644 --- a/paddle/fluid/pir/CMakeLists.txt +++ b/paddle/fluid/pir/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(dialect) add_subdirectory(transforms) +add_subdirectory(drr) diff --git a/paddle/fluid/pir/dialect/CMakeLists.txt b/paddle/fluid/pir/dialect/CMakeLists.txt index 17a73237c5fdb..aae8db4ba641c 100644 --- a/paddle/fluid/pir/dialect/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/CMakeLists.txt @@ -1,2 +1,229 @@ -add_subdirectory(operator) -add_subdirectory(kernel) +set(PD_DIALECT_BINARY_DIR + "${PADDLE_BINARY_DIR}/paddle/fluid/pir/dialect/operator/ir") + +# Generate pd_op_dialect files defining op using op_gen_file +set(op_gen_parsed_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parse_op.py) + +set(op_gen_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/op_gen.py) +set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) +set(op_forward_yaml_file1 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml +) +set(op_forward_yaml_file2 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml +) +set(op_backward_yaml_file1 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/backward_ops.parsed.yaml +) +set(op_backward_yaml_file2 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml +) +set(fused_op_forward_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_ops.parsed.yaml +) +set(fused_op_backward_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_backward.parsed.yaml +) + +set(pd_op_forward_yaml_file1 + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/ops.yaml) + +set(pd_op_forward_yaml_file2 + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml) + +set(pd_op_backward_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml) + +set(parsed_op_dir + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated) + +set(op_yaml_file3 ${parsed_op_dir}/ops.parsed.yaml) +set(op_yaml_file4 ${parsed_op_dir}/ops_backward.parsed.yaml) +set(op_yaml_file5 ${parsed_op_dir}/update_ops.parsed.yaml) + +set(op_yaml_files + ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${fused_op_forward_yaml_file},${fused_op_backward_yaml_file},${op_yaml_file3},${op_yaml_file4},${op_yaml_file5} +) +set(op_namespace paddle,dialect) +set(dialect_name pd_op) +set(op_header_file ${PD_DIALECT_BINARY_DIR}/pd_op.h) +set(op_source_file ${PD_DIALECT_BINARY_DIR}/pd_op.cc) +set(op_header_file_tmp ${op_header_file}.tmp) +set(op_source_file_tmp ${op_source_file}.tmp) + +set(op_vjp_source_file ${PD_DIALECT_BINARY_DIR}/pd_op_vjp.cc) +set(op_decomp_source_file ${PD_DIALECT_BINARY_DIR}/op_decomp.cc) +set(op_vjp_source_file_tmp ${op_vjp_source_file}.tmp) + +execute_process( + COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_op_dir} + COMMAND ${PYTHON_EXECUTABLE} ${op_gen_parsed_yaml_file} --op_yaml_path + ${pd_op_forward_yaml_file1} --output_path ${op_yaml_file3} + COMMAND ${PYTHON_EXECUTABLE} ${op_gen_parsed_yaml_file} --op_yaml_path + ${pd_op_forward_yaml_file2} --output_path ${op_yaml_file5} + COMMAND ${PYTHON_EXECUTABLE} ${op_gen_parsed_yaml_file} --op_yaml_path + ${pd_op_backward_yaml_file} --output_path ${op_yaml_file4} --backward) + +add_custom_command( + OUTPUT ${op_header_file} ${op_source_file} ${op_vjp_source_file} + COMMAND + ${PYTHON_EXECUTABLE} ${op_gen_file} --op_yaml_files ${op_yaml_files} + --op_compat_yaml_file ${op_compat_yaml_file} --namespaces ${op_namespace} + --dialect_name ${dialect_name} --op_def_h_file ${op_header_file_tmp} + --op_def_cc_file ${op_source_file_tmp} --op_vjp_cc_file + ${op_vjp_source_file_tmp} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_header_file_tmp} + ${op_header_file} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_source_file_tmp} + ${op_source_file} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_vjp_source_file_tmp} + ${op_vjp_source_file} + COMMENT + "copy_if_different ${op_header_file} ${op_source_file} ${op_vjp_source_file}" + DEPENDS ${op_gen_file} + ${op_forward_yaml_file1} + ${op_forward_yaml_file2} + ${op_backward_yaml_file1} + ${op_backward_yaml_file2} + ${op_compat_yaml_file} + ${op_yaml_file3} + ${op_yaml_file4} + ${op_yaml_file5} + VERBATIM) +add_custom_target( + op_header_and_source_gen ALL DEPENDS ${op_header_file} ${op_source_file} + ${op_vjp_source_file}) +set(api_gen_yaml_files + ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3},${op_yaml_file4},${op_yaml_file5} +) +set(api_gen_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/api_gen.py) +set(api_header_file ${PD_DIALECT_BINARY_DIR}/pd_api.h) +set(api_source_file ${PD_DIALECT_BINARY_DIR}/pd_api.cc) +set(api_header_file_tmp ${api_header_file}.tmp) +set(api_source_file_tmp ${api_source_file}.tmp) + +add_custom_command( + OUTPUT ${api_header_file} ${api_source_file} + COMMAND + ${PYTHON_EXECUTABLE} ${api_gen_file} --op_yaml_files ${op_yaml_files} + --op_compat_yaml_file ${op_compat_yaml_file} --namespaces ${op_namespace} + --api_def_h_file ${api_header_file_tmp} --api_def_cc_file + ${api_source_file_tmp} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_header_file_tmp} + ${api_header_file} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp} + ${api_source_file} + COMMENT "copy_if_different ${api_header_file} ${api_source_file}" + DEPENDS ${api_gen_file} + ${op_forward_yaml_file1} + ${op_forward_yaml_file2} + ${op_backward_yaml_file1} + ${op_backward_yaml_file2} + ${op_compat_yaml_file} + ${op_yaml_file3} + ${op_yaml_file4} + ${op_yaml_file5} + VERBATIM) + +add_custom_target(api_header_and_source_gen ALL DEPENDS ${api_header_file} + ${api_source_file}) +set(python_c_gen_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/python_c_gen.py) +set(python_c_header_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/static_op_function.h) +set(python_c_source_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/static_op_function.cc) +set(python_c_header_file_tmp ${python_c_header_file}.tmp) +set(python_c_source_file_tmp ${python_c_source_file}.tmp) + +add_custom_command( + OUTPUT ${python_c_header_file} ${python_c_source_file} + COMMAND + ${PYTHON_EXECUTABLE} ${python_c_gen_file} --op_yaml_files ${op_yaml_files} + --op_compat_yaml_file ${op_compat_yaml_file} --namespaces "paddle,pybind" + --python_c_def_h_file ${python_c_header_file_tmp} --python_c_def_cc_file + ${python_c_source_file_tmp} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${python_c_header_file_tmp} + ${python_c_header_file} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${python_c_source_file_tmp} + ${python_c_source_file} + COMMENT "copy_if_different ${python_c_header_file} ${python_c_source_file}" + DEPENDS ${python_c_gen_file} + ${op_forward_yaml_file1} + ${op_forward_yaml_file2} + ${op_backward_yaml_file1} + ${op_backward_yaml_file2} + ${op_compat_yaml_file} + ${op_yaml_file3} + ${op_yaml_file4} + ${op_yaml_file5} + VERBATIM) + +add_custom_target(static_op_function_gen ALL DEPENDS ${python_c_header_file} + ${python_c_source_file}) + +set(ops_api_gen_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py) +set(ops_api_source_file ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/ops_api.cc) +set(ops_api_source_file_tmp ${ops_api_source_file}.tmp) + +add_custom_command( + OUTPUT ${ops_api_source_file} + COMMAND + ${PYTHON_EXECUTABLE} ${ops_api_gen_file} --op_yaml_files ${op_yaml_files} + --op_compat_yaml_file ${op_compat_yaml_file} --namespaces "paddle,pybind" + --ops_api_file ${ops_api_source_file_tmp} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${ops_api_source_file_tmp} + ${ops_api_source_file} + COMMENT "copy_if_different ${ops_api_source_file}" + DEPENDS ${ops_api_gen_file} + ${op_forward_yaml_file1} + ${op_forward_yaml_file2} + ${op_backward_yaml_file1} + ${op_backward_yaml_file2} + ${op_compat_yaml_file} + ${python_c_header_file} + ${python_c_source_file} + VERBATIM) + +add_custom_target(ops_api_gen ALL DEPENDS ${ops_api_source_file}) + +#Note(risemeup1):compile some *.cc files which do not depend on primitive_vjp_experimental into op_dialect.a/lib +file(GLOB_RECURSE op_dialect_srcs "*.cc") +list( + REMOVE_ITEM + op_dialect_srcs + ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_decomp.cc + ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_vjp.cc + ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/op_dialect.cc) + +set(op_dialect_srcs ${op_dialect_srcs} ${op_source_file} ${api_source_file}) + +set(op_dialect_deps phi phi_utils pir type_info) + +cc_library( + op_dialect + SRCS ${op_dialect_srcs} + DEPS ${op_dialect_deps}) + +#Note(risemeup1):compile some *.cc files which depend on primitive_vjp_experimental into op_dialect_vjp.a/lib +set(op_dialect_vjp_srcs + ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_decomp.cc + ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_vjp.cc + ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/op_dialect.cc + ${op_decomp_source_file} + ${op_vjp_source_file}) +set(op_dialect_vjp_deps primitive_vjp_experimental op_dialect) + +cc_library( + op_dialect_vjp + SRCS ${op_dialect_vjp_srcs} + DEPS ${op_dialect_vjp_deps}) + +if((CMAKE_CXX_COMPILER_ID STREQUAL "GNU")) + set_target_properties(op_dialect PROPERTIES COMPILE_FLAGS + "-Wno-maybe-uninitialized") +endif() diff --git a/paddle/fluid/pir/dialect/kernel/CMakeLists.txt b/paddle/fluid/pir/dialect/kernel/CMakeLists.txt deleted file mode 100644 index dd1b708ce9fe4..0000000000000 --- a/paddle/fluid/pir/dialect/kernel/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(ir) diff --git a/paddle/fluid/pir/dialect/kernel/ir/CMakeLists.txt b/paddle/fluid/pir/dialect/kernel/ir/CMakeLists.txt deleted file mode 100644 index bdfdb75410524..0000000000000 --- a/paddle/fluid/pir/dialect/kernel/ir/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -file(GLOB PADDLE_KERNEL_DIALECT_SRCS "*.cc") -cc_library( - pd_kernel_dialect - SRCS ${PADDLE_KERNEL_DIALECT_SRCS} - DEPS pd_op_dialect_core) diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc index f120ebc969e5c..80515c6e9f5e9 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc @@ -29,7 +29,7 @@ const phi::DDim& AllocatedDenseTensorType::dims() const { return storage()->dense_tensor_type_.dims(); } -const phi::DataLayout& AllocatedDenseTensorType::data_layout() const { +phi::DataLayout AllocatedDenseTensorType::data_layout() const { return storage()->dense_tensor_type_.data_layout(); } @@ -37,7 +37,7 @@ const phi::LoD& AllocatedDenseTensorType::lod() const { return storage()->dense_tensor_type_.lod(); } -const size_t& AllocatedDenseTensorType::offset() const { +size_t AllocatedDenseTensorType::offset() const { return storage()->dense_tensor_type_.offset(); } @@ -45,7 +45,7 @@ const phi::Place& AllocatedSelectedRowsType::place() const { return storage()->place_; } -const pir::Type& AllocatedSelectedRowsType::dtype() const { +pir::Type AllocatedSelectedRowsType::dtype() const { return storage()->selected_rows_type_.dtype(); } @@ -53,7 +53,7 @@ const phi::DDim& AllocatedSelectedRowsType::dims() const { return storage()->selected_rows_type_.dims(); } -const phi::DataLayout& AllocatedSelectedRowsType::data_layout() const { +phi::DataLayout AllocatedSelectedRowsType::data_layout() const { return storage()->selected_rows_type_.data_layout(); } @@ -61,7 +61,7 @@ const phi::LoD& AllocatedSelectedRowsType::lod() const { return storage()->selected_rows_type_.lod(); } -const size_t& AllocatedSelectedRowsType::offset() const { +size_t AllocatedSelectedRowsType::offset() const { return storage()->selected_rows_type_.offset(); } diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h index 2c43cfb9ec384..51c5cd2286338 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h @@ -55,11 +55,11 @@ class AllocatedDenseTensorType const phi::DDim &dims() const; - const phi::DataLayout &data_layout() const; + phi::DataLayout data_layout() const; const phi::LoD &lod() const; - const size_t &offset() const; + size_t offset() const; }; class AllocatedSelectedRowsType @@ -92,15 +92,15 @@ class AllocatedSelectedRowsType const phi::Place &place() const; - const pir::Type &dtype() const; + pir::Type dtype() const; const phi::DDim &dims() const; - const phi::DataLayout &data_layout() const; + phi::DataLayout data_layout() const; const phi::LoD &lod() const; - const size_t &offset() const; + size_t offset() const; }; } // namespace dialect diff --git a/paddle/fluid/pir/dialect/op_generator/api_gen.py b/paddle/fluid/pir/dialect/op_generator/api_gen.py index c336dc7b61be1..b004e706d5bf5 100644 --- a/paddle/fluid/pir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/api_gen.py @@ -21,7 +21,9 @@ PD_MANUAL_OP_LIST, OpCompatParser, OpInfoParser, + check_need_update_ops, to_pascal_case, + update_ops, ) H_FILE_TEMPLATE = """ @@ -154,12 +156,18 @@ def __init__(self) -> None: def _parse_yaml(self, op_yaml_files, op_compat_yaml_file): op_compat_parser = OpCompatParser(op_compat_yaml_file) + need_update_ops, update_yaml_file = check_need_update_ops(op_yaml_files) op_yaml_items = [] for yaml_file in op_yaml_files: + if update_yaml_file == yaml_file: + continue with open(yaml_file, "r") as f: ops = yaml.safe_load(f) op_yaml_items = op_yaml_items + ops + # replace old ir ops with pir ops + if need_update_ops: + update_ops(op_yaml_items, update_yaml_file) op_info_items = [] for op in op_yaml_items: op_compat_item = op_compat_parser.get_compat(op['name']) @@ -178,6 +186,13 @@ def _parse_yaml(self, op_yaml_files, op_compat_yaml_file): and 'scalar' in op_compat_item ): op_compat_item = op_compat_item.pop('scalar') + if 'support_tensor' in op.keys() and op['support_tensor']: + ( + scalar_item, + int_array_item, + ) = op_compat_parser.parse_support_tensor(op) + op_compat_item['scalar'] = scalar_item + op_compat_item['int_array'] = int_array_item op_info_items.append(OpInfoParser(op, op_compat_item)) return op_info_items diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py new file mode 100644 index 0000000000000..5d464f27a0b6c --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ===================================== +# DecompInterface gen op list +# ===================================== + + +# come into effect in generated file pd_op.h +# manual decomp interface declare are located in manual_op.h +decomp_interface_declare_gen_op_list = ["mean", "squeeze", "add_n", "relu"] + +# come into effect in generated file op_decomp.cc +# manual decomp interface implementation are located in manual_op_decomp.cc +decomp_interface_implementation_gen_op_list = [ + "mean", + "squeeze", + "add_n", + "relu", +] diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index ba78e7d7dc722..dcc2123c23686 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -21,6 +21,8 @@ 'ReduceIntArrayAxisInferMeta', 'ReshapeWithXShapeInferMeta', 'SliceRawInferMeta', + 'StackInferMeta', + 'Conv2dTransposeInferMeta', } _PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE = {'FrobeniusNormOp'} @@ -355,14 +357,37 @@ def GenBuildOutputs( meta_{name}.push_back(&vec_meta_{name}[i]); }} """ + + CREATE_OPTIONAL_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_ir_meta_tensor_{name}; + if ({name}_.impl() != nullptr) {{ + pir::VectorType {name} = {name}_.type().dyn_cast(); + for (size_t i=0; i < static_cast({name}.size()); i++) {{ + vec_ir_meta_tensor_{name}.push_back(paddle::dialect::IrMetaTensor(paddle::dialect::TransToPhiDataType({name}[i].dyn_cast().dtype()), + {name}[i].dyn_cast().dims(), + {name}[i].dyn_cast().data_layout(), + {name}[i].dyn_cast().lod(), + {name}[i].dyn_cast().offset())); + }} + }} + + std::vector vec_meta_{name}; + for (size_t i=0; i < vec_ir_meta_tensor_{name}.size(); i++) {{ + vec_meta_{name}.push_back(phi::MetaTensor(&vec_ir_meta_tensor_{name}[i])); + }} + + std::vector meta_{name}; + for (size_t i=0; i < static_cast(vec_meta_{name}.size()); i++) {{ + meta_{name}.push_back(&vec_meta_{name}[i]); + }} + +""" + CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name}; if ({name}_.dyn_cast().owner()->isa()) {{ - {name} = std::move(phi::IntArray({name}_.dyn_cast().owner() + {name} = std::move(phi::IntArray(paddle::dialect::GetInt64Vector( + {name}_.dyn_cast().owner() ->dyn_cast() - .attribute("value") - .dyn_cast() - .data() - .GetData())); + .attribute("value")))); }} else if ({name}_.type().isa()) {{ size_t {name}_size = {name}_.type().dyn_cast().size(); {name} = std::move(phi::IntArray(std::vector({name}_size, -1))); @@ -377,12 +402,10 @@ def GenBuildOutputs( CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ std::vector {name}; if ({name}_.dyn_cast().owner()->isa()) {{ - {name} = {name}_.dyn_cast().owner() + {name} = paddle::dialect::GetInt64Vector( + {name}_.dyn_cast().owner() ->dyn_cast() - .attribute("value") - .dyn_cast() - .data() - .GetData(); + .attribute("value")); }} else if ({name}_.type().isa()) {{ size_t {name}_size = {name}_.type().dyn_cast().size(); {name} = std::vector({name}_size, -1); @@ -424,9 +447,10 @@ def GenBuildOutputs( for idx in range(len(op_input_name_list)): # is a vector if 'pir::VectorType' in op_input_type_list[idx]: - build_output_str += " pir::VectorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( - name=op_input_name_list[idx] - ) + if op_input_optional_list[idx] == 'false': + build_output_str += " pir::VectorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( + name=op_input_name_list[idx] + ) # is a Tensor else: if op_input_optional_list[idx] == 'false': @@ -481,11 +505,19 @@ def GenBuildOutputs( ) ] ): - build_output_str += ( - CREATE_INPUT_VEC_METATENSOR_TEMPLATE.format( + input_index = op_input_name_list.index( + op_infer_meta_map['param'][idx] + ) + if op_input_optional_list[input_index] == 'true': + build_output_str += CREATE_OPTIONAL_INPUT_VEC_METATENSOR_TEMPLATE.format( name=op_infer_meta_map['param'][idx] ) - ) + else: + build_output_str += ( + CREATE_INPUT_VEC_METATENSOR_TEMPLATE.format( + name=op_infer_meta_map['param'][idx] + ) + ) # is a Tensor else: input_index = op_input_name_list.index( @@ -695,41 +727,36 @@ def gen_build_func_str( ) GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """ - PADDLE_ENFORCE( + IR_ENFORCE( attributes.find("{attribute_name}") != attributes.end(), - phi::errors::NotFound( - "'{attribute_name}' Attribute is expected for {op_name}. ")); + "'{attribute_name}' Attribute is expected for {op_name}. "); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data(); """ GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE = """ - PADDLE_ENFORCE( + IR_ENFORCE( attributes.find("{attribute_name}") != attributes.end(), - phi::errors::NotFound( - "'{attribute_name}' Attribute is expected for {op_name}. ")); + "'{attribute_name}' Attribute is expected for {op_name}. "); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().AsString(); """ GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ - PADDLE_ENFORCE( + IR_ENFORCE( attributes.find("{attribute_name}") != attributes.end(), - phi::errors::NotFound( - "'{attribute_name}' Attribute is expected for {op_name}. ")); + "'{attribute_name}' Attribute is expected for {op_name}. "); {attr_type} {attribute_name}; for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ {attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast().at(i).dyn_cast<{inner_type}>().{data_name}()); }} """ GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ - PADDLE_ENFORCE( + IR_ENFORCE( attributes.find("{attribute_name}") != attributes.end(), - phi::errors::NotFound( - "'{attribute_name}' Attribute is expected for {op_name}. ")); + "'{attribute_name}' Attribute is expected for {op_name}. "); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().GetData(); """ GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE = """ - PADDLE_ENFORCE( + IR_ENFORCE( attributes.find("{attribute_name}") != attributes.end(), - phi::errors::NotFound( - "'{attribute_name}' Attribute is expected for {op_name}. ")); + "'{attribute_name}' Attribute is expected for {op_name}. "); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().to<{attr_type}>(); """ diff --git a/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py new file mode 100644 index 0000000000000..a6661a80a9a29 --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py @@ -0,0 +1,192 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import yaml +from op_gen import ( + OpCompatParser, + OpInfoParser, + check_need_update_ops, + to_pascal_case, + update_ops, +) + +CPP_FILE_TEMPLATE = """ +#include "paddle/fluid/pir/drr/ir_operation_factory.h" + +{op_header} +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" + +namespace pir {{ +namespace drr {{ + +void OperationFactory::Register{dialect}GeneratedOpCreator() {{ +{body} +}} + +}} // namespace drr +}} // namespace pir + +""" + +NORMAL_FUNCTION_TEMPLATE = """ + RegisterOperationCreator( + "{op_name}", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) {{ + return rewriter.Build<{namespace}::{op_class_name}>( + {params_code}); + }}); +""" + +MUTABLE_ATTR_FUNCTION_TEMPLATE = """ + RegisterOperationCreator( + "{op_name}", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) {{ + // mutable_attr is tensor + if (inputs.size() > {inputs_num}) {{ + return rewriter.Build( + {params_code_with_mutable_attr}); + }} else {{ + return rewriter.Build( + {params_code_no_mutable_attr}); + }} + }}); +""" + +Dialect2NameSpaceMap = {"pd_op": "paddle::dialect", "cinn_op": "cinn::dialect"} +Dialect2OpHeaderMap = { + "pd_op": "#include \"paddle/fluid/pir/dialect/operator/ir/pd_op.h\"", + "cinn_op": "#include \"paddle/cinn/hlir/dialect/operator/ir/cinn_op.h\"", +} + + +class OpCreatorCodeGen: + def __init__(self, op_yaml_files, op_compat_yaml_file, dialect_name): + self.op_info_items = self.parse_yaml(op_yaml_files, op_compat_yaml_file) + self.dialect_name = dialect_name + + def parse_yaml(self, op_yaml_files, op_compat_yaml_file): + op_compat_parser = OpCompatParser(op_compat_yaml_file) + need_update_ops, update_yaml_file = check_need_update_ops(op_yaml_files) + + op_yaml_items = [] + for yaml_file in op_yaml_files: + if update_yaml_file == yaml_file: + continue + with open(yaml_file, "r") as f: + ops = yaml.safe_load(f) + op_yaml_items = op_yaml_items + ops + # replace old ir ops with pir ops + if need_update_ops: + update_ops(op_yaml_items, update_yaml_file) + + op_info_items = [] + for op in op_yaml_items: + op_compat_item = op_compat_parser.get_compat(op['name']) + if ( + op_compat_item is not None + and op_compat_item['op'] == "pow" + and 'scalar' in op_compat_item + ): + op_compat_item = op_compat_item.pop('scalar') + op_info_items.append(OpInfoParser(op, op_compat_item)) + return op_info_items + + def gen_cpp_file_code(self, cpp_file_path): + body_code = "" + for op_info_item in self.op_info_items: + if op_info_item.infer_meta_map is None: + continue + for phi_op_name in op_info_item.op_phi_name: + ir_op_name = self.dialect_name + "." + phi_op_name + params_no_mutable_attr = [] + for i in range(len(op_info_item.input_name_list)): + params_no_mutable_attr.append( + f"inputs[{i}].dyn_cast()" + ) + if len(op_info_item.attribute_name_list) > 0: + params_no_mutable_attr.append("attrs") + + if len(op_info_item.mutable_attribute_name_list) == 0: + body_code += NORMAL_FUNCTION_TEMPLATE.format( + op_name=ir_op_name, + namespace=Dialect2NameSpaceMap[self.dialect_name], + op_class_name=(to_pascal_case(phi_op_name) + "Op"), + params_code=", ".join(params_no_mutable_attr), + ) + else: + params_with_mutable_attr = [] + for i in range( + len(op_info_item.input_name_list) + + len(op_info_item.mutable_attribute_name_list) + ): + params_with_mutable_attr.append( + f"inputs[{i}].dyn_cast()" + ) + if len(op_info_item.attribute_name_list) > len( + op_info_item.mutable_attribute_name_list + ): + # TODO(zyfncg): Currently Op::Build Interface doesn't support this case. + continue + # params_with_mutable_attr.append("attrs") + + body_code += MUTABLE_ATTR_FUNCTION_TEMPLATE.format( + op_name=ir_op_name, + inputs_num=len(op_info_item.input_name_list), + op_class_name=(to_pascal_case(phi_op_name) + "Op"), + params_code_with_mutable_attr=",".join( + params_with_mutable_attr + ), + params_code_no_mutable_attr=", ".join( + params_no_mutable_attr + ), + ) + + with open(cpp_file_path, 'w') as f: + f.write( + CPP_FILE_TEMPLATE.format( + dialect=to_pascal_case(self.dialect_name), + op_header=Dialect2OpHeaderMap[self.dialect_name], + body=body_code, + ) + ) + + +def ParseArguments(): + parser = argparse.ArgumentParser( + description='Generate Op Creator Files By Yaml' + ) + parser.add_argument('--op_yaml_files', type=str) + parser.add_argument('--op_compat_yaml_file', type=str) + parser.add_argument('--dialect_name', type=str) + parser.add_argument('--op_creator_file', type=str) + return parser.parse_args() + + +if __name__ == '__main__': + args = ParseArguments() + op_yaml_files = args.op_yaml_files.split(",") + op_compat_yaml_file = args.op_compat_yaml_file + op_creator_file = args.op_creator_file + dialect_name = args.dialect_name + + code_gen = OpCreatorCodeGen( + op_yaml_files, op_compat_yaml_file, dialect_name + ) + code_gen.gen_cpp_file_code(op_creator_file) diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 64caafc544892..e4435940e609f 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -19,18 +19,17 @@ import sys import yaml +from decomp_interface_gen_op_list import decomp_interface_declare_gen_op_list from op_build_gen import gen_build_func_str, gen_build_func_str_by_invoke from op_interface_gen import ( gen_exclusive_interface_str, gen_op_infer_meta_str, gen_op_vjp_str, ) +from op_kerneltype_gen import gen_kernel_type_for_var_str from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str -from vjp_interface_gen_op_list import ( - vjp_interface_declare_gen_op_list, - vjp_interface_implementation_gen_op_list, -) +from vjp_interface_black_list import vjp_interface_black_list # import from paddle/fluid/primitive/code_gen/gen.py sys.path.append( @@ -61,12 +60,14 @@ #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" +#include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" #include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/ir_adaptor/translator/utils.h" +{only_pd_op_header_files} {op_to_multi_kernels_map} @@ -100,6 +101,7 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ {build_attr_num_over_1} {build_mutable_attr_is_input_attr_num_over_1} void VerifySig(); +{get_kernel_type_for_var_declare} {get_inputs_and_outputs} {exclusive_interface} }}; @@ -111,6 +113,13 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ "static const char *attributes_name[{attribute_num}];" ) +get_kernel_type_for_var_declare_template = """ + static phi::DataType GetKernelTypeForVar( + const std::string& var_name, + const phi::DataType& tensor_dtype, + const phi::DataType& expected_kernel_dtype); +""" + # ===================================== # String Template for cc file code gen # ===================================== @@ -207,6 +216,55 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ } +attr_types_map = { + 'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'], + 'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'], + 'Scalar(int)': ['pir::Int32Attribute', 'int'], + 'Scalar(int64_t)': ['pir::Int64Attribute', 'int64_t'], + 'Scalar(float)': ['pir::FloatAttribute', 'float'], + 'Scalar(double)': ['pir::DoubleAttribute', 'double'], + 'Scalar[]': [ + 'pir::ArrayAttribute', + 'const std::vector&', + ], + 'int': ['pir::Int32Attribute', 'int'], + 'int32_t': ['pir::Int32Attribute', 'int32_t'], + 'int64_t': ['pir::Int64Attribute', 'int64_t'], + 'long': ['pir::LongAttribute', 'long'], + 'size_t': ['pir::Size_tAttribute', 'size_t'], + 'float': ['pir::FloatAttribute', 'float'], + 'float[]': [ + 'pir::ArrayAttribute', + 'const std::vector&', + ], + 'double': ['pir::DoubleAttribute', 'double'], + 'bool': ['pir::BoolAttribute', 'bool'], + 'bool[]': [ + 'pir::ArrayAttribute', + 'const std::vector&', + ], + 'str': ['pir::StrAttribute', 'const std::string&'], + 'str[]': [ + 'pir::ArrayAttribute', + 'const std::vector&', + ], + 'Place': ['paddle::dialect::PlaceAttribute', 'const Place&'], + 'DataLayout': [ + 'paddle::dialect::DataLayoutAttribute', + 'DataLayout', + ], + 'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'], + 'int64_t[]': [ + 'pir::ArrayAttribute', + 'const std::vector&', + ], + 'int[]': [ + 'pir::ArrayAttribute', + 'const std::vector&', + ], +} + + def to_phi_and_fluid_op_name(op_item): # Templat: - op : phi_name (fluid_name) names = op_item.split('(') @@ -252,6 +310,23 @@ def get_compat(self, op_name): return compat return None + def parse_support_tensor(self, op): + scalar_item = {} + int_array_item = {} + for support_tensor_attr in op['support_tensor']: + for attr in op['attrs']: + if ( + attr['typename'] == 'Scalar' + and attr['name'] == support_tensor_attr + ): + scalar_item[support_tensor_attr] = {"support_tensor": True} + if ( + attr['typename'] == 'IntArray' + and attr['name'] == support_tensor_attr + ): + scalar_item[support_tensor_attr] = {"support_tensor": True} + return scalar_item, int_array_item + # ===================================== # Parse Op Information From Yaml @@ -288,53 +363,7 @@ def __init__(self, op_yaml_item, op_compat_item): ) # parse attributes - self.attr_types_map = { - 'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'], - 'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'], - 'Scalar(int)': ['pir::Int32Attribute', 'int'], - 'Scalar(int64_t)': ['pir::Int64Attribute', 'int64_t'], - 'Scalar(float)': ['pir::FloatAttribute', 'float'], - 'Scalar(dobule)': ['pir::DoubleAttribute', 'dobule'], - 'Scalar[]': [ - 'pir::ArrayAttribute', - 'const std::vector&', - ], - 'int': ['pir::Int32Attribute', 'int'], - 'int32_t': ['pir::Int32Attribute', 'int32_t'], - 'int64_t': ['pir::Int64Attribute', 'int64_t'], - 'long': ['pir::LongAttribute', 'long'], - 'size_t': ['pir::Size_tAttribute', 'size_t'], - 'float': ['pir::FloatAttribute', 'float'], - 'float[]': [ - 'pir::ArrayAttribute', - 'const std::vector&', - ], - 'double': ['pir::DoubleAttribute', 'double'], - 'bool': ['pir::BoolAttribute', 'bool'], - 'bool[]': [ - 'pir::ArrayAttribute', - 'const std::vector&', - ], - 'str': ['pir::StrAttribute', 'const std::string&'], - 'str[]': [ - 'pir::ArrayAttribute', - 'const std::vector&', - ], - 'Place': ['paddle::dialect::PlaceAttribute', 'const Place&'], - 'DataLayout': [ - 'paddle::dialect::DataLayoutAttribute', - 'DataLayout', - ], - 'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'], - 'int64_t[]': [ - 'pir::ArrayAttribute', - 'const std::vector&', - ], - 'int[]': [ - 'pir::ArrayAttribute', - 'const std::vector&', - ], - } + self.attr_types_map = attr_types_map self.attribute_name_list = self.parse_attribute_name_list() self.attribute_type_list = self.parse_attribute_type_list() self.attribute_build_arg_type_list = ( @@ -378,6 +407,9 @@ def __init__(self, op_yaml_item, op_compat_item): self.inplace_map = self.parse_op_inplace_info() self.view_map = self.parse_op_view_info() + # parse data_transform + self.data_transform_map = self.parse_data_transform_info() + # parse has_custom_verify self.custom_verify = self.parse_custom_verify() @@ -717,8 +749,8 @@ def parse_output_size_list(self): def parse_output_optional_list(self): optional_list = [] for output_info in self.op_yaml_item['outputs']: - if 'optional' in output_info: - if output_info['optional']: + if 'optional' in output_info or 'intermediate' in output_info: + if output_info['optional'] or output_info['intermediate']: optional_list.append("true") else: optional_list.append("false") @@ -834,6 +866,15 @@ def parse_invoke_map(self): else: return None + def parse_data_transform_info(self): + if ( + 'data_transform' in self.op_yaml_item + and self.op_yaml_item['data_transform'] + ): + data_trans_item = self.op_yaml_item['data_transform'] + return data_trans_item + return None + def parse_backward_name(self): if 'backward' in self.op_yaml_item: return self.op_yaml_item['backward'] @@ -931,6 +972,27 @@ def get_mutable_attribute_grad_semantic(op_info, op_info_items): return mutable_attribute_grad_semantics +def check_need_update_ops(op_yaml_files): + need_update_ops = False + update_yaml_file = None + for yaml_file in op_yaml_files: + if yaml_file.find("update_ops.parsed.yaml") != -1: + need_update_ops = True + update_yaml_file = yaml_file + break + return need_update_ops, update_yaml_file + + +def update_ops(op_yaml_items, update_yaml_file): + with open(update_yaml_file, "r") as f: + update_ops = yaml.safe_load(f) + for i in range(len(op_yaml_items)): + for update_op in update_ops: + if op_yaml_items[i]['name'] == update_op['name']: + op_yaml_items[i] = update_op + break + + def OpGenerator( op_yaml_files, op_compat_yaml_file, @@ -948,12 +1010,19 @@ def OpGenerator( # (2) Prepare: Get all op item in all op_yaml_files op_compat_parser = OpCompatParser(op_compat_yaml_file) + need_update_ops, update_yaml_file = check_need_update_ops(op_yaml_files) op_yaml_items = [] for yaml_file in op_yaml_files: + if update_yaml_file == yaml_file: + continue with open(yaml_file, "r") as f: ops = yaml.safe_load(f) op_yaml_items = op_yaml_items + ops + # replace old ir ops with pir ops + if need_update_ops: + update_ops(op_yaml_items, update_yaml_file) + op_info_items = {} for op in op_yaml_items: op_compat_item = op_compat_parser.get_compat(op['name']) @@ -971,6 +1040,13 @@ def OpGenerator( ): op_compat_item = op_compat_item.pop('scalar') + if 'support_tensor' in op.keys() and op['support_tensor']: + scalar_item, int_array_item = op_compat_parser.parse_support_tensor( + op + ) + op_compat_item['scalar'] = scalar_item + op_compat_item['int_array'] = int_array_item + op_info_items[op['name']] = OpInfoParser(op, op_compat_item) # (3) CodeGen: Traverse op_info_items and generate ops_name_list = [] # all op class name store in this list @@ -1025,6 +1101,7 @@ def OpGenerator( op_invoke_map = op_info.invoke_map op_inplace_map = op_info.inplace_map op_view_map = op_info.view_map + op_data_transform_map = op_info.data_transform_map op_interfaces = ["paddle::dialect::OpYamlInfoInterface"] op_traits = [] @@ -1036,13 +1113,16 @@ def OpGenerator( if ( op_info.backward_name - and op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list + and op_info.op_phi_name[0] not in vjp_interface_black_list ): op_interfaces += ["paddle::dialect::VjpInterface"] exclusive_interface_str = gen_exclusive_interface_str( op_info, op_info_items ) + if dialect_name == "pd_op": + op_interfaces += ["paddle::dialect::GetKernelTypeForVarInterface"] + # if op has custom vjp rule, then append a CustomVjpTrait to it if op_info.op_phi_name[0] in custom_vjp_op_name_list: op_traits += ["paddle::dialect::CustomVjpTrait"] @@ -1052,9 +1132,19 @@ def OpGenerator( mutable_attribute_grad_semantics = get_mutable_attribute_grad_semantic( op_info, op_info_items ) + op_interfaces_tmp = op_interfaces + exclusive_interface_str_tmp = exclusive_interface_str # If op has inplace info, we will generate inplace op and non-inplace op. for op_name in op_info.op_phi_name: + if op_name in decomp_interface_declare_gen_op_list: + op_interfaces = op_interfaces + [ + "paddle::dialect::DecompInterface" + ] + exclusive_interface_str += "\n static std::vector> Decomp(pir::Operation* op);" + else: + op_interfaces = op_interfaces_tmp + exclusive_interface_str = exclusive_interface_str_tmp if op_name in PD_MANUAL_OP_LIST: continue if op_kernel_map is None: @@ -1123,6 +1213,12 @@ def OpGenerator( build_func_with_attr_is_map = "" build_func_with_muta_attr_is_input = "" + get_kernel_type_for_var_declare_str = "" + if dialect_name == "pd_op": + get_kernel_type_for_var_declare_str = ( + get_kernel_type_for_var_declare_template + ) + if op_infer_meta_map is not None: ( build_args_with_muta_attr_not_input_for_declare, @@ -1255,6 +1351,7 @@ def OpGenerator( build_mutable_attr_is_input_attr_num_over_1=build_mutable_attr_is_input_attr_num_over_1, get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, + get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str, ) op_defined_str = "" else: @@ -1275,6 +1372,7 @@ def OpGenerator( build_mutable_attr_is_input_attr_num_over_1=build_mutable_attr_is_input_attr_num_over_1, get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, + get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str, ) attribute_names_str = ( '"' @@ -1371,11 +1469,31 @@ def OpGenerator( 'data_type' in op_kernel_map and op_kernel_map['data_type'] ): - kernel_key_dtype = '", "'.join( - op_kernel_map['data_type']['candidates'] - ) + for idx in range( + len(op_kernel_map['data_type']['candidates']) + ): + if ( + 'to_complex_flag' in op_kernel_map['data_type'] + and op_kernel_map['data_type'][ + 'to_complex_flag' + ][idx] + ): + kernel_key_dtype += ( + 'complex:' + + op_kernel_map['data_type']['candidates'][ + idx + ] + + '", "' + ) + else: + kernel_key_dtype += ( + op_kernel_map['data_type']['candidates'][ + idx + ] + + '", "' + ) if kernel_key_dtype != "": - kernel_key_dtype = '"' + kernel_key_dtype + '"' + kernel_key_dtype = '"' + kernel_key_dtype[:-3] if 'backend' in op_kernel_map and op_kernel_map['backend']: kernel_key_backend = '", "'.join( op_kernel_map['backend']['candidates'] @@ -1426,6 +1544,18 @@ def OpGenerator( op_output_optional_list, ) + # generate op GetKernelKeyForVar function str + op_get_kernel_type_for_var_str = '' + if dialect_name == "pd_op": + op_get_kernel_type_for_var_str = ( + gen_kernel_type_for_var_str( + op_class_name, + op_data_transform_map, + op_kernel_map, + op_info.op_compat_item, + ) + ) + op_infer_meta_str = gen_op_infer_meta_str( op_info, op_class_name, op_info_items ) @@ -1444,7 +1574,7 @@ def OpGenerator( if ( op_info.backward_name and op_info.op_phi_name[0] - in vjp_interface_implementation_gen_op_list + not in vjp_interface_black_list ): op_vjp_str = gen_op_vjp_str( op_class_name, @@ -1470,6 +1600,8 @@ def OpGenerator( ops_defined_list.append(op_verify_str) ops_defined_list.append(op_infer_meta_str) + ops_defined_list.append(op_get_kernel_type_for_var_str) + # NOTE(chenxi67)skip if dialect_name==cinn if dialect_name == "cinn": pass @@ -1520,12 +1652,18 @@ def OpGenerator( head_file_str = "" head_file_str += "".join(ops_declare_list) # Add op class + only_pd_op_header_files_str = "" + if dialect_name == "pd_op": op_to_multi_kernels_map = OP_TO_MULTI_KERNELS_MAP_H for name in reversed(namespaces): op_to_multi_kernels_map = NAMESPACE_GARD_TEMPLATE.format( namespace=name, input=op_to_multi_kernels_map ) # Add namespaces + only_pd_op_header_files_str = """ +#include \"paddle/phi/common/data_type.h\" +#include \"paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h\" + """ else: op_to_multi_kernels_map = "" @@ -1538,6 +1676,7 @@ def OpGenerator( op_to_multi_kernels_map=op_to_multi_kernels_map, input=head_file_str, declare_type_id=declare_type_id_str, + only_pd_op_header_files=only_pd_op_header_files_str, ) # Add head # (5) Generate source file str diff --git a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py index 9c8ff889f2b21..6d7c5224e3803 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py @@ -13,7 +13,7 @@ # limitations under the License. # generator interfaces -from vjp_interface_gen_op_list import vjp_interface_declare_gen_op_list +from vjp_interface_black_list import vjp_interface_black_list OP_INFER_SHAPE_TEMPLATE = """ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ @@ -21,48 +21,44 @@ fn(infer_meta); }} """ +CHECK_INPUT_TEMPLATE = """ + PADDLE_ENFORCE_EQ( + inputs_.size(), + {inputs_size}, + platform::errors::InvalidArgument("{op_name} op's inputs size should be {inputs_size}, but now is %d.", inputs_.size())); + PADDLE_ENFORCE_EQ( + outputs.size(), + {outputs_size}, + platform::errors::InvalidArgument("{op_name} op's outputs size should be {outputs_size}, but now is %d.", outputs.size())); +""" OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """ - {input_type} {input_name}(std::make_shared(op_obj.{input_name}()));""" + {input_type} {input_name}(std::make_shared({vjp_param_name}[{input_idx}][0]));""" OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE = """ - pir::CombineOp combine_op_obj = - op_obj.{input_name}().dyn_cast().owner()->dyn_cast(); std::vector {input_name}; - for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {{ + for (size_t idx = 0; idx < {vjp_param_name}[{input_idx}].size(); idx++) {{ {input_name}.emplace_back( - std::make_shared(combine_op_obj.inputs()[idx])); + std::make_shared({vjp_param_name}[{input_idx}][idx])); }}""" OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE = """ paddle::optional {input_name}; - if (!IsEmptyValue(op_obj.{input_name}())){{ - {input_name} = paddle::make_optional(Tensor(std::make_shared(op_obj.{input_name}()))); + if (!IsEmptyValue({vjp_param_name}[{input_idx}][0])){{ + {input_name} = paddle::make_optional(Tensor(std::make_shared({vjp_param_name}[{input_idx}][0]))); }}""" OP_VJP_FORWARD_OPTIONAL_VECTOR_INPUT_TEMPLATE = """ paddle::optional> {input_name}; - if (!IsEmptyValue(op_obj.{input_name}())){{ - pir::CombineOp combine_op_obj = - op_obj.{input_name}().dyn_cast().owner()->dyn_cast(); - std::vector optional_{input_name}; - for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {{ + std::vector optional_{input_name}; + if (!IsEmptyValue({vjp_param_name}[{input_idx}][0])){{ + for (size_t idx = 0; idx < {vjp_param_name}[{input_idx}].size(); idx++) {{ optional_{input_name}.emplace_back( - std::make_shared(combine_op_obj.inputs()[idx])); + std::make_shared({vjp_param_name}[{input_idx}][idx])); }} {input_name} = paddle::make_optional>(optional_{input_name}); }}""" -OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """ - Tensor {output_grad_name}(std::make_shared(out_grads[{idx1}][{idx2}]));""" - -OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """ - std::vector {output_grad_name}; - for (size_t idx = 0; idx < out_grads[{index}].size(); idx++) {{ - {output_grad_name}.emplace_back( - std::make_shared(out_grads[{index}][idx])); - }}""" - OP_VJP_ATTRIBUTE_TEMPLATE = """ {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().{func}();""" @@ -92,12 +88,10 @@ }""" OP_VJP_DEFINE_TEMPLATE = """ -std::vector> {op_class_name}::Vjp(pir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients){{ - {op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); (void)op_obj; - +std::vector> {op_class_name}::Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients){{ +{check_param} VLOG(6) << "Prepare inputs of {op_grad_name}"; -{forward_input_output_code} -{forward_output_grad_code} +{backward_input_code} VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; {attribute_code} @@ -125,62 +119,66 @@ def gen_op_vjp_str( op_grad_info, ): bw_input_list = op_grad_info.input_name_list - forward_input_output_code = '' - forward_output_grad_code = '' + fwd_input_and_mutable_attr_name_list = ( + op_info.input_name_list + op_info.mutable_attribute_name_list + ) + + backward_input_code = '' build_args_str = '' grad_idx = -1 for idx in range(len(bw_input_list)): - build_args_str += bw_input_list[idx] + ", " + bw_input_name = bw_input_list[idx] + build_args_str += bw_input_name + ", " + input_type = input_types_map[op_grad_info.input_type_list[idx]] + + vjp_param_name = '' + index_0 = -1 + if bw_input_name in fwd_input_and_mutable_attr_name_list: + vjp_param_name = 'inputs_' + index_0 = fwd_input_and_mutable_attr_name_list.index(bw_input_name) + elif bw_input_name in op_info.output_name_list: + vjp_param_name = 'outputs' + index_0 = op_info.output_name_list.index(bw_input_name) + else: + vjp_param_name = 'out_grads' + grad_idx += 1 + index_0 = grad_idx if op_grad_info.input_optional_list[idx] == 'true': - input_type = input_types_map[op_grad_info.input_type_list[idx]] if input_type == 'Tensor': - forward_input_output_code += ( + backward_input_code += ( OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE.format( - input_name=bw_input_list[idx], + vjp_param_name=vjp_param_name, + input_name=bw_input_name, + input_idx=index_0, ) ) else: - forward_input_output_code += ( + backward_input_code += ( OP_VJP_FORWARD_OPTIONAL_VECTOR_INPUT_TEMPLATE.format( - input_name=bw_input_list[idx], + vjp_param_name=vjp_param_name, + input_name=bw_input_name, + input_idx=index_0, ) ) else: - if ( - bw_input_list[idx] in op_info.input_name_list - or bw_input_list[idx] in op_info.output_name_list - ): - input_type = input_types_map[op_grad_info.input_type_list[idx]] - if input_type == 'Tensor': - forward_input_output_code += ( - OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( - input_type=input_type, - input_name=bw_input_list[idx], - ) - ) - else: - forward_input_output_code += ( - OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE.format( - input_name=bw_input_list[idx], - ) + if input_type == 'Tensor': + backward_input_code += ( + OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( + vjp_param_name=vjp_param_name, + input_type=input_type, + input_name=bw_input_name, + input_idx=index_0, ) + ) else: - grad_idx += 1 - input_type = input_types_map[op_grad_info.input_type_list[idx]] - if input_type == 'Tensor': - forward_output_grad_code += ( - OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format( - output_grad_name=bw_input_list[idx], - idx1=grad_idx, - idx2=0, - ) - ) - else: - forward_input_output_code += ( - OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE.format( - output_grad_name=bw_input_list[idx], index=grad_idx - ) + backward_input_code += ( + OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE.format( + vjp_param_name=vjp_param_name, + input_name=bw_input_name, + input_idx=index_0, ) + ) + op_attribute_list = op_grad_info.attribute_name_list attribute_code = '' build_attr_str = '' @@ -190,8 +188,12 @@ def gen_op_vjp_str( if op_attribute_list[idx] in op_info.mutable_attribute_name_list: attribute_code += ( OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( + vjp_param_name='inputs_', input_type="Tensor", input_name=op_attribute_list[idx], + input_idx=fwd_input_and_mutable_attr_name_list.index( + op_attribute_list[idx] + ), ) ) build_args_str += op_attribute_list[idx] + ", " @@ -241,14 +243,19 @@ def gen_op_vjp_str( inputs_list=build_args_str, ) stop_gradient_input_grad_code = OP_VJP_STOPGRADIENT_TEMPLATE + check_param = CHECK_INPUT_TEMPLATE.format( + op_name=op_phi_name_format, + inputs_size=len(fwd_input_and_mutable_attr_name_list), + outputs_size=len(op_info.output_name_list), + out_grads_size=grad_idx + 1, + ) str = OP_VJP_DEFINE_TEMPLATE.format( + check_param=check_param, op_class_name=op_class_name, op_grad_name=op_grad_name, op_phi_name=op_phi_name, - res_size=len(op_info.input_name_list), - forward_input_output_code=forward_input_output_code, - forward_output_grad_code=forward_output_grad_code, + backward_input_code=backward_input_code, attribute_code=attribute_code, call_vjp_code=call_vjp_code, stop_gradient_input_grad_code=stop_gradient_input_grad_code, @@ -285,6 +292,6 @@ def gen_exclusive_interface_str(op_info, op_info_items): exclusive_interface_str += ( " static void InferMeta( phi::InferMetaContext *infer_meta );" ) - if op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list: - exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients);" + if op_info.op_phi_name[0] not in vjp_interface_black_list: + exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients);" return exclusive_interface_str diff --git a/paddle/fluid/pir/dialect/op_generator/op_kerneltype_gen.py b/paddle/fluid/pir/dialect/op_generator/op_kerneltype_gen.py new file mode 100644 index 0000000000000..06250e40a0283 --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/op_kerneltype_gen.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +OP_GET_KERNEL_TYPE_FOR_VAR_TEMPLATE = """ +phi::DataType {op_name}::GetKernelTypeForVar( + const std::string& var_name, + const phi::DataType& tensor_dtype, + const phi::DataType& expected_kernel_dtype) {{ + VLOG(4) << "Get KernelType for Var of op: {op_name}"; + {data_transform_check}{complex_promote_check} + return expected_kernel_dtype; +}} +""" + +OP_DATA_TRANSFORM_CHECK_TEMPLATE = """ +{skip_trans}{support_trans} +""" + +OP_SKIP_TRANSFORM_CHECK_TEMPLATE = """ + // deal skip data transform + if ({skip_transform_check}){{ + return expected_kernel_dtype; + }} +""" + +OP_SUPPORT_TRANSFORM_CHECK_TEMPLATE = """ + // deal support data transform + VLOG(8) << "SUPPORT_TRANSFORM: " << \"{support_dtype_name};"; + return tensor_dtype; +""" + +OP_COMPLEX_PROMOTE_CHECK_TEMPLATE = """ + // deal complex_promote + if (framework::IsComplexType(expected_kernel_dtype)) {{ + // only promote inputs’s types when contains complex input + return tensor_dtype; + }} +""" + + +def get_data_transform_check_str(op_data_transform_map): + skip_trans_str = "" + support_trans_str = "" + if op_data_transform_map is not None: + args = None + if "skip_transform" in op_data_transform_map: + args = op_data_transform_map["skip_transform"] + if args is not None: + if_cond_args = [] + for skip_arg in args: + if_cond_args.append("var_name == \"" + skip_arg + "\"") + skip_trans_str = OP_SKIP_TRANSFORM_CHECK_TEMPLATE.format( + skip_transform_check=' || '.join(if_cond_args) + ) + if "support_trans_dtype" in op_data_transform_map: + args = op_data_transform_map["support_trans_dtype"] + # TODO:(chenxi) comlete SUPPORT logic + if args is not None: + support_trans_str = OP_SUPPORT_TRANSFORM_CHECK_TEMPLATE.format( + support_dtype_name=args + ) + + return OP_DATA_TRANSFORM_CHECK_TEMPLATE.format( + skip_trans=skip_trans_str, + support_trans=support_trans_str, + ) + + +def get_complex_promote_check_str(op_compat_item): + complex_promote_check_str = "" + if ( + op_compat_item is not None + and "complex_promote" in op_compat_item + and op_compat_item["complex_promote"] is not None + ): + complex_promote_check_str = OP_COMPLEX_PROMOTE_CHECK_TEMPLATE + return complex_promote_check_str + + +def gen_kernel_type_for_var_str( + op_class_name, op_data_transform_map, op_kernel_map, op_compat_item +): + complex_promote_check_str = get_complex_promote_check_str(op_compat_item) + data_transform_check_str = get_data_transform_check_str( + op_data_transform_map + ) + + return OP_GET_KERNEL_TYPE_FOR_VAR_TEMPLATE.format( + op_name=op_class_name, + data_transform_check=data_transform_check_str, + complex_promote_check=complex_promote_check_str, + ) diff --git a/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py b/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py index 3a2515f278915..f42a73347d13a 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py @@ -19,8 +19,8 @@ VLOG(4) << "Verifying inputs:"; {{ auto input_size = num_operands(); - PADDLE_ENFORCE_EQ(input_size, {inputs_size}u, - phi::errors::PreconditionNotMet("The size %d of inputs must be equal to {inputs_size}.", input_size));{inputs_type_check} + IR_ENFORCE(input_size == {inputs_size}u, + "The size %d of inputs must be equal to {inputs_size}.", input_size);{inputs_type_check} }} VLOG(4) << "Verifying attributes:"; {{{attributes_check} @@ -28,8 +28,8 @@ VLOG(4) << "Verifying outputs:"; {{ auto output_size = num_results(); - PADDLE_ENFORCE_EQ(output_size, {outputs_size}u, - phi::errors::PreconditionNotMet("The size %d of outputs must be equal to {outputs_size}.", output_size));{outputs_type_check} + IR_ENFORCE(output_size == {outputs_size}u, + "The size %d of outputs must be equal to {outputs_size}.", output_size);{outputs_type_check} }} VLOG(4) << "End Verifying for: {op_name}."; }} @@ -40,83 +40,83 @@ """ INPUT_TYPE_CHECK_TEMPLATE = """ - PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));""" + IR_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(), + "Type validation failed for the {index}th input.");""" INPUT_VECTORTYPE_CHECK_TEMPLATE = """ if (auto vec_type = (*this)->operand_source({index}).type().dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); ++i) {{ - PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + IR_ENFORCE(vec_type[i].isa<{standard}>(), + "Type validation failed for the {index}th input."); }} }} else {{ - PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + IR_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(), + "Type validation failed for the {index}th input."); }}""" INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """ if (auto val = (*this)->operand({index})) {{ - PADDLE_ENFORCE(val.type().isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + IR_ENFORCE(val.type().isa<{standard}>(), + "Type validation failed for the {index}th input."); }}""" INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """ if (auto val = (*this)->operand({index})) {{ if (auto vec_type = val.type().dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); i++) {{ - PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + IR_ENFORCE(vec_type[i].isa<{standard}>(), + "Type validation failed for the {index}th input."); }} }} else {{ - PADDLE_ENFORCE(val.type().isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + IR_ENFORCE(val.type().isa<{standard}>(), + "Type validation failed for the {index}th input."); }} }}""" ATTRIBUTE_CHECK_TEMPLATE = """ - PADDLE_ENFORCE(attributes.count("{attribute_name}")>0, - phi::errors::PreconditionNotMet("{attribute_name} does not exist.")); - PADDLE_ENFORCE(attributes.at("{attribute_name}").isa<{standard}>(), - phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not {standard}.")); + IR_ENFORCE(attributes.count("{attribute_name}")>0, + "{attribute_name} does not exist."); + IR_ENFORCE(attributes.at("{attribute_name}").isa<{standard}>(), + "Type of attribute: {attribute_name} is not {standard}."); """ ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """ - PADDLE_ENFORCE(attributes.count("{attribute_name}")>0, - phi::errors::PreconditionNotMet("{attribute_name} does not exist.")); - PADDLE_ENFORCE(attributes.at("{attribute_name}").isa(), - phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not pir::ArrayAttribute.")); + IR_ENFORCE(attributes.count("{attribute_name}")>0, + "{attribute_name} does not exist."); + IR_ENFORCE(attributes.at("{attribute_name}").isa(), + "Type of attribute: {attribute_name} is not pir::ArrayAttribute."); for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ - PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast().at(i).isa<{standard}>(), - phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); + IR_ENFORCE(attributes.at("{attribute_name}").dyn_cast().at(i).isa<{standard}>(), + "Type of attribute: {attribute_name} is not right."); }}""" OUTPUT_TYPE_CHECK_TEMPLATE = """ - PADDLE_ENFORCE((*this)->result({index}).type().isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));""" + IR_ENFORCE((*this)->result({index}).type().isa<{standard}>(), + "Type validation failed for the {index}th output.");""" OUTPUT_VECTORTYPE_CHECK_TEMPLATE = """ auto output_{index}_type = (*this)->result({index}).type(); if (auto vec_type = output_{index}_type.dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); i++) {{ - PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + IR_ENFORCE(vec_type[i].isa<{standard}>(), + "Type validation failed for the {index}th output."); }} }} else {{ - PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + IR_ENFORCE(output_{index}_type.isa<{standard}>(), + "Type validation failed for the {index}th output."); }}""" OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """ if (auto output_{index}_type = (*this)->result({index}).type()) {{ - PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + IR_ENFORCE(output_{index}_type.isa<{standard}>(), + "Type validation failed for the {index}th output."); }}""" OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """ if (auto output_{index}_type = (*this)->result({index}).type()) {{ if (auto vec_type = output_{index}_type.dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); ++i) {{ - PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + IR_ENFORCE(vec_type[i].isa<{standard}>(), + "Type validation failed for the {index}th output."); }} }} else {{ - PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + IR_ENFORCE(output_{index}_type.isa<{standard}>(), + "Type validation failed for the {index}th output."); }} }}""" diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index e2d17e7f11802..c2076b25cd514 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -68,7 +68,25 @@ OPS_API_TEMPLATE = """ {{"{name}", (PyCFunction)(void (*)(void)){name}, METH_VARARGS | METH_KEYWORDS, "C++ interface function for {name}."}},""" -NEED_GEN_STATIC_ONLY_APIS = ['fetch'] +NEED_GEN_STATIC_ONLY_APIS = [ + 'fetch', + 'fused_embedding_eltwise_layernorm', + 'fused_fc_elementwise_layernorm', + 'fused_multi_transformer_xpu', + 'fused_scale_bias_relu_conv_bnstats', + 'fusion_transpose_flatten_concat', + 'generate_sequence_xpu', + 'layer_norm_act_xpu', + 'multi_encoder_xpu', + 'multihead_matmul', + 'squeeze_excitation_block', + 'yolo_box_xpu', + 'fusion_gru', + 'fusion_seqconv_eltadd_relu', + 'fusion_seqexpand_concat_fc', + 'fused_attention', + 'fused_feedforward', +] NO_NEED_GEN_STATIC_ONLY_APIS = [ 'add_n_', @@ -83,11 +101,10 @@ 'c_reduce_sum', 'dpsgd', 'embedding_grad_sparse', - 'fused_attention', 'fused_batch_norm_act_', 'fused_bn_add_activation_', - 'fused_feedforward', 'fused_scale_bias_relu_conv_bnstats', + 'memcpy', 'print', 'recv_v2', 'rnn_', @@ -98,6 +115,7 @@ 'set_value_with_tensor', 'set_value_with_tensor_', 'shadow_feed', + 'sparse_momentum', ] diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py new file mode 100644 index 0000000000000..a7841e4d6d8af --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py @@ -0,0 +1,30 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ===================================== +# VjpInterface gen op list +# ===================================== +# we don't support vjp function code +# gen now, so we use a whitelist to +# control the generation of Vjp methods. +# TODO(wanghao107) +# remove this file and support Vjp methods +# code gen. + + +vjp_interface_black_list = [ + 'silu_grad', + 'fused_dropout_add', + 'fused_rotary_position_embedding', +] diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py deleted file mode 100644 index 58abcbf1143b9..0000000000000 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# ===================================== -# VjpInterface gen op list -# ===================================== -# we don't support vjp function code -# gen now, so we use a whitelist to -# control the generation of Vjp methods. -# TODO(wanghao107) -# remove this file and support Vjp methods -# code gen. - - -vjp_interface_declare_gen_op_list = [ - 'where', - "tanh", - "mean", - "divide", - "sum", - "add", - "concat", - "split", - "split_with_num", - "gelu", - "matmul", - "erf", - "multiply", - "pow", - "rsqrt", - "subtract", - "square", - "dropout", - 'exp', - 'expm1', - 'expand', - 'layer_norm', - 'reshape', - 'cast', - "scale", - 'softmax', - 'silu', - 'elementwise_pow', - 'embedding', - 'fused_softmax_mask_upper_triangle', - 'slice', - 'transpose', - 'slice_grad', - 'gather_nd', - 'stack', - 'poisson', - 'gumbel_softmax', - 'pad', - 'pad3d', - 'squeeze', - 'unsqueeze', - 'tril', - 'triu', - 'squeeze', - 'unsqueeze', - 'conv2d', - 'depthwise_conv2d', - 'sqrt', - 'flatten', - 'relu', - 'abs', - 'log', - 'clip', - 'ceil', - 'p_norm', - 'maximum', - 'argsort', - 'min', - 'max', - 'batch_norm', - 'max_pool2d_with_index', - 'pool2d', - 'minimum', - 'prod', - 'round', - 'sin', - 'cos', - 'dot', - 'floor', - 'topk', - 'square', - 'gather', - 'label_smooth', - 'cross_entropy_with_softmax', - 'mean_all', - 'cumsum', - 'linear_interp', - 'bilinear_interp', - 'trilinear_interp', - 'nearest_interp', - 'bicubic_interp', - 'assign', - 'assign_out_', - 'real', - 'flip', - 'softmax', - 'expand', - 'conv2d_transpose', - 'depthwise_conv2d_transpose', - 'sigmoid', - 'pad', - 'pad3d', - 'einsum', - 'leaky_relu', - 'log10', - 'conv3d', - 'solve', - 'diag', - 'trace', - 'tile', -] -vjp_interface_implementation_gen_op_list = [ - 'where', - "tanh", - "mean", - "divide", - "sum", - "add", - "concat", - "split", - "split_with_num", - "gelu", - "matmul", - "erf", - "multiply", - "subtract", - "pow", - "rsqrt", - "square", - "dropout", - 'exp', - 'expm1', - 'expand', - 'layer_norm', - 'reshape', - 'cast', - "scale", - 'softmax', - 'silu', - 'elementwise_pow', - 'embedding', - 'fused_softmax_mask_upper_triangle', - 'slice', - 'transpose', - 'slice_grad', - 'gather_nd', - 'stack', - 'poisson', - 'gumbel_softmax', - 'pad', - 'pad3d', - 'squeeze', - 'unsqueeze', - 'tril', - 'triu', - 'squeeze', - 'unsqueeze', - 'conv2d', - 'depthwise_conv2d', - 'sqrt', - 'flatten', - 'relu', - 'abs', - 'log', - 'clip', - 'ceil', - 'p_norm', - 'maximum', - 'argsort', - 'min', - 'max', - 'batch_norm', - 'max_pool2d_with_index', - 'pool2d', - 'minimum', - 'prod', - 'round', - 'sin', - 'cos', - 'dot', - 'floor', - 'topk', - 'square', - 'gather', - 'label_smooth', - 'cross_entropy_with_softmax', - 'mean_all', - 'cumsum', - 'linear_interp', - 'bilinear_interp', - 'trilinear_interp', - 'nearest_interp', - 'bicubic_interp', - 'assign', - 'assign_out_', - 'real', - 'flip', - 'softmax', - 'expand', - 'conv2d_transpose', - 'depthwise_conv2d_transpose', - 'sigmoid', - 'pad', - 'pad3d', - 'einsum', - 'leaky_relu', - 'log10', - 'conv3d', - 'solve', - 'diag', - 'trace', - 'tile', -] diff --git a/paddle/fluid/pir/dialect/operator/interface/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/interface/CMakeLists.txt deleted file mode 100644 index a6496585e7790..0000000000000 --- a/paddle/fluid/pir/dialect/operator/interface/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -# All source files of pd_op_dialect, except for the source file of op, which is generated in the compilation directory. -file(GLOB PD_INTERFACE_SRCS "*.cc") - -cc_library( - pd_interface - SRCS ${PD_INTERFACE_SRCS} - DEPS pir_core phi_utils) diff --git a/paddle/fluid/pir/dialect/operator/interface/decomp.h b/paddle/fluid/pir/dialect/operator/interface/decomp.h new file mode 100644 index 0000000000000..10a6e51e7db3c --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/decomp.h @@ -0,0 +1,52 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/pir/core/op_base.h" + +namespace paddle { +namespace dialect { +class DecompInterface : public pir::OpInterfaceBase { + public: + struct Concept { + explicit Concept( + std::vector> (*decomp)(pir::Operation* op)) + : decomp_(decomp) {} + std::vector> (*decomp_)(pir::Operation* op); + }; + + template + struct Model : public Concept { + static std::vector> Decomp(pir::Operation* op) { + return ConcreteOp::Decomp(op); + } + Model() : Concept(Decomp) {} + }; + + /// Constructor + DecompInterface(pir::Operation* op, Concept* impl) + : pir::OpInterfaceBase(op), impl_(impl) {} + + std::vector> Decomp(pir::Operation* op) { + return impl_->decomp_(op); + } + + private: + Concept* impl_; +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DecompInterface) diff --git a/paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h b/paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h new file mode 100644 index 0000000000000..7f9795acff8e1 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h @@ -0,0 +1,69 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/common/data_type.h" +#include "paddle/pir/core/op_base.h" + +namespace paddle { +namespace dialect { +class GetKernelTypeForVarInterface + : public pir::OpInterfaceBase { + public: + struct Concept { + explicit Concept(phi::DataType (*get_kernel_type_for_var)( + const std::string& var_name, + const phi::DataType& tensor_dtype, + const phi::DataType& expected_kernel_dtype)) + : get_kernel_type_for_var_(get_kernel_type_for_var) {} + + phi::DataType (*get_kernel_type_for_var_)( + const std::string& var_name, + const phi::DataType& tensor_dtype, + const phi::DataType& expected_kernel_dtype); + }; + + template + struct Model : public Concept { + static phi::DataType GetKernelTypeForVar( + const std::string& var_name, + const phi::DataType& tensor_dtype, + const phi::DataType& expected_kernel_dtype) { + return ConcreteOp::GetKernelTypeForVar( + var_name, tensor_dtype, expected_kernel_dtype); + } + + Model() : Concept(GetKernelTypeForVar) {} + }; + + /// Constructor + GetKernelTypeForVarInterface(pir::Operation* op, Concept* impl) + : pir::OpInterfaceBase(op), impl_(impl) {} + + phi::DataType GetKernelTypeForVar( + const std::string& var_name, + const phi::DataType& tensor_dtype, + const phi::DataType& expected_kernel_dtype) { + return impl_->get_kernel_type_for_var_( + var_name, tensor_dtype, expected_kernel_dtype); + } + + private: + Concept* impl_; +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::GetKernelTypeForVarInterface) diff --git a/paddle/fluid/pir/dialect/operator/interface/infermeta.h b/paddle/fluid/pir/dialect/operator/interface/infermeta.h index 958d2df369ed9..fe0f50a456008 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infermeta.h +++ b/paddle/fluid/pir/dialect/operator/interface/infermeta.h @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #pragma once #include "paddle/phi/core/infermeta_utils.h" diff --git a/paddle/fluid/pir/dialect/operator/interface/interface.cc b/paddle/fluid/pir/dialect/operator/interface/interface.cc index ce8bdb6c6829f..01d8045425bea 100644 --- a/paddle/fluid/pir/dialect/operator/interface/interface.cc +++ b/paddle/fluid/pir/dialect/operator/interface/interface.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/pir/dialect/operator/interface/decomp.h" +#include "paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" @@ -19,6 +21,8 @@ namespace paddle { namespace dialect { std::vector> VjpInterface::Vjp( pir::Operation* op, + const std::vector>& inputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { std::vector> out_grads_value; @@ -29,7 +33,7 @@ std::vector> VjpInterface::Vjp( } out_grads_value.emplace_back(std::move(grad_value)); } - return impl_->vjp_(op, out_grads_value, stop_gradients); + return impl_->vjp_(op, inputs, outputs, out_grads_value, stop_gradients); } } // namespace dialect } // namespace paddle @@ -37,3 +41,5 @@ std::vector> VjpInterface::Vjp( IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferMetaInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OpYamlInfoInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::VjpInterface) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DecompInterface) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::GetKernelTypeForVarInterface) diff --git a/paddle/fluid/pir/dialect/operator/interface/vjp.h b/paddle/fluid/pir/dialect/operator/interface/vjp.h index 4f2292c7b6c02..44d1731359beb 100644 --- a/paddle/fluid/pir/dialect/operator/interface/vjp.h +++ b/paddle/fluid/pir/dialect/operator/interface/vjp.h @@ -22,11 +22,15 @@ class VjpInterface : public pir::OpInterfaceBase { struct Concept { explicit Concept(std::vector> (*vjp)( pir::Operation* op, + const std::vector>& inputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients)) : vjp_(vjp) {} std::vector> (*vjp_)( pir::Operation* op, + const std::vector>& inputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients); }; @@ -35,9 +39,11 @@ class VjpInterface : public pir::OpInterfaceBase { struct Model : public Concept { static std::vector> Vjp( pir::Operation* op, + const std::vector>& inputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { - return ConcreteOp::Vjp(op, out_grads, stop_gradients); + return ConcreteOp::Vjp(op, inputs, outputs, out_grads, stop_gradients); } Model() : Concept(Vjp) {} @@ -49,13 +55,17 @@ class VjpInterface : public pir::OpInterfaceBase { std::vector> Vjp( pir::Operation* op, + const std::vector>& inputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { - return impl_->vjp_(op, out_grads, stop_gradients); + return impl_->vjp_(op, inputs, outputs, out_grads, stop_gradients); } std::vector> Vjp( pir::Operation* op, + const std::vector>& inputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients); diff --git a/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt deleted file mode 100644 index 7954e000baf51..0000000000000 --- a/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt +++ /dev/null @@ -1,214 +0,0 @@ -set(PD_DIALECT_BINARY_DIR - "${PADDLE_BINARY_DIR}/paddle/fluid/pir/dialect/operator/ir") - -# Generate pd_op_dialect files defining op using op_gen_file -set(op_gen_parsed_yaml_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parse_op.py) - -set(op_gen_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/op_gen.py) -set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) -set(op_forward_yaml_file1 - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml -) -set(op_forward_yaml_file2 - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml -) -set(op_backward_yaml_file1 - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/backward_ops.parsed.yaml -) -set(op_backward_yaml_file2 - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml -) -set(fused_op_forward_yaml_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_ops.parsed.yaml -) -set(fused_op_backward_yaml_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_backward.parsed.yaml -) - -set(pd_op_forward_yaml_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/ops.yaml) - -set(pd_op_backward_yaml_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml) - -set(parsed_op_dir - ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated) - -set(op_yaml_file3 ${parsed_op_dir}/ops.parsed.yaml) -set(op_yaml_file4 ${parsed_op_dir}/ops_backward.parsed.yaml) - -set(op_yaml_files - ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${fused_op_forward_yaml_file},${fused_op_backward_yaml_file},${op_yaml_file3},${op_yaml_file4} -) -set(op_namespace paddle,dialect) -set(dialect_name pd_op) -set(op_header_file ${PD_DIALECT_BINARY_DIR}/pd_op.h) -set(op_source_file ${PD_DIALECT_BINARY_DIR}/pd_op.cc) -set(op_header_file_tmp ${op_header_file}.tmp) -set(op_source_file_tmp ${op_source_file}.tmp) - -set(op_vjp_source_file ${PD_DIALECT_BINARY_DIR}/pd_op_vjp.cc) -set(op_vjp_source_file_tmp ${op_vjp_source_file}.tmp) - -add_custom_command( - OUTPUT ${op_yaml_file3} ${op_yaml_file4} - COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_op_dir} - COMMAND ${PYTHON_EXECUTABLE} ${op_gen_parsed_yaml_file} --op_yaml_path - ${pd_op_forward_yaml_file} --output_path ${op_yaml_file3} - COMMENT "Generate pd_ops.parsed.yaml" - COMMAND ${PYTHON_EXECUTABLE} ${op_gen_parsed_yaml_file} --op_yaml_path - ${pd_op_backward_yaml_file} --output_path ${op_yaml_file4} --backward - COMMENT "Generate pd_ops_backward.parsed.yaml" - DEPENDS ${op_gen_parsed_yaml_file} ${pd_op_forward_yaml_file} - ${pd_op_backward_yaml_file} - VERBATIM) - -add_custom_command( - OUTPUT ${op_header_file} ${op_source_file} ${op_vjp_source_file} - COMMAND - ${PYTHON_EXECUTABLE} ${op_gen_file} --op_yaml_files ${op_yaml_files} - --op_compat_yaml_file ${op_compat_yaml_file} --namespaces ${op_namespace} - --dialect_name ${dialect_name} --op_def_h_file ${op_header_file_tmp} - --op_def_cc_file ${op_source_file_tmp} --op_vjp_cc_file - ${op_vjp_source_file_tmp} - COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_header_file_tmp} - ${op_header_file} - COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_source_file_tmp} - ${op_source_file} - COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_vjp_source_file_tmp} - ${op_vjp_source_file} - COMMENT - "copy_if_different ${op_header_file} ${op_source_file} ${op_vjp_source_file}" - DEPENDS ${op_gen_file} - ${op_forward_yaml_file1} - ${op_forward_yaml_file2} - ${op_backward_yaml_file1} - ${op_backward_yaml_file2} - ${op_compat_yaml_file} - ${op_yaml_file3} - ${op_yaml_file4} - VERBATIM) - -set(api_gen_yaml_files - ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3},${op_yaml_file4} -) -set(api_gen_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/api_gen.py) -set(api_header_file ${PD_DIALECT_BINARY_DIR}/pd_api.h) -set(api_source_file ${PD_DIALECT_BINARY_DIR}/pd_api.cc) -set(api_header_file_tmp ${api_header_file}.tmp) -set(api_source_file_tmp ${api_source_file}.tmp) - -add_custom_command( - OUTPUT ${api_header_file} ${api_source_file} - COMMAND - ${PYTHON_EXECUTABLE} ${api_gen_file} --op_yaml_files ${api_gen_yaml_files} - --op_compat_yaml_file ${op_compat_yaml_file} --namespaces ${op_namespace} - --api_def_h_file ${api_header_file_tmp} --api_def_cc_file - ${api_source_file_tmp} - COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_header_file_tmp} - ${api_header_file} - COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp} - ${api_source_file} - COMMENT "copy_if_different ${api_header_file} ${api_source_file}" - DEPENDS ${api_gen_file} - ${op_forward_yaml_file1} - ${op_forward_yaml_file2} - ${op_backward_yaml_file1} - ${op_backward_yaml_file2} - ${op_compat_yaml_file} - ${op_yaml_file3} - ${op_yaml_file4} - VERBATIM) - -set(python_c_gen_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/python_c_gen.py) -set(python_c_header_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/static_op_function.h) -set(python_c_source_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/static_op_function.cc) -set(python_c_header_file_tmp ${python_c_header_file}.tmp) -set(python_c_source_file_tmp ${python_c_source_file}.tmp) - -add_custom_command( - OUTPUT ${python_c_header_file} ${python_c_source_file} - COMMAND - ${PYTHON_EXECUTABLE} ${python_c_gen_file} --op_yaml_files - ${api_gen_yaml_files} --op_compat_yaml_file ${op_compat_yaml_file} - --namespaces "paddle,pybind" --python_c_def_h_file - ${python_c_header_file_tmp} --python_c_def_cc_file - ${python_c_source_file_tmp} - COMMAND ${CMAKE_COMMAND} -E copy_if_different ${python_c_header_file_tmp} - ${python_c_header_file} - COMMAND ${CMAKE_COMMAND} -E copy_if_different ${python_c_source_file_tmp} - ${python_c_source_file} - COMMENT "copy_if_different ${python_c_header_file} ${python_c_source_file}" - DEPENDS ${python_c_gen_file} - ${op_forward_yaml_file1} - ${op_forward_yaml_file2} - ${op_backward_yaml_file1} - ${op_backward_yaml_file2} - ${op_compat_yaml_file} - ${op_yaml_file3} - ${op_yaml_file4} - VERBATIM) - -add_custom_target(static_op_function_gen ALL DEPENDS ${python_c_header_file} - ${python_c_source_file}) - -set(ops_api_gen_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py) -set(ops_api_source_file ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/ops_api.cc) -set(ops_api_source_file_tmp ${ops_api_source_file}.tmp) - -add_custom_command( - OUTPUT ${ops_api_source_file} - COMMAND - ${PYTHON_EXECUTABLE} ${ops_api_gen_file} --op_yaml_files - ${api_gen_yaml_files} --op_compat_yaml_file ${op_compat_yaml_file} - --namespaces "paddle,pybind" --ops_api_file ${ops_api_source_file_tmp} - COMMAND ${CMAKE_COMMAND} -E copy_if_different ${ops_api_source_file_tmp} - ${ops_api_source_file} - COMMENT "copy_if_different ${ops_api_source_file}" - DEPENDS ${ops_api_gen_file} - ${op_forward_yaml_file1} - ${op_forward_yaml_file2} - ${op_backward_yaml_file1} - ${op_backward_yaml_file2} - ${op_compat_yaml_file} - ${python_c_header_file} - ${python_c_source_file} - VERBATIM) - -add_custom_target(ops_api_gen ALL DEPENDS ${ops_api_source_file}) - -cc_library( - pd_op_dialect_core - SRCS op_attribute.cc op_type.cc meta_tensor.cc - DEPS phi pd_interface pd_trait type_info) -cc_library( - pd_op_dialect_op - SRCS ${op_source_file} manual_op.cc control_flow_op.cc - DEPS pd_op_dialect_core pir_control_flow) -cc_library( - api_builder - SRCS api_builder.cc - DEPS pir_core) -cc_library( - pd_op_dialect_api - SRCS ${api_source_file} manual_api.cc - DEPS api_builder pd_op_dialect_op pd_op_dialect_utils) -if((CMAKE_CXX_COMPILER_ID STREQUAL "GNU")) - set_target_properties(pd_op_dialect_api PROPERTIES COMPILE_FLAGS - "-Wno-maybe-uninitialized") -endif() - -target_include_directories(pd_op_dialect_api INTERFACE ${PD_DIALECT_BINARY_DIR}) - -cc_library( - pd_op_dialect - SRCS op_dialect.cc manual_op_vjp.cc ${op_vjp_source_file} - DEPS pd_op_dialect_api param_to_variable primitive_vjp_experimental - pd_op_dialect_utils op_yaml_info_parser) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index c235799633896..28ca22d30de2f 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -23,7 +23,7 @@ paddle::dialect::IfOp, paddle::dialect::WhileOp #include "paddle/pir/core/ir_printer.h" #include "paddle/pir/core/operation_utils.h" #include "paddle/pir/core/utils.h" -#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" namespace paddle { namespace dialect { diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 0f636e01e19a3..00ba7da80aa25 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -429,9 +429,9 @@ OpInfoTuple FusedGemmEpilogueOp::GetOpInfo() { paddle::dialect::OpRunTimeInfo run_time_info( "FusedGemmEpilogueInferMeta", {"x", "y", "bias", "trans_x", "trans_y", "activation"}, - "", - {""}, - {""}, + {"fused_gemm_epilogue"}, + {"x", "y", "bias", "trans_x", "trans_y", "activation"}, + {}, {}, {}, {}); @@ -674,9 +674,15 @@ OpInfoTuple FusedGemmEpilogueGradOp::GetOpInfo() { "trans_x", "trans_y", "activation_grad"}, - "", - {""}, - {""}, + {"fused_gemm_epilogue_grad"}, + {"x", + "y", + "reserve_space", + "out_grad", + "trans_x", + "trans_y", + "activation_grad"}, + {}, {}, {}, {}); diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 317ce64feea08..cda6eb596d21e 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -16,6 +16,7 @@ #include #include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" @@ -34,7 +35,8 @@ namespace dialect { class AddNOp : public pir::Op { + paddle::dialect::VjpInterface, + paddle::dialect::DecompInterface> { public: using Op::Op; static const char *name() { return "pd_op.add_n"; } @@ -51,8 +53,11 @@ class AddNOp : public pir::Op> Vjp( pir::Operation *op, + const std::vector> &inputs_, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); + static std::vector> Decomp(pir::Operation *op); }; class AddN_Op : public pir::Op> AddNOp::Vjp( pir::Operation* op, + const std::vector>& inputs_, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { - AddNOp op_obj = op->dyn_cast(); - VLOG(6) << "Prepare inputs of add_n_grad"; + PADDLE_ENFORCE_EQ( + inputs_.size(), + 1u, + platform::errors::InvalidArgument( + "addn op's inputs size should be 1 but now is %d", inputs_.size())); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1u, + platform::errors::InvalidArgument( + "addn op's outputs size should be 1 but now is %d", outputs.size())); PADDLE_ENFORCE( - op_obj.inputs() != nullptr, - paddle::platform::errors::Fatal("addn op's inputs can't be null")); - pir::CombineOp combine_op_obj = op_obj.inputs() - .dyn_cast() - .owner() - ->dyn_cast(); + inputs_[0].size() != 0, + paddle::platform::errors::Fatal("addn op's inputs[0] can't be null")); std::vector inputs; - for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) { + for (size_t idx = 0; idx < inputs_[0].size(); idx++) { inputs.emplace_back( - std::make_shared(combine_op_obj.inputs()[idx])); + std::make_shared(inputs_[0][idx])); } Tensor out_grad(std::make_shared(out_grads[0][0])); diff --git a/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc b/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc index f10db043d1523..3134214cf9029 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc @@ -35,6 +35,8 @@ phi::Scalar ScalarAttribute::data() { return phi::Scalar(dyn_cast().data()); } else if (isa()) { return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); } else if (isa()) { return phi::Scalar(dyn_cast().data()); } else if (isa()) { diff --git a/paddle/fluid/pir/dialect/operator/ir/op_attribute.h b/paddle/fluid/pir/dialect/operator/ir/op_attribute.h index 6b9edf98cb56a..0b0973a5205c8 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_attribute.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_attribute.h @@ -48,6 +48,7 @@ class ScalarAttribute : public pir::Attribute { (val.type_id() == pir::FloatAttribute::type_id()) || (val.type_id() == pir::DoubleAttribute::type_id()) || (val.type_id() == pir::Int32Attribute::type_id()) || + (val.type_id() == pir::IndexAttribute::type_id()) || (val.type_id() == pir::Int64Attribute::type_id()) || (val.type_id() == pir::StrAttribute::type_id()); } diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 5a2da284142ad..d73558cf81fb9 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -7,7 +7,6 @@ kernel: func: add_n param: [inputs] - backward: add_n_grad - op : add_n_with_kernel args : (Tensor[] inputs) @@ -18,7 +17,6 @@ kernel: func: add_n param: [inputs] - backward: add_n_grad - op : assert args : (Tensor cond, Tensor[] data, int64_t summarize = -1) @@ -83,6 +81,16 @@ args : (Tensor[] x) output : Tensor(out) +- op : memcpy + args : (Tensor x, int dst_place_type) + output : Tensor(out) + infer_meta: + func: UnchangedInferMeta + param: [x] + kernel: + func : memcpy + param: [x, dst_place_type] + - op : print args : (Tensor in, int first_n, str message, int summarize, bool print_tensor_name = true, bool print_tensor_type = true, bool print_tensor_shape = true, bool print_tensor_layout = true, bool print_tensor_lod = true, str print_phase = "BOTH", bool is_forward = true) output : Tensor(out) @@ -132,7 +140,7 @@ param : [x, ring_id, dynamic_shape, peer, use_calc_stream] - op : set_value - args : (Tensor x, int64_t[] starts, int64_t[] ends, int64_t[] steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values) + args : (Tensor x, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values) output : Tensor(out) infer_meta: func: SetValueInferMeta @@ -144,7 +152,7 @@ backward: set_value_grad - op : set_value_with_tensor - args : (Tensor x, Tensor values, int64_t[] starts, int64_t[] ends, int64_t[] steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) + args : (Tensor x, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) output : Tensor(out) infer_meta: func: SetValueInferMeta @@ -175,7 +183,6 @@ - op : write_to_array args : (Tensor i, Tensor x) output : Tensor[](out) - backward: write_to_array_grad - op: dpsgd args: (Tensor param, Tensor grad, Tensor learning_rate, float clip = 10.0f, float batch_size = 16.0f, float sigma = 1.0f, int seed = 0) @@ -207,3 +214,14 @@ func: FusedFeedForwardInferMeta optional: dropout1_seed, dropout2_seed, linear1_bias, linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias, ln2_mean, ln2_variance, ln1_mean, ln1_variance, ln1_out backward: fused_feedforward_grad + +- op: sparse_momentum + args: (Tensor param, Tensor grad, Tensor velocity, Tensor index, Tensor learning_rate, Tensor master_param,float mu, Scalar axis=0, bool use_nesterov=false,str regularization_method="", float regularization_coeff=0.0f, bool multi_precision=false, float rescale_grad=1.0f) + output: Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out) + infer_meta: + func: SparseMomentumInferMeta + param: [param, learning_rate, velocity] + kernel: + func: sparse_momentum + data_type: param + optional: master_param, master_param_out diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml index 95e3d99bd573b..81213383e3fcf 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml @@ -19,7 +19,7 @@ optional: linear1_bias, linear2_bias, ln1_scale, ln1_bias, ln1_out, ln1_mean, ln1_variance, ln2_scale, ln2_bias, ln2_mean, ln2_variance, dropout2_out, ln1_scale_grad, ln1_bias_grad, ln2_scale_grad, ln2_bias_grad, linear2_bias_grad - backward_op : set_value_grad - args : (Tensor out_grad, Tensor values, int64_t[] starts, int64_t[] ends, int64_t[] steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) + args : (Tensor out_grad, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) output : Tensor(x_grad), Tensor(values_grad) infer_meta: func: SetValueGradInferMeta diff --git a/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml b/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml new file mode 100644 index 0000000000000..de542e68f30b9 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml @@ -0,0 +1,14 @@ +# Ops in this file is only used for pir currently and will replace ops of legacy_ops.yaml/ops.yaml of PHI in future. + +- op : arange + args : (Scalar start, Scalar end, Scalar step, DataType dtype=DataType::FLOAT64, Place place=CPUPlace()) + output : Tensor(out) + infer_meta : + func : ArangeInferMeta + param : [start, end, step, dtype] + kernel : + func : arange + param : [start, end, step] + data_type : dtype + backend : place + support_tensor : [start, end, step] diff --git a/paddle/fluid/pir/dialect/operator/trait/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/trait/CMakeLists.txt deleted file mode 100644 index 0689edb35655e..0000000000000 --- a/paddle/fluid/pir/dialect/operator/trait/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -file(GLOB PD_INTERFACE_SRCS "*.cc") - -cc_library( - pd_trait - SRCS ${PD_INTERFACE_SRCS} - DEPS pir_core) diff --git a/paddle/fluid/pir/dialect/operator/transforms/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/transforms/CMakeLists.txt deleted file mode 100644 index 7116a12be50ef..0000000000000 --- a/paddle/fluid/pir/dialect/operator/transforms/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -cc_library( - param_to_variable - SRCS param_to_variable.cc - DEPS pd_op_dialect_core) diff --git a/paddle/fluid/pir/dialect/operator/utils/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/utils/CMakeLists.txt deleted file mode 100644 index 58eafb2cc3921..0000000000000 --- a/paddle/fluid/pir/dialect/operator/utils/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -cc_library(op_yaml_info_parser SRCS op_yaml_info_parser.cc) -cc_library( - pd_op_dialect_utils - SRCS utils.cc - DEPS pd_op_dialect_core) diff --git a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc index 5452cd6f47f30..bf752a089b4f6 100644 --- a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc +++ b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc @@ -17,8 +17,9 @@ namespace paddle { namespace dialect { -OpYamlInfoParser::OpYamlInfoParser(const OpInfoTuple& op_info_tuple) - : op_info_tuple_(op_info_tuple) { +OpYamlInfoParser::OpYamlInfoParser(const OpInfoTuple& op_info_tuple, + bool is_legacy_op) + : op_info_tuple_(op_info_tuple), is_legacy_op_(is_legacy_op) { parse(); } @@ -210,7 +211,9 @@ void OpYamlInfoParser::parse() { } for (auto& name : runtime_info.kernel_param) { - if (input_name2id_.count(name) && !input_info_[name].is_mutable_attribute) { + if ((input_name2id_.count(name) && + (!input_info_[name].is_mutable_attribute)) || + (is_legacy_op_ && input_info_[name].is_mutable_attribute)) { kernel_fn_tensor_params_.push_back(name); } else { kernel_fn_attr_params_.push_back(name); diff --git a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h index 0a972ced0ef41..fd4004730c906 100644 --- a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h +++ b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h @@ -23,7 +23,8 @@ class OpYamlInfoParser { public: OpYamlInfoParser() = delete; - explicit OpYamlInfoParser(const OpInfoTuple& op_info_tuple); + explicit OpYamlInfoParser(const OpInfoTuple& op_info_tuple, + bool is_legacy_op = false); bool IsTensorAttribute(size_t index) const; size_t InputTensorNumber() const; @@ -74,6 +75,7 @@ class OpYamlInfoParser { } OpInfoTuple op_info_tuple_; + bool is_legacy_op_; // input info std::map input_name2id_; diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index d2f6fc56d2a17..8961c70569c8b 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -25,6 +25,8 @@ const std::unordered_set LegacyOpList = { "pd_op.c_broadcast_", "pd_op.c_sync_calc_stream_", "pd_op.c_sync_comm_stream_", + "pd_op.fused_gemm_epilogue", + "pd_op.fused_gemm_epilogue_grad", "pd_op.dpsgd", "pd_op.send_v2", "pd_op.recv_v2", @@ -35,7 +37,8 @@ const std::unordered_set LegacyOpList = { "pd_op.c_allreduce_max_", "pd_op.c_allgather", "pd_op.seed", - "pd_op.share_data"}; + "pd_op.share_data", + "pd_op.sparse_momentum"}; enum class AttrType { UNDEFINED = 0, @@ -201,5 +204,24 @@ bool IsEmptyValue(const pir::Value& value) { return !value.impl() || !value.type(); } +std::vector GetInt64Vector(const pir::Attribute& attr) { + PADDLE_ENFORCE_EQ(attr.isa(), + true, + phi::errors::PreconditionNotMet( + "attribute MUST be a pir::ArrayAttribute")); + auto attr_vec = attr.dyn_cast().AsVector(); + + std::vector vec_int64; + for (auto vec_element : attr_vec) { + PADDLE_ENFORCE_EQ( + vec_element.isa(), + true, + phi::errors::PreconditionNotMet("element MUST be a Int64Attribute")); + vec_int64.push_back(vec_element.dyn_cast().data()); + } + + return vec_int64; +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.h b/paddle/fluid/pir/dialect/operator/utils/utils.h index 1c228e7e85083..e35d7fa74cc64 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.h +++ b/paddle/fluid/pir/dialect/operator/utils/utils.h @@ -133,5 +133,7 @@ bool IsLegacyOp(const std::string& name); bool IsEmptyValue(const pir::Value& value); +std::vector GetInt64Vector(const pir::Attribute& attr); + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/drr/CMakeLists.txt b/paddle/fluid/pir/drr/CMakeLists.txt new file mode 100644 index 0000000000000..6643f303926eb --- /dev/null +++ b/paddle/fluid/pir/drr/CMakeLists.txt @@ -0,0 +1,102 @@ +file(GLOB DRR_SRCS "*.cc" "api/*.cc") + +set(op_creator_gen_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py +) +set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) +set(op_forward_yaml_file1 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml +) + +set(op_forward_yaml_file2 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml +) +set(op_forward_yaml_file3 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/update_ops.parsed.yaml +) +set(op_backward_yaml_file1 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/backward_ops.parsed.yaml +) +set(op_backward_yaml_file2 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml +) +set(fused_op_forward_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_ops.parsed.yaml +) +set(fused_op_backward_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_backward.parsed.yaml +) + +set(cinn_op_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/generated/ops.parsed.yaml) + +set(cinn_op_yaml_source_file + ${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/operator/ir/ops.yaml) + +set(parsed_op_dir + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated) + +set(op_yaml_file3 ${parsed_op_dir}/ops.parsed.yaml) +set(op_yaml_file4 ${parsed_op_dir}/ops_backward.parsed.yaml) +set(op_yaml_file5 ${parsed_op_dir}/update_ops.parsed.yaml) + +set(op_yaml_files + ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${fused_op_forward_yaml_file},${fused_op_backward_yaml_file},${op_yaml_file3},${op_yaml_file4},${op_yaml_file5} +) + +set(op_creator_file + ${PADDLE_BINARY_DIR}/paddle/fluid/pir/drr/ir_op_factory_generated.cc) +set(op_creator_file_tmp ${op_creator_file}.tmp) + +set(dialect_name pd_op) + +set(cinn_op_creator_file + ${PADDLE_BINARY_DIR}/paddle/fluid/pir/drr/cinn_op_factory_generated.cc) +set(cinn_op_creator_file_tmp ${cinn_op_creator_file}.tmp) + +set(cinn_dialect_name cinn_op) + +add_custom_command( + OUTPUT ${op_creator_file} + COMMAND + ${PYTHON_EXECUTABLE} ${op_creator_gen_file} --op_yaml_files ${op_yaml_files} + --op_compat_yaml_file ${op_compat_yaml_file} --dialect_name ${dialect_name} + --op_creator_file ${op_creator_file_tmp} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_creator_file_tmp} + ${op_creator_file} + COMMENT "copy_if_different ${op_creator_file}" + DEPENDS ${op_creator_gen_file} + ${op_forward_yaml_file1} + ${op_forward_yaml_file2} + ${op_backward_yaml_file1} + ${op_backward_yaml_file2} + ${op_compat_yaml_file} + ${op_yaml_file3} + ${op_yaml_file4} + op_dialect + VERBATIM) + +if(WITH_CINN AND NOT CINN_ONLY) + add_custom_command( + OUTPUT ${cinn_op_creator_file} + COMMAND + ${PYTHON_EXECUTABLE} ${op_creator_gen_file} --op_yaml_files + ${cinn_op_yaml_file} --op_compat_yaml_file ${op_compat_yaml_file} + --dialect_name ${cinn_dialect_name} --op_creator_file + ${cinn_op_creator_file_tmp} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${cinn_op_creator_file_tmp} + ${cinn_op_creator_file} + COMMENT "copy_if_different ${cinn_op_creator_file}" + DEPENDS ${op_creator_gen_file} ${op_compat_yaml_file} + ${cinn_op_yaml_source_file} op_dialect cinn_op_dialect + VERBATIM) + set(CINN_SOURCE_FILE ${cinn_op_creator_file}) + + set(CINN_DEPS cinn_op_dialect) + +endif() + +cc_library( + drr + SRCS ${DRR_SRCS} ${op_creator_file} ${CINN_SOURCE_FILE} + DEPS op_dialect_vjp ${CINN_DEPS} pir) diff --git a/paddle/fluid/pir/drr/README.md b/paddle/fluid/pir/drr/README.md new file mode 100644 index 0000000000000..4abdbb1b64717 --- /dev/null +++ b/paddle/fluid/pir/drr/README.md @@ -0,0 +1,230 @@ +# DRR (Declarative Rewrite Rule) Tool User Manual +--- +## 1. Related Background + +PASS is a crucial component for optimizing intermediate representations (IR), and the transformation of DAG-to-DAG (Replace a subgraph of the directed acyclic graph (DAG) type in the original graph with another subgraph) is the most common type of Pass. The transformation of DAG-to-DAG can be divided into two steps: matching and rewriting. Matching refers to the complete matching of a known subgraph to the corresponding target subgraph in the Program, while rewriting refers to replacing the matched graph with a new subgraph. + +DRR can reduce the development cost of PASS, allowing developers to focus on processing optimization logic without caring about the data structure of the underlying IR. After the developer declares the pattern of the target subgraph and the new subgraph to be replaced through a set of simple and easy-to-use interfaces, DRR can automatically match the original subgraph in the Program and replace it with the new subgraph. + +Taking PASS to eliminate redundant CastOp as an example, the code example developed using DRR is as follows: +~~~ c++ +// 1. Inherit specialized template class from DrPatternBase +class RemoveRedundentCastPattern + : public pir::drr::DrrPatternBase { + // 2. Overload operator() + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // 3. Define a SourcePattern containing two consecutive CastOps using Op, Tensor, and Attribute + auto pat = ctx->SourcePattern(); + + pat.Tensor("tmp") = // CastOp output Tensor named "tmp" + pat.Op(paddle::dialect::CastOp::name(), // Pass in the name of the CastOp + {{"dtype", pat.Attr("dtype1")}}) // The corresponding globally unique ID of the "dtype" attribute of CastOp is "dtype1" + (pat.Tensor("arg0")); // The input Tensor of CastOp is "arg0" + pat.Tensor("ret") = + pat.Op(paddle::dialect::CastOp::name(), + {{"dtype", pat.Attr("dtype2")}})(pat.Tensor("tmp")); + // 4. Define Constrain + pat.RequireEqual(pat("tmp").dtype(), pat.Tensor("ret").dtype()); + + // 5. Define ResultPattern + auto res = pat.ResultPattern(); + res.Tensor("ret") = + res.Op(paddle::dialect::CastOp::name(), + {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); + } +}; +~~~ + +DRR PASS contains the following three parts: ++ `Source Pattern`:used to describe the target subgraph to be matched in Program ++ `Constrains`:used to specify constraints for SourcePattern matching(nonessential) ++ `Result Pattern`:Used to describe the subgraph that needs to be replaced by +Developers only need to define `SourcePattern`, `Constrains` and `ResultPattern` to implement a complete PASS. + +**Note:** +1. **DRR only supports matching and replacing the closed SourcePattern and ResultPattern (except for the Pattern input and output Tensor, all internal Tensors cannot be used by the Pattern external Op). If the defined Pattern is not closed in the Program, the matching will fail.** +2. **The input and output of ResultPattern need to be a subset of the input and output of SourcePattern.** +## 2. Interface List + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Class Function Function Description Parameter Interpretation
DrrPatternBase
 virtual void operator()(
+        pir::drr::DrrPatternContext* ctx) const 
Implement the entry function of DRR PASS ctx: Context parameters required to create Patten
SourcePattern
 const drr::Op& Op(
+    const std::string& op_type,
+    const std::unordered_map<std::string, Attribute>& attributes)
Define an Op in the SourcePattern op_type: The defined Op name. Can be obtained through paddle::dialect::xxOp::name() interface
attributes : Attribute information of the created Op
 const drr::Tensor& Tensor(
+        const std::string& tensor_name) 
Define a tensor named tensor_name in SourcePattern tensor_name: The name of the defined Tensor must be unique within the SourcePattern
 Attribute Attr(
+        const std::string& attr_name) const 
Define an attribute named attr_name in SourcePattern attr_name: The name of the attribute, which needs to be unique within SourcePattern
 void RequireEqual(
+        const TensorShape& first,
+        const TensorShape& second)
Requires the TensorShape of the two Tensors in SourcePattern to be the same first: first TensorShape
second : second TensorShape
 void RequireEqual(
+        const TensorDataType& first,
+        const TensorDataType& second)
The data types of the two Tensors in SourcePattern are required to be the same first: DataType of the first Tensor
second : DataType of the second Tensor
void RequireNativeCall(
+        const std::function<bool(const MatchContext&)>& custom_fn)
Define a constraint in SourcePattern. You can use this interface and lambda expressions to implement custom constraints on SourcePattern. custom_fn: Customized constraint functions
ResultPattern
 const drr::Op& Op(
+    const std::string& op_type,
+    const std::unordered_map<std::string, Attribute>&  attributes) 
Define an Op in ResultPattern op_type: The defined Op name. Can be obtained through paddle::dialect::xxOp::name() interface
attributes : Attribute information of the created Op
const drr::Tensor& Tensor(
+        const std::string& tensor_name)
Define a tensor named tensor_name in ResultPattern tensor_name: The name of the defined Tensor must be unique within the ResultPattern
Attribute Attr(
+        const std::string& attr_name) const 
Define an attribute named attr_name in ResultPattern attr_name: The name of the attribute must be unique within the ResultPattern
using AttrComputeFunc = std::function<std::any(const MatchContext&)>;
+Attribute Attr(const AttrComputeFunc& attr_compute_func) const
Create an Attribute through a custom calculation logic AttrComputeFuncattr_compute_func: Customized calculation logic
drr::Tensor& NoneTensor()
When the input Tensor of an Op is optional and not needed, NoneTensor needs to be used to occupy the place. /
TensorShape
explicit TensorShape(
+        const std::string& tensor_name) 
Abstract the class that describes the shape of Tensor tensor_name: The name of the Tensor being described
 const std::string& tensor_name() const
Obtain the name of Tensor /
TensorDataType
explicit TensorDataType(
+        const std::string& tensor_name)
An abstract class that describes the data types of elements in Tensor tensor_name: The name of the Tensor being described
 const std::string& tensor_name() const
Obtain the name of Tensor /
DrrPatternContext
drr::SourcePattern DrrPatternContext::SourcePattern()
Create a SourcePattern object and return /
+ +## 3 Example +Example 1: Matmul + Add -> FusedGemmEpilogue +~~~ c++ +class FusedLinearPattern : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // Define SourcePattern + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + + pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w")); + pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); + + // Define ResultPattern + pir::drr::ResultPattern res = pat.ResultPattern(); + // Define Constrain + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "none"; + }); + const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), + {{{"trans_x", pat.Attr("trans_x")}, + {"trans_y", pat.Attr("trans_y")}, + {"activation", act_attr}}}); + fused_gemm_epilogue( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out")}); + } +}; +~~~ + +Example 2: Full + Expand -> Full +~~~ c++ +class FoldExpandToConstantPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // Define SourcePattern + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &full1 = pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape_1")}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); + const auto &full_int_array1 = + pat.Op(paddle::dialect::FullIntArrayOp::name(), + {{"value", pat.Attr("expand_shape_value")}, + {"dtype", pat.Attr("dtype_2")}, + {"place", pat.Attr("place_2")}}); + const auto &expand = pat.Op(paddle::dialect::ExpandOp::name()); + pat.Tensor("ret") = expand(full1(), full_int_array1()); + + // Define ResultPattern + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &full2 = res.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("expand_shape_value")}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); + res.Tensor("ret") = full2(); + } +}; +~~~ diff --git a/paddle/fluid/pir/drr/README_cn.md b/paddle/fluid/pir/drr/README_cn.md new file mode 100644 index 0000000000000..456bf7921414b --- /dev/null +++ b/paddle/fluid/pir/drr/README_cn.md @@ -0,0 +1,233 @@ +# DRR( Declarative Rewrite Rule) PASS用户使用手册 +--- +## 1. 相关背景 + +PASS 是对 IR 进行优化的关键组件,而 DAG-to-DAG 的变换(将原图中的一个 DAG 子图替换成另一个 DAG 子图)是最常见的Pass类型。DAG-to-DAG 的变换可以划分为匹配和重写两个步骤:匹配是根据已知子图在 Program 中完全匹配到对应的目标子图,重写是将匹配到的图结构替换为新的子图。 + +DRR ( Declarative Rewrite Rule ) 是来处理这种 DAG-to-DAG 类型的一套 PASS 组件。DRR 能降低 PASS 的开发成本,让开发者集中在对优化逻辑的处理上,而不需要关心底层 IR 的数据结构。开发者通过一套简洁易用的接口对目标子图和需要替换成的新子图进行模式声明后,DRR 就能自动的在 Program 中对原图进行匹配,并替换成新子图。 + +以消除冗余 CastOp 的 PASS 为例,使用 DRR 的代码开发示例如下: +~~~ c++ +// 1. 继承 DrrPatternBase 的特化模板类 +class RemoveRedundentCastPattern + : public pir::drr::DrrPatternBase { + // 2. 重载 operator() + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // 3. 使用 Op、Tensor 和 Attribute 定义一个包含两个连续 CastOp 的 SourcePattern + auto pat = ctx->SourcePattern(); + + pat.Tensor("tmp") = // CastOp 输出 Tensor 命名为"tmp" + pat.Op(paddle::dialect::CastOp::name(), // 传入 CastOp 的 name + {{"dtype", pat.Attr("dtype1")}}) // CastOp 的"dtype"属性的对应的全局唯一ID为"dtype1" + (pat.Tensor("arg0")); // CastOp 输入 Tensor 为"arg0" + pat.Tensor("ret") = + pat.Op(paddle::dialect::CastOp::name(), + {{"dtype", pat.Attr("dtype2")}})(pat.Tensor("tmp")); + // 4. 定义 Constrain + pat.RequireEqual(pat("tmp").dtype(), pat.Tensor("ret").dtype()); + + // 5. 定义 ResultPattern + auto res = pat.ResultPattern(); + res.Tensor("ret") = + res.Op(paddle::dialect::CastOp::name(), + {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); + } +}; +~~~ + +DRR PASS 包含以下三个部分: ++ `SourcePattern`:用于描述在 Program 中待匹配的目标子图 ++ `Constrains`:用于指定`SourcePattern`匹配的限制条件(非必需) ++ `ResultPattern`:用于描述需要替换为的模式子图 +开发者只需要定义出`SourcePattern`, `Constrains`和`ResultPattern`即可实现一个完整的 PASS。 + +**注意:** +1. **DRR 仅支持对闭包(除 Pattern 输入输出 Tensor 以外,所有的内部 Tensor 不能被 Pattern 外部 Op 使用)的 SourcePattern 和 ResultPattern 进行匹配替换,若定义的 Pattern 在 Program 中不闭包则匹配失败** +2. **ResultPattern 的输入输出需要满足是 SourcePattern 的输入输出的子集** +## 2. 接口列表 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
函数 功能描述 参数解释
DrrPatternBase
 virtual void operator()(
+        pir::drr::DrrPatternContext* ctx) const 
实现 DRR PASS 的入口函数 ctx: 创建 Patten 所需要的 Context 参数
SourcePattern
 const drr::Op& Op(
+    const std::string& op_type,
+    const std::unordered_map<std::string, Attribute>& attributes)
在 SourcePattern 中定义一个 Op op_type: 定义的 Op 名称,可以通过 paddle::dialect::xxOp + ::name() 接口获取
attributes : 所创建的 Op 的属性信息
 const drr::Tensor& Tensor(
+        const std::string& tensor_name) 
在 SourcePattern 中定义一个名为 tensor_name 的 tensor tensor_name: 定义的 Tensor 的名称,需要满足 SourcePattern 内唯一
 Attribute Attr(
+        const std::string& attr_name) const 
在 SourcePattern 中定义一个名为 attr_name 的属性 attr_name: 属性的名称,需要满足 SourcePattern 内唯一
 void RequireEqual(
+        const TensorShape& first,
+        const TensorShape& second)
要求 SourcePattern 中两个 Tensor 的 TensorShape 相同 first: 第一个 TensorShape
second : 第二个 TensorShape
 void RequireEqual(
+        const TensorDataType& first,
+        const TensorDataType& second)
要求 SourcePattern 中两个 Tensor 的数据类型相同 first: 第一个 Tensor 的 DataType
second : 第二个 Tensor 的 DataType
void RequireNativeCall(
+        const std::function<bool(const MatchContext&)>& custom_fn)
在 SourcePattern 中定义一个约束,可以利用此接口和 lamda 表达式实现对 SourcePattern 的自定义约束 custom_fn: 自定义的约束函数
ResultPattern
 const drr::Op& Op(
+    const std::string& op_type,
+    const std::unordered_map<std::string, Attribute>&  attributes) 
在ResultPattern中定义一个Op op_type: 定义的 Op 名称,可以通过 paddle::dialect::xxOp + ::name() 接口获取
attributes : 所创建的 Op 的属性信息
const drr::Tensor& Tensor(
+        const std::string& tensor_name)
在 ResultPattern 中定义一个名为 tensor_name 的 tensor tensor_name: 定义的 Tensor 的名称,需要满足 ResultPattern 内唯一
Attribute Attr(
+        const std::string& attr_name) const 
在 ResultPattern 中定义一个名为 attr_name 的属性 attr_name: 属性的名称,需要满足 ResultPattern 内唯一
using AttrComputeFunc = std::function<std::any(const MatchContext&)>;
+Attribute Attr(const AttrComputeFunc& attr_compute_func) const
通过自定义的计算逻辑 AttrComputeFunc,创建出一个 Attributeattr_compute_func: 自定义的计算逻辑
drr::Tensor& NoneTensor()
当一个 Op 的输入 Tensor 是一个可选项并且不需要时,需要使用 NoneTensor 来占位 /
TensorShape
explicit TensorShape(
+        const std::string& tensor_name) 
抽象出来描述 Tensor 的 shape 的类 tensor_name: 被描述的 Tensor 的 name
 const std::string& tensor_name() const
获取 tensor 的 name /
TensorDataType
explicit TensorDataType(
+        const std::string& tensor_name)
抽象出来的描述 Tensor 中元素数据类型的类 tensor_name: 被描述的 Tensor 的 name
 const std::string& tensor_name() const
获取 Tensor 的 name /
DrrPatternContext
drr::SourcePattern DrrPatternContext::SourcePattern()
创建一个 SourcePattern 对象,并返回 /
+ +## 3 使用示例 +Example 1: Matmul + Add -> FusedGemmEpilogue +~~~ c++ +class FusedLinearPattern : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // 定义 Source Pattern + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + + pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w")); + pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); + + // 定义 Result Pattern + pir::drr::ResultPattern res = pat.ResultPattern(); + // 定义 Constrain + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "none"; + }); + const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), + {{{"trans_x", pat.Attr("trans_x")}, + {"trans_y", pat.Attr("trans_y")}, + {"activation", act_attr}}}); + fused_gemm_epilogue( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out")}); + } +}; +~~~ + +Example 2: Full + Expand -> Full +~~~ c++ +class FoldExpandToConstantPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // 定义 Source Pattern + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &full1 = pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape_1")}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); + const auto &full_int_array1 = + pat.Op(paddle::dialect::FullIntArrayOp::name(), + {{"value", pat.Attr("expand_shape_value")}, + {"dtype", pat.Attr("dtype_2")}, + {"place", pat.Attr("place_2")}}); + const auto &expand = pat.Op(paddle::dialect::ExpandOp::name()); + pat.Tensor("ret") = expand(full1(), full_int_array1()); + + // 定义 Result Pattern Constrains: 本 Pass 无额外约束规则 + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &full2 = res.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("expand_shape_value")}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); + res.Tensor("ret") = full2(); + } +}; +~~~ diff --git a/paddle/fluid/pir/drr/api/drr_pattern_base.h b/paddle/fluid/pir/drr/api/drr_pattern_base.h new file mode 100644 index 0000000000000..1a84c42800373 --- /dev/null +++ b/paddle/fluid/pir/drr/api/drr_pattern_base.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/drr_rewrite_pattern.h" + +namespace pir { +namespace drr { + +template +class DrrPatternBase { + public: + virtual ~DrrPatternBase() = default; + + // Define the Drr Pattern. + virtual void operator()(pir::drr::DrrPatternContext* ctx) const = 0; + + std::unique_ptr Build( + pir::IrContext* ir_context, pir::PatternBenefit benefit = 1) const { + DrrPatternContext drr_context; + this->operator()(&drr_context); + std::string pattern_name = pir::get_type_name(); + return std::make_unique( + pattern_name, drr_context, ir_context, benefit); + } +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.cc b/paddle/fluid/pir/drr/api/drr_pattern_context.cc new file mode 100644 index 0000000000000..a0c9987e2e9b9 --- /dev/null +++ b/paddle/fluid/pir/drr/api/drr_pattern_context.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" + +#include "paddle/fluid/pir/drr/pattern_graph.h" +#include "paddle/phi/core/enforce.h" + +namespace pir { +namespace drr { + +DrrPatternContext::DrrPatternContext() { + source_pattern_graph_ = std::make_shared(); + result_pattern_graph_ = std::make_shared(); +} + +drr::SourcePattern DrrPatternContext::SourcePattern() { + return drr::SourcePattern(this); +} +const Op& DrrPatternContext::SourceOpPattern( + const std::string& op_type, + const std::unordered_map& attributes) { + owned_ops_.push_back(std::shared_ptr( + new drr::Op(op_type, attributes, source_pattern_graph_.get()))); + return *owned_ops_.back(); +} + +const drr::Tensor& DrrPatternContext::SourceTensorPattern( + const std::string& name) { + return source_pattern_graph_->AddTensor(std::shared_ptr( + new drr::Tensor(name, source_pattern_graph_.get()))); +} + +const Op& DrrPatternContext::ResultOpPattern( + const std::string& op_type, + const std::unordered_map& attributes) { + owned_ops_.push_back(std::shared_ptr( + new drr::Op(op_type, attributes, result_pattern_graph_.get()))); + return *owned_ops_.back(); +} + +drr::Tensor& DrrPatternContext::ResultTensorPattern(const std::string& name) { + return result_pattern_graph_->AddTensor(std::shared_ptr( + new drr::Tensor(name, result_pattern_graph_.get()))); +} + +std::vector DrrPatternContext::constraints() const { + return constraints_; +} + +// void DrrPatternContext::RequireEqual(const Attribute& first, const Attribute& +// second) { +// auto constrain_fn = [&](const MatchContext& match_context) { +// return match_context.Attr(first.id()) == match_context.Attr(second.id()); +// }; +// constraints_.emplace_back(constrain_fn); +// } + +void DrrPatternContext::RequireEqual(const TensorShape& first, + const TensorShape& second) { + // Note: we capture the datas by value for constrain_fn + // because the datas are destructed before running constrain_fn. + auto constrain_fn = [=](const MatchContext& match_context) { + return match_context.Tensor(first.tensor_name()).Shape() == + match_context.Tensor(second.tensor_name()).Shape(); + }; + constraints_.emplace_back(constrain_fn); +} + +void DrrPatternContext::RequireEqual(const TensorDataType& first, + const TensorDataType& second) { + // Note: we capture the datas by value for constrain_fn + // because the datas are destructed before running constrain_fn. + auto constrain_fn = [=](const MatchContext& match_context) { + return match_context.Tensor(first.tensor_name()).Dtype() == + match_context.Tensor(second.tensor_name()).Dtype(); + }; + constraints_.emplace_back(constrain_fn); +} + +void DrrPatternContext::RequireNativeCall( + const std::function& custom_fn) { + constraints_.emplace_back(custom_fn); +} + +void Op::operator()(const Tensor& arg, const Tensor* out) const { + std::vector inputs{&arg}; + std::vector outputs{out}; + pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); +} + +void Op::operator()(const std::vector& args, + const std::vector& outputs) const { + pattern_graph_->AddOpCall(std::make_shared(this, args, outputs)); +} + +Tensor& Op::operator()(const Tensor& arg) const { + std::vector inputs{&arg}; + auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr(new Tensor( + prefix + op_type_name_ + "_" + std::to_string(count++), pattern_graph_))); + std::vector outputs{&out}; + pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); + return out; +} + +Tensor& Op::operator()(const Tensor& arg1, const Tensor& arg2) const { + std::vector inputs{&arg1, &arg2}; + auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr(new Tensor( + prefix + op_type_name_ + "_" + std::to_string(count++), pattern_graph_))); + std::vector outputs{&out}; + pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); + return out; +} + +Tensor& Op::operator()() const { + std::vector inputs{}; + auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr(new Tensor( + prefix + op_type_name_ + "_" + std::to_string(count++), pattern_graph_))); + std::vector outputs{&out}; + pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); + return out; +} + +thread_local int64_t Op::count = 0; +const char* Op::prefix = "@drr_temp@_"; + +const char Tensor::NONE_TENSOR_NAME[] = "__@none_tensor@__"; + +void Tensor::Assign(const Tensor& other) { + dynamic_cast(pattern_graph_)->AssignTensor(*this, other); +} + +void Tensor::operator=(const Tensor& other) const { // NOLINT + // The two tensor must be in the same pattern graph. + PADDLE_ENFORCE_EQ( + this->pattern_graph_, + other.pattern_graph_, + phi::errors::InvalidArgument("Matching failed." + "Two Tensors must be in the same pattern " + "graph to make the '=' judgment.")); + if (other.name_.find(Op::prefix) == 0 && + name_.find(Op::prefix) == std::string::npos) { + other.pattern_graph_->UpdateTmpTensor(other.name_, this->name_); + } +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.h b/paddle/fluid/pir/drr/api/drr_pattern_context.h new file mode 100644 index 0000000000000..b4156bd54bf41 --- /dev/null +++ b/paddle/fluid/pir/drr/api/drr_pattern_context.h @@ -0,0 +1,334 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/pir/drr/api/match_context.h" + +namespace pir { +namespace drr { + +class Op; +class Tensor; +class OpCall; +class SourcePattern; +class ResultPattern; +class PatternGraph; +class SourcePatternGraph; +class ResultPatternGraph; + +class NormalAttribute { + public: + explicit NormalAttribute(const std::string& name) : attr_name_(name) {} + + const std::string& name() const { return attr_name_; } + + private: + std::string attr_name_; +}; + +using AttrComputeFunc = std::function; + +class ComputeAttribute { + public: + explicit ComputeAttribute(const AttrComputeFunc& attr_compute_func) + : attr_compute_func_(attr_compute_func) {} + + const AttrComputeFunc& attr_compute_func() const { + return attr_compute_func_; + } + + private: + AttrComputeFunc attr_compute_func_; +}; + +using Attribute = std::variant; + +class TensorShape { + public: + explicit TensorShape(const std::string& tensor_name) + : tensor_name_(tensor_name) {} + + const std::string& tensor_name() const { return tensor_name_; } + + private: + std::string tensor_name_; +}; + +class TensorDataType { + public: + explicit TensorDataType(const std::string& tensor_name) + : tensor_name_(tensor_name) {} + + const std::string& tensor_name() const { return tensor_name_; } + + private: + std::string tensor_name_; +}; + +class Constraint { + public: + explicit Constraint( + const std::function& constrain_fn) + : IsContextMatchConstraint_(constrain_fn) {} + bool operator()(const MatchContext& match_context) const { + return IsContextMatchConstraint_(match_context); + } + + private: + std::function IsContextMatchConstraint_; +}; + +class DrrPatternContext { + public: + DrrPatternContext(); + ~DrrPatternContext() = default; + + drr::SourcePattern SourcePattern(); + + std::shared_ptr source_pattern_graph() const { + return source_pattern_graph_; + } + + std::vector constraints() const; + + std::shared_ptr result_pattern_graph() const { + return result_pattern_graph_; + } + + private: + friend class drr::SourcePattern; + friend class drr::ResultPattern; + + const Op& SourceOpPattern( + const std::string& op_type, + const std::unordered_map& attributes = {}); + const drr::Tensor& SourceTensorPattern(const std::string& name); + + const Op& ResultOpPattern( + const std::string& op_type, + const std::unordered_map& attributes = {}); + drr::Tensor& ResultTensorPattern(const std::string& name); + + // void RequireEqual(const Attribute& first, const Attribute& second); + void RequireEqual(const TensorShape& first, const TensorShape& second); + void RequireEqual(const TensorDataType& first, const TensorDataType& second); + void RequireNativeCall( + const std::function& custom_fn); + + std::shared_ptr source_pattern_graph_; + std::vector constraints_; + std::shared_ptr result_pattern_graph_; + + std::vector> owned_ops_; +}; + +class Op { + public: + const std::string& name() const { return op_type_name_; } + + void operator()(const Tensor& arg, const Tensor* out) const; + + Tensor& operator()() const; + + Tensor& operator()(const Tensor& arg) const; + Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const; + void operator()(const std::vector& args, + const std::vector& outputs) const; + // const Tensor& operator()(const Tensor& arg0, const Tensor& arg1, const + // Tensor& arg2) const; const Tensor& operator()(const Tensor& arg0, const + // Tensor& arg1, const Tensor& arg2, const Tensor& arg3) const; const Tensor& + // operator()(const Tensor& arg0, const Tensor& arg1, const Tensor& arg2, + // const Tensor& arg3, const Tensor& arg4) const; + + static const char* prefix; + + private: + friend class DrrPatternContext; + friend class OpCall; + + Op(const std::string& op_type_name, + const std::unordered_map& attributes, + PatternGraph* pattern_graph) + : op_type_name_(op_type_name), + attributes_(attributes), + pattern_graph_(pattern_graph) {} + + const std::unordered_map& attributes() const { + return attributes_; + } + + thread_local static int64_t count; + + std::string op_type_name_; + std::unordered_map attributes_; + PatternGraph* pattern_graph_{nullptr}; +}; + +class Tensor { + public: + static const char NONE_TENSOR_NAME[]; + + const std::string& DebugName() const; + + TensorShape shape() const { return TensorShape(name()); } + + TensorDataType dtype() const { return TensorDataType(name()); } + + bool is_none() const { return name_ == NONE_TENSOR_NAME; } + + void Assign(const Tensor& other); + + void operator=(const Tensor& other) const; // NOLINT + + const std::string& name() const { return name_; } + + void set_name(const std::string& name) { name_ = name; } + + OpCall* producer() const { return producer_; } + + void set_producer(OpCall* producer) { producer_ = producer; } + + const std::vector& consumers() const { return consumers_; } + + void set_consumables(const std::vector& consumers) { + consumers_ = consumers; + } + + void AddConsumer(const OpCall* consumer) { consumers_.push_back(consumer); } + + private: + friend class DrrPatternContext; + friend class Op; + + Tensor(const std::string& name, PatternGraph* pattern_graph) + : name_(name), pattern_graph_(pattern_graph) {} + + std::string name_; + OpCall* producer_{nullptr}; + std::vector consumers_; + PatternGraph* pattern_graph_{nullptr}; +}; + +class OpCall { + public: + OpCall(const Op* op, + const std::vector& inputs, + const std::vector& outputs) + : op_name_(op->op_type_name_), + inputs_(inputs), + outputs_(outputs), + attributes_(op->attributes_) {} + + const std::string& name() const { return op_name_; } + + const std::vector& inputs() const { return inputs_; } + + const std::vector& outputs() const { return outputs_; } + + const std::unordered_map& attributes() const { + return attributes_; + } + + private: + std::string op_name_; + std::vector inputs_; + std::vector outputs_; + std::unordered_map attributes_; +}; + +class ResultPattern { + public: + const drr::Op& Op( + const std::string& op_type, + const std::unordered_map& attributes = {}) { + return ctx_->ResultOpPattern(op_type, attributes); + } + + drr::Tensor& Tensor(const std::string& name) { + return ctx_->ResultTensorPattern(name); + } + + // Represent the input tensor which is none. + // Example: + // instance_norm has follow input tensor : (x, scale, bias), scale and + // bias are optional(means it may be none). + // When scale is onoe, we can write a instance_norm op in drr as follow: + // res.Op("instance_norm")(res.Tensor("x"), res.NoneTensor, + // res.Tensor("bias")); + drr::Tensor& NoneTensor() { + return ctx_->ResultTensorPattern(Tensor::NONE_TENSOR_NAME); + } + + Attribute Attr(const std::string& attr_name) const { + return NormalAttribute(attr_name); + } + Attribute Attr(const AttrComputeFunc& attr_compute_func) const { + return ComputeAttribute(attr_compute_func); + } + + private: + friend class SourcePattern; + + explicit ResultPattern(DrrPatternContext* ctx) : ctx_(ctx) {} + + DrrPatternContext* ctx_{nullptr}; +}; + +class SourcePattern { + public: + drr::ResultPattern ResultPattern() const { return drr::ResultPattern(ctx_); } + + const drr::Op& Op( + const std::string& op_type, + const std::unordered_map& attributes = {}) { + return ctx_->SourceOpPattern(op_type, attributes); + } + + const drr::Tensor& Tensor(const std::string& name) { + return ctx_->SourceTensorPattern(name); + } + + Attribute Attr(const std::string& attr_name) const { + return NormalAttribute(attr_name); + } + + void RequireEqual(const TensorShape& first, const TensorShape& second) { + ctx_->RequireEqual(first, second); + } + void RequireEqual(const TensorDataType& first, const TensorDataType& second) { + ctx_->RequireEqual(first, second); + } + + void RequireNativeCall( + const std::function& custom_fn) { + ctx_->RequireNativeCall(custom_fn); + } + + private: + friend class DrrPatternContext; + explicit SourcePattern(DrrPatternContext* ctx) : ctx_(ctx) {} + DrrPatternContext* ctx_{nullptr}; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/match_context.cc b/paddle/fluid/pir/drr/api/match_context.cc new file mode 100644 index 0000000000000..35b28db13254e --- /dev/null +++ b/paddle/fluid/pir/drr/api/match_context.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/api/match_context.h" + +#include + +#include "paddle/fluid/pir/drr/ir_operation.h" +#include "paddle/fluid/pir/drr/match_context_impl.h" + +namespace pir { +namespace drr { + +MatchContext::MatchContext(std::shared_ptr impl) + : impl_(impl) {} + +const TensorInterface& MatchContext::Tensor( + const std::string& tensor_name) const { + return impl_->Tensor(tensor_name); +} + +template +T MatchContext::Attr(const std::string& attr_name) const { + return impl_->Attr(attr_name); +} + +template bool MatchContext::Attr(const std::string&) const; +template int32_t MatchContext::Attr(const std::string&) const; +template int64_t MatchContext::Attr(const std::string&) const; +template float MatchContext::Attr(const std::string&) const; +template std::string MatchContext::Attr(const std::string&) const; +template std::vector MatchContext::Attr>( + const std::string&) const; +template std::vector MatchContext::Attr>( + const std::string&) const; + +} // namespace drr +} // namespace pir diff --git a/paddle/pir/dialect/control_flow/ir/cf_ops.h b/paddle/fluid/pir/drr/api/match_context.h similarity index 55% rename from paddle/pir/dialect/control_flow/ir/cf_ops.h rename to paddle/fluid/pir/drr/api/match_context.h index 7d669c0b648ea..a1699ccb5bddf 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_ops.h +++ b/paddle/fluid/pir/drr/api/match_context.h @@ -13,23 +13,31 @@ // limitations under the License. #pragma once -#include -#include "paddle/pir/core/builder.h" -#include "paddle/pir/core/op_base.h" + +#include +#include + +#include "paddle/fluid/pir/drr/api/tensor_interface.h" +#include "paddle/fluid/pir/drr/ir_operation.h" namespace pir { -class IR_API YieldOp : public Op { +namespace drr { + +class TensorInterface; +class MatchContextImpl; + +class MatchContext final { public: - using Op::Op; - static const char *name() { return "cf.yield"; } - static constexpr uint32_t attributes_num = 0; - static constexpr const char **attributes_name = nullptr; - - static void Build(Builder &builder, // NOLINT - OperationArgument &argument, // NOLINT - const std::vector &Value); - void VerifySig() {} + MatchContext(std::shared_ptr impl); + + const TensorInterface& Tensor(const std::string& tensor_name) const; + + template + T Attr(const std::string& attr_name) const; + + private: + std::shared_ptr impl_; }; -} // namespace pir -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::YieldOp); +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/tensor_interface.cc b/paddle/fluid/pir/drr/api/tensor_interface.cc new file mode 100644 index 0000000000000..1b81b3a567211 --- /dev/null +++ b/paddle/fluid/pir/drr/api/tensor_interface.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/api/tensor_interface.h" +#include "paddle/fluid/pir/drr/ir_value.h" + +namespace pir { +namespace drr { + +bool ShapeInterface::operator==(const ShapeInterface& other) const { + return *shape_ == *other.shape_; +} + +int ShapeInterface::size() const { return shape_->size(); } + +int64_t ShapeInterface::at(int idx) const { return shape_->at(idx); } + +bool DtypeInterface::operator==(const DtypeInterface& other) const { + return *dtype_ == *other.dtype_; +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/tensor_interface.h b/paddle/fluid/pir/drr/api/tensor_interface.h new file mode 100644 index 0000000000000..7629857591bf3 --- /dev/null +++ b/paddle/fluid/pir/drr/api/tensor_interface.h @@ -0,0 +1,61 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace pir { +namespace drr { + +class IrValue; +class IrShape; +class IrDtype; + +class ShapeInterface final { + public: + bool operator==(const ShapeInterface& other) const; + + int size() const; + + int64_t at(int idx) const; + + private: + explicit ShapeInterface(const IrShape* shape) : shape_(shape) {} + + friend class IrValue; + + const IrShape* shape_; +}; + +class DtypeInterface final { + public: + bool operator==(const DtypeInterface& other) const; + + private: + explicit DtypeInterface(const IrDtype* dtype) : dtype_(dtype) {} + + friend class IrValue; + + const IrDtype* dtype_; +}; + +class TensorInterface { + public: + virtual ShapeInterface Shape() const = 0; + virtual DtypeInterface Dtype() const = 0; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/attr_type_uilts.h b/paddle/fluid/pir/drr/attr_type_uilts.h new file mode 100644 index 0000000000000..28b26ba26a2a1 --- /dev/null +++ b/paddle/fluid/pir/drr/attr_type_uilts.h @@ -0,0 +1,118 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/pir/core/builtin_attribute.h" + +namespace pir { +namespace drr { + +template +struct CppTypeToIrAttribute; + +#define PD_SPECIALIZE_CppTypeToIrAttribute(cpp_type, ir_attr_type) \ + template <> \ + struct CppTypeToIrAttribute< \ + std::remove_const_t>> { \ + using type = ir_attr_type; \ + }; + +PD_SPECIALIZE_CppTypeToIrAttribute(bool, BoolAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(int32_t, Int32Attribute); +PD_SPECIALIZE_CppTypeToIrAttribute(int64_t, Int64Attribute); +PD_SPECIALIZE_CppTypeToIrAttribute(float, FloatAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(std::string, StrAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(phi::DataType, + paddle::dialect::DataTypeAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, pir::ArrayAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, + paddle::dialect::IntArrayAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(phi::IntArray, + paddle::dialect::IntArrayAttribute); + +template +struct IrAttrbuteCreator { + typename CppTypeToIrAttribute::type operator()(T obj) const { + return CppTypeToIrAttribute::type::template get( + pir::IrContext::Instance(), obj); + } +}; + +template <> +struct IrAttrbuteCreator> { + pir::ArrayAttribute operator()(std::vector obj) const { + std::vector attr_vec; + attr_vec.reserve(obj.size()); + for (int32_t x : obj) { + attr_vec.push_back(Int32Attribute::get(pir::IrContext::Instance(), x)); + } + return pir::ArrayAttribute::get(pir::IrContext::Instance(), attr_vec); + } +}; + +template +struct IrAttrTypeCast { + static T To(const pir::Attribute& attr) { + return attr.dyn_cast::type>().data(); + } +}; + +template <> +struct IrAttrTypeCast { + static std::string To(const pir::Attribute& attr) { + return attr.dyn_cast::type>() + .AsString(); + } +}; + +template <> +struct IrAttrTypeCast> { + static std::vector To(const pir::Attribute& attr) { + std::vector result; + auto array_attr = attr.dyn_cast(); + for (size_t i = 0; i < array_attr.size(); i++) { + result.push_back(array_attr.at(i).dyn_cast().data()); + } + return result; + } +}; + +template <> +struct IrAttrTypeCast> { + static std::vector To(const pir::Attribute& attr) { + std::vector result; + if (attr.dyn_cast()) { + auto array_attr = attr.dyn_cast(); + for (size_t i = 0; i < array_attr.size(); i++) { + result.push_back( + array_attr.at(i).dyn_cast().data()); + } + } else if (attr.dyn_cast()) { + result = + attr.dyn_cast().data().GetData(); + } else { + PADDLE_THROW(phi::errors::Unavailable( + "Dynamic cast failed for IR attribute vector")); + } + return result; + } +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc b/paddle/fluid/pir/drr/drr_rewrite_pattern.cc new file mode 100644 index 0000000000000..e1b94bb77a082 --- /dev/null +++ b/paddle/fluid/pir/drr/drr_rewrite_pattern.cc @@ -0,0 +1,546 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/drr_rewrite_pattern.h" + +namespace pir { +namespace drr { + +bool DrrRewritePattern::MatchAndRewrite( + pir::Operation* op, + PatternRewriter& rewriter) const { // NOLINT + std::shared_ptr src_match_ctx = + std::make_shared(); + if (PatternGraphMatch(op, src_match_ctx.get())) { + VLOG(4) << "DRR pattern (" << pattern_name_ << ") is matched in program."; + PatternGraphRewrite(*src_match_ctx, rewriter); + return true; + } + return false; +} + +bool DrrRewritePattern::PatternGraphMatch( + pir::Operation* op, MatchContextImpl* source_pattern_match_ctx) const { + VLOG(6) << "PatternGraphMatch Start: op(" << op->name() << ")"; + const OpCall* anchor = source_pattern_graph_->AnchorNode(); + std::unordered_map> + bind_map = + FindCandidateIrOutputOp(op, anchor, *(source_pattern_graph_.get())); + if (bind_map.empty()) { + return false; + } + std::vector drr_output_sequence; + std::vector ir_output_sequence; + std::unordered_map output_op_map; + for (auto pair : bind_map) { + drr_output_sequence.push_back(pair.first); + } + // using dfs to obtain the arrangement of all candidate ir ops + auto permute = [&](auto&& permute, size_t index) -> bool { + if (index == drr_output_sequence.size()) { + // avoiding duplicate binding of ir op + std::unordered_set ir_output_set; + for (Operation* op : ir_output_sequence) { + auto pr = ir_output_set.insert(op); + if (pr.second == false) { + return false; + } + } + // new match_ctx + std::shared_ptr match_ctx = + std::make_shared(); + std::transform(drr_output_sequence.begin(), + drr_output_sequence.end(), + ir_output_sequence.begin(), + std::inserter(output_op_map, output_op_map.end()), + [](const OpCall* drr_op, Operation* ir_op) { + return std::make_pair(drr_op, ir_op); + }); + if (MatchFromOutputToInput( + output_op_map, *(source_pattern_graph_.get()), match_ctx)) { + *source_pattern_match_ctx = *match_ctx; + return true; + } + return false; + } + for (auto* ir_op : bind_map[drr_output_sequence[index]]) { + ir_output_sequence.push_back(ir_op); + if (permute(permute, index + 1)) { + return true; + } + ir_output_sequence.pop_back(); + } + return false; + }; + + return permute(permute, 0); +} + +std::unordered_map> +DrrRewritePattern::FindCandidateIrOutputOp( + pir::Operation* op, + const OpCall* anchor, + const SourcePatternGraph& source_pattern_graph) const { + // get source pattern output op + std::unordered_set drr_output_op_set = + source_pattern_graph.OutputNodes(); + std::unordered_map> + output_op_bind_map{{anchor, {op}}}; + if (drr_output_op_set.size() == 1) { + return output_op_bind_map; + } + std::unordered_set drr_visited_ops{anchor}; + DfsVisitor( + anchor, op, drr_output_op_set, &drr_visited_ops, &output_op_bind_map); + if (output_op_bind_map.size() != drr_output_op_set.size()) { + return {}; + } + return output_op_bind_map; +} + +void DrrRewritePattern::DfsVisitor( + const OpCall* drr_op, + pir::Operation* ir_op, + const std::unordered_set& drr_output_op_set, + std::unordered_set* drr_visited_ops, + std::unordered_map>* + output_op_bind_map) const { + VLOG(6) << "DfsVisitor Start: drr op(" << drr_op->name() << ")" + << "ir op(" << ir_op->name() << ")"; + if (drr_op->name() != ir_op->name()) { + return; + } + // check op input's size + const auto& drr_op_input_tensors = drr_op->inputs(); + auto ir_op_input_value_size = ir_op->num_operands(); + if (drr_op_input_tensors.size() != ir_op_input_value_size) { + return; + } + // check op output's size + const auto& drr_op_output_tensors = drr_op->outputs(); + auto ir_op_output_value_size = ir_op->num_results(); + if (drr_op_output_tensors.size() != ir_op_output_value_size) { + return; + } + // check producer op + for (size_t i = 0; i < drr_op_input_tensors.size(); ++i) { + // case 1: drr_op_input_tensor is the input tensor of source pattern + if (drr_op_input_tensors[i]->producer() == nullptr) { + // dfs source pattern input tensor other child op + auto ir_input_tensor = ir_op->operand(i).source(); + for (auto drr_bro_op : drr_op_input_tensors[i]->consumers()) { + if (drr_visited_ops->count(drr_bro_op)) { + continue; + } + for (auto it = ir_input_tensor.use_begin(); + it != ir_input_tensor.use_end(); + ++it) { + auto* ir_bro_op = it.owner(); + if (drr_bro_op->name() == ir_bro_op->name()) { + drr_visited_ops->insert(drr_bro_op); + DfsVisitor(drr_bro_op, + ir_bro_op, + drr_output_op_set, + drr_visited_ops, + output_op_bind_map); + drr_visited_ops->erase(drr_bro_op); + } + } + } + continue; + } + // case 2: have producer op + const auto& drr_producer_op = drr_op_input_tensors[i]->producer(); + if (drr_visited_ops->count(drr_producer_op)) { + continue; + } + auto ir_operand_value = ir_op->operand(i).source(); + if (drr_op_input_tensors[i]->consumers().size() != + ir_operand_value.use_count()) { + return; + } + auto* ir_producer_op = ir_operand_value.dyn_cast().owner(); + drr_visited_ops->insert(drr_producer_op); + DfsVisitor(drr_producer_op, + ir_producer_op, + drr_output_op_set, + drr_visited_ops, + output_op_bind_map); + drr_visited_ops->erase(drr_producer_op); + } + if (drr_output_op_set.count(drr_op)) { + (*output_op_bind_map)[drr_op].insert(ir_op); + return; + } + // check child ops + for (size_t i = 0; i < drr_op_output_tensors.size(); ++i) { + const auto& drr_child_ops = drr_op_output_tensors[i]->consumers(); + auto ir_output_value = ir_op->result(i); + if (drr_child_ops.size() != ir_output_value.use_count()) { + return; + } + for (auto* drr_child_op : drr_child_ops) { + for (auto it = ir_output_value.use_begin(); + it != ir_output_value.use_end(); + ++it) { + auto* ir_child_op = it.owner(); + if (drr_child_op->name() == ir_child_op->name()) { + if (drr_visited_ops->count(drr_child_op)) { + continue; + } + drr_visited_ops->insert(drr_child_op); + DfsVisitor(drr_child_op, + ir_child_op, + drr_output_op_set, + drr_visited_ops, + output_op_bind_map); + drr_visited_ops->erase(drr_child_op); + } + } + } + } // check child ops + return; +} + +bool DrrRewritePattern::MatchFromOutputToInput( + std::unordered_map output_op_map, + const SourcePatternGraph& source_pattern_graph, + const std::shared_ptr& source_pattern_match_ctx) const { + VLOG(6) << "MatchFromOutputToInput Start"; + std::unordered_set drr_visited; + std::unordered_set ir_visited; + std::queue drr_q; + std::queue ir_q; + bool matched = true; + size_t step = 0; + for (auto it = output_op_map.begin(); it != output_op_map.end(); ++it) { + VLOG(6) << "match (" << it->first->name() << " @" << it->first << " : @" + << it->second << ") in source_pattern_graph "; + drr_q.push(it->first); + drr_visited.insert(it->first); + ir_q.push(it->second); + ir_visited.insert(it->second); + } + while (!drr_q.empty()) { + if (!matched) break; + auto* drr_node = drr_q.front(); + auto* ir_node = ir_q.front(); + drr_q.pop(); + ir_q.pop(); + if (drr_node->name() != ir_node->name()) { + matched = false; + break; + } + const auto& drr_input_tensors = drr_node->inputs(); + auto ir_input_value_size = ir_node->num_operands(); + if (drr_input_tensors.size() != ir_input_value_size) { + matched = false; + break; + } + if (drr_node->outputs().size() != ir_node->num_results()) { + matched = false; + break; + } + source_pattern_match_ctx->BindIrOperation( + drr_node, std::make_shared(ir_node)); + // binding input_tensor of current_op + for (size_t i = 0; i < drr_input_tensors.size(); ++i) { + source_pattern_match_ctx->BindIrValue( + drr_input_tensors[i]->name(), + std::make_shared(ir_node->operand(i).source())); + auto* drr_producer_op = drr_input_tensors[i]->producer(); + if (drr_producer_op == nullptr) { + continue; + } + auto* ir_producer_op = + ir_node->operand(i).source().dyn_cast().owner(); + if (drr_input_tensors[i]->consumers().size() != + ir_node->operand(i).source().use_count()) { + matched = false; + break; + } + // bfs producer_op of current_op + if (drr_visited.count(drr_producer_op) && + ir_visited.count(ir_producer_op)) { + continue; + } + if (!drr_visited.count(drr_producer_op) && + !ir_visited.count(ir_producer_op)) { + drr_q.push(drr_producer_op); + ir_q.push(ir_producer_op); + drr_visited.insert(drr_producer_op); + ir_visited.insert(ir_producer_op); + } else { + matched = false; + break; + } + } + // binding output tensor of current_op + auto drr_op_output_tensor = drr_node->outputs(); + for (size_t j = 0; j < drr_op_output_tensor.size(); j++) { + source_pattern_match_ctx->BindIrValue( + drr_op_output_tensor[j]->name(), + std::make_shared(ir_node->result(j))); + } + ++step; + } + + if (matched) { + PADDLE_ENFORCE_EQ( + step, + source_pattern_graph.CountOfOpCalls(), + phi::errors::PreconditionNotMet( + "Pattern matching failed." + "The number of successful matches and the number of OpCalls in the " + "source pattern graph are not equal.")); + } else { + return matched; + } + + MatchContext match_context{source_pattern_match_ctx}; + for (const auto& constraint : constraints_) { + matched = constraint(match_context); + if (!matched) break; + } + + return matched; +} + +void DrrRewritePattern::PatternGraphRewrite( + const MatchContextImpl& source_pattern_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + VLOG(6) << "Create Operations in result_pattern_graph"; + MatchContextImpl res_match_ctx = CreateOperations(*source_pattern_graph_, + *result_pattern_graph_, + source_pattern_match_ctx, + rewriter); + VLOG(6) << "Process Assign Tensor"; + RebindIrTensorForAssignTensor(*result_pattern_graph_, &res_match_ctx); + VLOG(6) << "Replace Output Values in source_pattern_graph by Output Values " + "in result_pattern_graph"; + ReplaceOutputTensor(source_pattern_match_ctx, res_match_ctx, rewriter); + VLOG(6) << "Delete Operations in source_pattern_graph"; + DeleteSourcePatternOp(*source_pattern_graph_, + *result_pattern_graph_, + source_pattern_match_ctx, + rewriter); +} + +MatchContextImpl DrrRewritePattern::CreateOperations( + const SourcePatternGraph& source_pattern_graph, + const ResultPatternGraph& result_pattern_graph, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + MatchContextImpl res_match_ctx; + // add input tensors info for res_match_ctx + for (const auto& in_tensor : result_pattern_graph.input_tensors()) { + PADDLE_ENFORCE_NE( + result_pattern_graph.id2owend_tensor().count(in_tensor), + 0, + phi::errors::NotFound("Not found the input tensor." + "Drr input tensor [%s] must exist in the result " + "pattern graph to be obtained.", + in_tensor)); + if (!result_pattern_graph.id2owend_tensor().at(in_tensor)->is_none()) { + res_match_ctx.BindIrValue( + in_tensor, + std::make_shared(src_match_ctx.GetIrValue(in_tensor))); + } + } + + if (result_pattern_graph.CountOfOpCalls() == 1) { + CreateOperation(*result_pattern_graph.owned_op_call()[0], + src_match_ctx, + rewriter, + &res_match_ctx); + return res_match_ctx; + } + + std::vector> temp_program; + std::unordered_map op_2_temp_program_index; + for (Operation* op : *rewriter.block()) { + op_2_temp_program_index[op] = temp_program.size(); + temp_program.push_back({op}); + } + + // topo order visit result_pattern_graph + GraphTopo graph_topo_visit(&result_pattern_graph); + graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) { + // set insert point + size_t max_input_op_index = 0; + Operation* max_index_op = nullptr; + for (const Tensor* input : op_call.inputs()) { + if (input->is_none()) { + continue; + } + Value ir_val = res_match_ctx.GetIrValue(input->name()).get(); + if (ir_val) { + Operation* ir_input_op = ir_val.dyn_cast().owner(); + if (max_input_op_index < op_2_temp_program_index[ir_input_op]) { + max_input_op_index = op_2_temp_program_index[ir_input_op]; + max_index_op = ir_input_op; + } else if (max_input_op_index == op_2_temp_program_index[ir_input_op]) { + const auto& ops_vec = temp_program[max_input_op_index]; + for (auto it = ops_vec.rbegin(); it != ops_vec.rend(); it++) { + if (*it == max_index_op) { + break; + } else if (*it == ir_input_op) { + max_index_op = ir_input_op; + break; + } else { + // do nothing + } + } + } else { + // do nothing + } + } + } + if (max_input_op_index == 0UL) { + VLOG(6) << "Not found producer op for (" << op_call.name() << ")"; + Operation* source_patter_first_op = + src_match_ctx.Operation(source_pattern_graph.owned_op_call()[0].get()) + .get(); + max_input_op_index = op_2_temp_program_index[source_patter_first_op]; + rewriter.SetInsertionPoint(source_patter_first_op); + } else { + rewriter.SetInsertionPointAfter(max_index_op); + } + + Operation* new_op = + CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx); + op_2_temp_program_index[new_op] = max_input_op_index + 1; + temp_program[max_input_op_index + 1].push_back(new_op); + }); + + return res_match_ctx; +} + +void DrrRewritePattern::RebindIrTensorForAssignTensor( + const ResultPatternGraph& result_pattern_graph, + MatchContextImpl* res_match_ctx) const { + const auto& tensor_assign_map = result_pattern_graph.tensor_assign_map(); + for (const auto& kv : tensor_assign_map) { + const auto& src_tensor_name = kv.first; + const auto& dst_tensor_name = kv.second; + res_match_ctx->BindIrValue( + src_tensor_name, + std::make_shared(res_match_ctx->GetIrValue(dst_tensor_name))); + } +} + +void DrrRewritePattern::ReplaceOutputTensor( + const MatchContextImpl& src_match_ctx, + const MatchContextImpl& res_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + for (const auto& output_name : result_pattern_graph_->output_tensors()) { + if (source_pattern_graph_->id2owend_tensor().count(output_name)) { + const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name); + const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name); + rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get()); + } else { + LOG(WARNING) << "The output tensor (" << output_name + << ") in the result_pattern_graph is not the tensor" + " in source_pattern_graph."; + } + } +} + +void DrrRewritePattern::DeleteSourcePatternOp( + const SourcePatternGraph& source_pattern_graph, + const ResultPatternGraph& result_pattern_graph, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + std::vector topo_order_ops; + GraphTopo graph_topo_visit(&source_pattern_graph); + graph_topo_visit.WalkGraphNodesTopoOrder( + [&topo_order_ops](const OpCall& op_call) { + topo_order_ops.push_back(&op_call); + }); + + // Filter the operations which are replaced by result pattern + // 1. Filter operations by forward walk + std::unordered_set forward_visited_tensor_set( + result_pattern_graph.input_tensors()); + std::unordered_set forward_deleted_ops; + std::for_each(topo_order_ops.begin(), + topo_order_ops.end(), + [&forward_deleted_ops, + &forward_visited_tensor_set](const OpCall* op_call) { + if (op_call->inputs().empty()) { + forward_deleted_ops.insert(op_call); + for (const auto* output : op_call->outputs()) { + forward_visited_tensor_set.insert(output->name()); + } + } + for (const auto* input : op_call->inputs()) { + if (forward_visited_tensor_set.count(input->name())) { + forward_deleted_ops.insert(op_call); + for (const auto* output : op_call->outputs()) { + forward_visited_tensor_set.insert(output->name()); + } + break; + } + } + }); + // 2. Filter operations by backward walk and merge the forward result + std::unordered_set backward_visited_tensor_set( + result_pattern_graph.output_tensors()); + std::vector deleted_ops; + std::unordered_set deleted_ops_set; + std::for_each(topo_order_ops.rbegin(), + topo_order_ops.rend(), + [&deleted_ops, + &deleted_ops_set, + &backward_visited_tensor_set, + &forward_deleted_ops](const OpCall* op_call) { + bool all_comsumer_deleted = true; + bool from_backward_visited_tensor = false; + for (const auto* output : op_call->outputs()) { + if (backward_visited_tensor_set.count(output->name())) { + from_backward_visited_tensor = true; + } else if (output->consumers().empty()) { + continue; + } else { + all_comsumer_deleted = false; + } + } + if (all_comsumer_deleted && from_backward_visited_tensor && + forward_deleted_ops.count(op_call)) { + deleted_ops_set.insert(op_call); + deleted_ops.push_back(op_call); + for (const auto* input : op_call->inputs()) { + backward_visited_tensor_set.insert(input->name()); + } + } + }); + + // Delete Operation with topo order from output tensors. + for (const auto* op_call : deleted_ops) { + PADDLE_ENFORCE_NE(src_match_ctx.operation_map().count(op_call), + 0, + phi::errors::NotFound( + "Not found the OpCall." + "Only Opcall [%s] that exist in match context can be " + "deleted.", + op_call->name())); + auto* op = src_match_ctx.operation_map().at(op_call)->get(); + VLOG(6) << "Delete (" << op_call->name() << " @" << op_call << " :@" << op + << ") in source_pattern_graph "; + rewriter.EraseOp(op); + } +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.h b/paddle/fluid/pir/drr/drr_rewrite_pattern.h new file mode 100644 index 0000000000000..5d20a5947f13b --- /dev/null +++ b/paddle/fluid/pir/drr/drr_rewrite_pattern.h @@ -0,0 +1,116 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/fluid/pir/drr/ir_operation.h" +#include "paddle/fluid/pir/drr/ir_operation_factory.h" +#include "paddle/fluid/pir/drr/match_context_impl.h" +#include "paddle/fluid/pir/drr/pattern_graph.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/type_name.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" + +namespace pir { +namespace drr { + +class DrrRewritePattern : public pir::RewritePattern { + public: + explicit DrrRewritePattern(const std::string& pattern_name, + const DrrPatternContext& drr_context, + pir::IrContext* context, + pir::PatternBenefit benefit = 1) + : pir::RewritePattern( + drr_context.source_pattern_graph()->AnchorNode()->name(), + benefit, + context, + {}), + pattern_name_(pattern_name), + source_pattern_graph_(drr_context.source_pattern_graph()), + constraints_(drr_context.constraints()), + result_pattern_graph_(drr_context.result_pattern_graph()) { + PADDLE_ENFORCE_NE( + source_pattern_graph_->owned_op_call().empty(), + true, + phi::errors::InvalidArgument("Source pattern graph is empty." + "Suggested fix: Please check the DRR " + "source pattern definition code.")); + } + + bool MatchAndRewrite(pir::Operation* op, + PatternRewriter& rewriter) const override; // // NOLINT + + private: + bool PatternGraphMatch(pir::Operation* op, + MatchContextImpl* source_pattern_match_ctx) const; + + std::unordered_map> + FindCandidateIrOutputOp(pir::Operation* op, + const OpCall* anchor, + const SourcePatternGraph& source_pattern_graph) const; + + void DfsVisitor( + const OpCall* drr_op, + pir::Operation* ir_op, + const std::unordered_set& drr_output_op_set, + std::unordered_set* drr_visited_ops, + std::unordered_map>* + output_op_bind_map) const; + + bool MatchFromOutputToInput( + std::unordered_map output_op_map, + const SourcePatternGraph& source_pattern_graph, + const std::shared_ptr& source_pattern_match_ctx) const; + + void PatternGraphRewrite(const MatchContextImpl& source_pattern_match_ctx, + pir::PatternRewriter& rewriter) const; // NOLINT + + private: + MatchContextImpl CreateOperations( + const SourcePatternGraph& source_pattern_graph, + const ResultPatternGraph& result_pattern_graph, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter) const; // NOLINT + + void RebindIrTensorForAssignTensor( + const ResultPatternGraph& result_pattern_graph, + MatchContextImpl* res_match_ctx) const; + + void ReplaceOutputTensor(const MatchContextImpl& src_match_ctx, + const MatchContextImpl& res_match_ctx, + pir::PatternRewriter& rewriter) const; // NOLINT + + void DeleteSourcePatternOp(const SourcePatternGraph& source_pattern_graph, + const ResultPatternGraph& result_pattern_graph, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter) const; // NOLINT + + private: + const std::string pattern_name_; + const std::shared_ptr source_pattern_graph_; + const std::vector constraints_; + const std::shared_ptr result_pattern_graph_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/pir/dialect/control_flow/ir/cf_ops.cc b/paddle/fluid/pir/drr/ir_operation.h similarity index 70% rename from paddle/pir/dialect/control_flow/ir/cf_ops.cc rename to paddle/fluid/pir/drr/ir_operation.h index 7981a6ab96396..2764bc9245417 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_ops.cc +++ b/paddle/fluid/pir/drr/ir_operation.h @@ -12,15 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#pragma once + +#include "paddle/pir/core/operation.h" namespace pir { +namespace drr { -void YieldOp::Build(Builder &builder, - OperationArgument &argument, - const std::vector &inputs) { - argument.AddInputs(inputs); -} -} // namespace pir +class IrOperation { + public: + explicit IrOperation(pir::Operation* op) : op_(op) {} + + pir::Operation* get() const { return op_; } -IR_DEFINE_EXPLICIT_TYPE_ID(pir::YieldOp) + private: + pir::Operation* op_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/ir_operation_factory.cc b/paddle/fluid/pir/drr/ir_operation_factory.cc new file mode 100644 index 0000000000000..665976838cff3 --- /dev/null +++ b/paddle/fluid/pir/drr/ir_operation_factory.cc @@ -0,0 +1,169 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/ir_operation_factory.h" + +#include + +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/attr_type_uilts.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" + +namespace pir { +namespace drr { + +void OperationFactory::RegisterManualOpCreator() { + RegisterOperationCreator( + "pd_op.fused_gemm_epilogue", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) { + return rewriter.Build( + inputs[0].dyn_cast(), + inputs[1].dyn_cast(), + inputs[2].dyn_cast(), + attrs); + }); + RegisterOperationCreator( + "pd_op.fused_gemm_epilogue_grad", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) { + return rewriter.Build( + inputs[0].dyn_cast(), + inputs[1].dyn_cast(), + inputs[2].dyn_cast(), + inputs[3].dyn_cast(), + attrs); + }); + RegisterOperationCreator("builtin.combine", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) { + return rewriter.Build(inputs); + }); +} + +static pir::Attribute CreateIrAttribute(const std::any& obj) { + if (obj.type() == typeid(bool)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(int32_t)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(int64_t)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(float)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(std::string)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(const char*)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(phi::DataType)) { + return IrAttrbuteCreator()( + std::any_cast(obj)); + } else if (obj.type() == typeid(phi::Place)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(std::vector)) { + return IrAttrbuteCreator>()( + std::any_cast>(obj)); + } else if (obj.type() == typeid(std::vector)) { + return IrAttrbuteCreator>()( + std::any_cast>(obj)); + } else if (obj.type() == typeid(phi::IntArray)) { + return IrAttrbuteCreator()( + std::any_cast(obj)); + } else { + PADDLE_THROW( + phi::errors::Unimplemented("Type error. CreateIrAttribute for type(%s) " + "is unimplemented CreateInCurrently.", + obj.type().name())); + } +} + +pir::AttributeMap CreateAttributeMap(const OpCall& op_call, + const MatchContextImpl& src_match_ctx) { + pir::AttributeMap attr_map; + for (const auto& kv : op_call.attributes()) { + std::visit( + [&](auto&& arg) { + if constexpr (std::is_same_v, + NormalAttribute>) { + attr_map[kv.first] = src_match_ctx.GetIrAttr(arg.name()); + } + if constexpr (std::is_same_v, + ComputeAttribute>) { + MatchContext ctx(std::make_shared(src_match_ctx)); + attr_map[kv.first] = + CreateIrAttribute(arg.attr_compute_func()(ctx)); + } + }, + kv.second); + } + return attr_map; +} + +Value GetIrValueByDrrTensor(const Tensor& tensor, + const MatchContextImpl& res_match_ctx) { + if (tensor.is_none()) { + return Value{}; + } + return res_match_ctx.GetIrValue(tensor.name()).get(); +} + +std::vector GetIrValuesByDrrTensors( + const std::vector& tensors, + const MatchContextImpl& res_match_ctx) { + std::vector ir_values; + ir_values.reserve(tensors.size()); + for (const auto* tensor : tensors) { + ir_values.push_back(GetIrValueByDrrTensor(*tensor, res_match_ctx)); + } + return ir_values; +} + +void BindIrOutputs(const OpCall& op_call, + pir::Operation* op, + MatchContextImpl* match_ctx) { + for (size_t i = 0; i < op_call.outputs().size(); ++i) { + std::shared_ptr ir_value = nullptr; + if (op->result(i)) { + ir_value = std::make_shared(op->result(i)); + } + match_ctx->BindIrValue(op_call.outputs()[i]->name(), ir_value); + } +} + +pir::Operation* CreateOperation(const OpCall& op_call, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter, // NOLINT + MatchContextImpl* res_match_ctx) { + VLOG(6) << "Drr create [" << op_call.name() << "] op..."; + const auto& inputs = op_call.inputs(); + std::vector ir_values = + GetIrValuesByDrrTensors(inputs, *res_match_ctx); + pir::Operation* op = OperationFactory::Instance().CreateOperation( + op_call.name(), + ir_values, + CreateAttributeMap(op_call, src_match_ctx), + rewriter); + BindIrOutputs(op_call, op, res_match_ctx); + VLOG(6) << "Drr create [" << op_call.name() << "] op done."; + return op; +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/ir_operation_factory.h b/paddle/fluid/pir/drr/ir_operation_factory.h new file mode 100644 index 0000000000000..adc76efb99b2d --- /dev/null +++ b/paddle/fluid/pir/drr/ir_operation_factory.h @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/match_context_impl.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" + +namespace pir { +namespace drr { + +class OperationFactory { + public: + static OperationFactory& Instance() { + static OperationFactory operation_factory; + return operation_factory; + } + + using operation_create_fn = + std::function&, + const pir::AttributeMap&, + pir::PatternRewriter&)>; + + void RegisterOperationCreator(const std::string& op_name, + const operation_create_fn& create_fn) { + op_creator_map.emplace(op_name, create_fn); + } + + pir::Operation* CreateOperation( + const std::string& op_name, + const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) const { // NOLINT + auto iter = op_creator_map.find(op_name); + PADDLE_ENFORCE_NE( + iter, + op_creator_map.end(), + phi::errors::NotFound( + "The op to be created is not found." + "Suggest fix: Place check if the op named %s has been registered.", + op_name)); + return iter->second(inputs, attrs, rewriter); + } + + private: + OperationFactory() { + RegisterPdOpGeneratedOpCreator(); +#ifdef PADDLE_WITH_CINN + RegisterCinnOpGeneratedOpCreator(); +#endif + RegisterManualOpCreator(); + } + + void RegisterManualOpCreator(); + void RegisterPdOpGeneratedOpCreator(); +#ifdef PADDLE_WITH_CINN + void RegisterCinnOpGeneratedOpCreator(); +#endif + std::unordered_map op_creator_map; +}; + +pir::Operation* CreateOperation(const OpCall& op_call, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter, // NOLINT + MatchContextImpl* res_match_ctx); + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/ir_value.h b/paddle/fluid/pir/drr/ir_value.h new file mode 100644 index 0000000000000..907df9dfd24eb --- /dev/null +++ b/paddle/fluid/pir/drr/ir_value.h @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/drr/api/tensor_interface.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/value.h" + +namespace pir { +namespace drr { + +class IrShape { + public: + explicit IrShape(const phi::DDim& dims) : dims_(dims) {} + + bool operator==(const IrShape& other) const { return dims_ == other.dims_; } + + int size() const { return dims_.size(); } + + int64_t at(int idx) const { return dims_.at(idx); } + + private: + const phi::DDim dims_; +}; + +class IrDtype { + public: + explicit IrDtype(pir::Type dtype) : dtype_(dtype) {} + + bool operator==(IrDtype other) const { return dtype_ == other.dtype_; } + + private: + const pir::Type dtype_; +}; + +class IrValue : public TensorInterface { + public: + explicit IrValue(const pir::Value& value) + : value_(value), + shape_((value && value.type() && + value.type().dyn_cast()) + ? value.type() + .dyn_cast() + .dims() + : phi::DDim{}), + dtype_((value && value.type() && + value.type().dyn_cast()) + ? value.type() + .dyn_cast() + .dtype() + : pir::Type{}) {} + + ShapeInterface Shape() const override { return ShapeInterface(&shape_); } + DtypeInterface Dtype() const override { return DtypeInterface(&dtype_); } + + const Value& get() const { return value_; } + + private: + const Value value_; + const IrShape shape_; + const IrDtype dtype_; +}; + +class IrAttr; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/match_context_impl.h b/paddle/fluid/pir/drr/match_context_impl.h new file mode 100644 index 0000000000000..37b06914cd2bd --- /dev/null +++ b/paddle/fluid/pir/drr/match_context_impl.h @@ -0,0 +1,134 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/api/tensor_interface.h" +#include "paddle/fluid/pir/drr/attr_type_uilts.h" +#include "paddle/fluid/pir/drr/ir_operation.h" +#include "paddle/fluid/pir/drr/ir_value.h" +#include "paddle/pir/core/builtin_attribute.h" + +namespace pir { +namespace drr { + +class MatchContextImpl final { + public: + MatchContextImpl() = default; + ~MatchContextImpl() = default; + + const TensorInterface& Tensor(const std::string& tensor_name) const { + PADDLE_ENFORCE_NE( + tensor_map_.count(tensor_name), + 0, + phi::errors::NotFound( + "Not found tensor." + "The Drr tensor [%s] must exist in pattern graph to be obtained.", + tensor_name)); + return *tensor_map_.at(tensor_name); + } + + const IrOperation& Operation(const OpCall* op_call) const { + PADDLE_ENFORCE_NE( + operation_map_.count(op_call), + 0, + phi::errors::NotFound("Not found operation." + "The Drr operation [%s] must exist in the " + "pattern graph to be obtained.", + op_call->name())); + return *operation_map_.at(op_call); + } + + template + T Attr(const std::string& attr_name) const { + return IrAttrTypeCast::To(GetIrAttr(attr_name)); + } + + const IrValue& GetIrValue(const std::string& tensor_name) const { + auto iter = tensor_map_.find(tensor_name); + PADDLE_ENFORCE_NE( + iter, + tensor_map_.end(), + phi::errors::NotFound("Not found tensor." + "The Drr tensor [%s] is not found in the map, " + "unable to obtain the corresponding IrValue.", + tensor_name)); + return *iter->second; + } + + pir::Attribute GetIrAttr(const std::string& attr_name) const { + auto iter = attr_map_.find(attr_name); + PADDLE_ENFORCE_NE( + iter, + attr_map_.end(), + phi::errors::NotFound("Not found attr." + "The Drr attr [%s] is not found in the map, " + "unable to obtain the corresponding Attribute.", + attr_name)); + return iter->second; + } + + const std::unordered_map>& + operation_map() const { + return operation_map_; + } + + const std::unordered_map& attr_map() const { + return attr_map_; + } + + const std::unordered_map>& tensor_map() + const { + return tensor_map_; + } + + void BindIrValue(const std::string& value_name, + const std::shared_ptr& value) { + tensor_map_.emplace(value_name, value); + } + + void BindIrOperation(const OpCall* op_call, + const std::shared_ptr& op) { + operation_map_.emplace(op_call, op); + const auto& attrs = op_call->attributes(); + for (const auto& kv : attrs) { + std::visit( + [&](auto&& arg) { + if constexpr (std::is_same_v, + NormalAttribute>) { + BindIrAttr(arg.name(), op->get()->attribute(kv.first)); + } + }, + kv.second); + } + } + + private: + void BindIrAttr(const std::string& attr_name, pir::Attribute attr) { + attr_map_.emplace(attr_name, attr); + } + + std::unordered_map> tensor_map_; + std::unordered_map> + operation_map_; + std::unordered_map attr_map_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/pattern_graph.cc b/paddle/fluid/pir/drr/pattern_graph.cc new file mode 100644 index 0000000000000..7d732b6576f68 --- /dev/null +++ b/paddle/fluid/pir/drr/pattern_graph.cc @@ -0,0 +1,241 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/pattern_graph.h" + +#include + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/phi/core/enforce.h" + +namespace pir { +namespace drr { + +const drr::OpCall &PatternGraph::AddOpCall( + const std::shared_ptr &op_call) { + owned_op_call_.push_back(op_call); + for (const auto *input : op_call->inputs()) { + const auto &tensor_name = input->name(); + PADDLE_ENFORCE_NE(id2owned_tensor_.count(tensor_name), + 0, + phi::errors::NotFound("Not found tensor." + "The intput tensor [%s] must exist " + "in pattern graph to be obtained.", + tensor_name)); + id2owned_tensor_.at(tensor_name)->AddConsumer(op_call.get()); + + if (input->producer() == nullptr) { + input_tensors_.insert(tensor_name); + } + if (output_tensors_.find(tensor_name) != output_tensors_.end()) { + output_tensors_.erase(tensor_name); + } + } + for (auto &output : op_call->outputs()) { + const auto &out_tensor_name = output->name(); + PADDLE_ENFORCE_NE(id2owned_tensor_.count(out_tensor_name), + 0, + phi::errors::NotFound("Not found tensor." + "The output tensor [%s] must exist " + "in pattern graph to be obtained.", + out_tensor_name)); + id2owned_tensor_[output->name()]->set_producer(op_call.get()); + } + return *owned_op_call_.back(); +} + +drr::Tensor &PatternGraph::AddTensor( + const std::shared_ptr &tensor) { + if (id2owned_tensor_.find(tensor->name()) == id2owned_tensor_.end()) { + id2owned_tensor_[tensor->name()] = tensor; + output_tensors_.insert(tensor->name()); + } + return *id2owned_tensor_[tensor->name()]; +} + +drr::Tensor &PatternGraph::AddTmpTensor( + const std::shared_ptr &tensor) { + PADDLE_ENFORCE_EQ(id2owned_tensor_.count(tensor->name()), + 0, + phi::errors::AlreadyExists( + "Tensor already exists." + "The tensor [%s] must not exist in pattern graph.", + tensor->name())); + id2owned_tensor_[tensor->name()] = tensor; + output_tensors_.insert(tensor->name()); + return *id2owned_tensor_[tensor->name()]; +} + +void PatternGraph::UpdateTmpTensor(const std::string &tmp_tensor_name, + const std::string &new_tensor_name) { + if (input_tensors_.count(tmp_tensor_name)) { + input_tensors_.erase(tmp_tensor_name); + input_tensors_.insert(new_tensor_name); + } + + output_tensors_.erase(new_tensor_name); + if (output_tensors_.count(tmp_tensor_name)) { + output_tensors_.erase(tmp_tensor_name); + output_tensors_.insert(new_tensor_name); + } + + auto tmp_tensor = id2owned_tensor_[tmp_tensor_name]; + id2owned_tensor_.erase(tmp_tensor_name); + tmp_tensor->set_name(new_tensor_name); + id2owned_tensor_[new_tensor_name] = tmp_tensor; +} + +size_t PatternGraph::CountOfOpCalls() const { return owned_op_call_.size(); } + +OpCall *SourcePatternGraph::AnchorNode() const { + for (const auto &output_tensor : output_tensors_) { + OpCall *output_op_candidate = + id2owned_tensor_.at(output_tensor)->producer(); + if (std::all_of(output_op_candidate->outputs().begin(), + output_op_candidate->outputs().end(), + [this](const Tensor *output) -> bool { + return this->output_tensors().count(output->name()); + })) + return output_op_candidate; + } + IR_THROW("Unable to find a valid anchor"); +} + +std::unordered_set SourcePatternGraph::OutputNodes() const { + std::unordered_set output_op_set; + for (const auto &output_tensor : output_tensors_) { + OpCall *output_op_candidate = + id2owned_tensor_.at(output_tensor)->producer(); + if (std::all_of(output_op_candidate->outputs().begin(), + output_op_candidate->outputs().end(), + [this](const Tensor *output) -> bool { + return this->output_tensors().count(output->name()); + })) + output_op_set.insert(output_op_candidate); + } + return output_op_set; +} + +void ResultPatternGraph::AssignTensor(const Tensor &from, const Tensor &to) { + if (to.producer() == nullptr) { + input_tensors_.insert(to.name()); + } + output_tensors_.erase(to.name()); + PADDLE_ENFORCE_EQ( + output_tensors_.count(from.name()), + 1, + phi::errors::PreconditionNotMet("The Tensor (%s) which be assigned must " + "be the output of result pattern graph.", + from.name())); + tensor_assign_map_[from.name()] = to.name(); +} + +void GraphTopo::WalkGraphNodesTopoOrder( + const std::function &VisitNode) const { + // graph data + const std::unordered_set &inputs_tensor = + graph_->input_tensors(); + const std::unordered_map> + &id2owned_tensor = graph_->id2owend_tensor(); + const std::vector> &owend_opcall = + graph_->owned_op_call(); + + std::queue opcall_queue; + std::unordered_map> + opcall_dependent; + + // init opcall_dependent + for (const std::shared_ptr &opcall_sptr : owend_opcall) { + if (opcall_sptr.get()->inputs().empty()) { // opcall inputs is empty + opcall_queue.push(opcall_sptr.get()); + } else { + for (const auto &pre_depd_tensor : opcall_sptr.get()->inputs()) { + opcall_dependent[opcall_sptr.get()].insert(pre_depd_tensor->name()); + } + } + } + + // init queue + for (const auto &tensor_name : inputs_tensor) { + PADDLE_ENFORCE_NE(id2owned_tensor.count(tensor_name), + 0, + phi::errors::NotFound("Not found tensor." + "The input tensor [%s] must exists " + "in pattern graph to be obtained.", + tensor_name)); + for (const auto &tensor_comsumer : + id2owned_tensor.at(tensor_name).get()->consumers()) { + opcall_dependent[tensor_comsumer].erase(tensor_name); + if (opcall_dependent[tensor_comsumer].empty()) { + opcall_queue.push(tensor_comsumer); + } + } + } + + while (!opcall_queue.empty()) { + const OpCall *opcall = opcall_queue.front(); + opcall_queue.pop(); + VisitNode(*opcall); + + // update opcall_dependent + for (const auto &output_tensor : opcall->outputs()) { + for (const auto &tensor_comsumer : output_tensor->consumers()) { + opcall_dependent[tensor_comsumer].erase(output_tensor->name()); + if (opcall_dependent[tensor_comsumer].empty()) { + opcall_queue.push(tensor_comsumer); + } + } + } + } +} + +std::ostream &operator<<(std::ostream &os, const PatternGraph &pattern_graph) { + os << "\nAll Tensors:\n"; + for (const auto &kv : pattern_graph.id2owend_tensor()) { + os << " " << kv.first; + } + os << "\n\n"; + + os << "Input Tensors:\n"; + for (const auto &tensor_name : pattern_graph.input_tensors()) { + os << " " << tensor_name; + } + os << "\n\n"; + + os << "Output Tensors:\n"; + for (const auto &tensor_name : pattern_graph.output_tensors()) { + os << " " << tensor_name; + } + os << "\n\n"; + + for (const auto &op_call : pattern_graph.owned_op_call()) { + os << " " << op_call->name() << " : "; + os << "inputs[ "; + for (const auto *input : op_call->inputs()) { + os << input->name() << " "; + } + os << "], "; + + os << "outputs[ "; + for (const auto &output : op_call->outputs()) { + os << output->name() << " "; + } + os << "]\n"; + } + os << "\n"; + return os; +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/pattern_graph.h b/paddle/fluid/pir/drr/pattern_graph.h new file mode 100644 index 0000000000000..63bd60eadf17f --- /dev/null +++ b/paddle/fluid/pir/drr/pattern_graph.h @@ -0,0 +1,108 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace pir { +namespace drr { + +class Constraint; +class MatchContext; +class OpCall; +class Tensor; + +class PatternGraph { + public: + virtual ~PatternGraph() {} + + const drr::OpCall& AddOpCall(const std::shared_ptr& op_call); + + drr::Tensor& AddTensor(const std::shared_ptr& tensor); + + drr::Tensor& AddTmpTensor(const std::shared_ptr& tensor); + + void UpdateTmpTensor(const std::string& tmp_tensor_name, + const std::string& new_tensor_name); + + const std::unordered_set& input_tensors() const { + return input_tensors_; + } + + const std::unordered_set& output_tensors() const { + return output_tensors_; + } + + size_t CountOfOpCalls() const; + + const std::vector>& owned_op_call() const { + return owned_op_call_; + } + + const std::unordered_map>& + id2owend_tensor() const { + return id2owned_tensor_; + } + + protected: + std::unordered_map> id2owned_tensor_; + std::vector> owned_op_call_; + std::unordered_set input_tensors_; + std::unordered_set output_tensors_; +}; + +std::ostream& operator<<(std::ostream& os, const PatternGraph& pattern_graph); + +class SourcePatternGraph : public PatternGraph { + public: + OpCall* AnchorNode() const; + + std::unordered_set OutputNodes() const; + + private: + friend class DrrPatternContext; +}; + +class ResultPatternGraph : public PatternGraph { + public: + void AssignTensor(const Tensor& from, const Tensor& to); + + const std::unordered_map& tensor_assign_map() + const { + return tensor_assign_map_; + } + + private: + std::unordered_map tensor_assign_map_; +}; + +class GraphTopo { + public: + explicit GraphTopo(const PatternGraph* graph) : graph_(graph) {} + + void WalkGraphNodesTopoOrder( + const std::function& VisitNode) const; + + private: + const PatternGraph* graph_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/transforms/CMakeLists.txt b/paddle/fluid/pir/transforms/CMakeLists.txt index e1903c903de34..c38140237e74e 100644 --- a/paddle/fluid/pir/transforms/CMakeLists.txt +++ b/paddle/fluid/pir/transforms/CMakeLists.txt @@ -1,26 +1,16 @@ -cc_library( - transform_general_functions - SRCS transform_general_functions.cc - DEPS pd_op_dialect_core) - -cc_library( - pd_op_to_kernel_pass - SRCS pd_op_to_kernel_pass.cc - DEPS pd_kernel_dialect pd_op_dialect pd_op_dialect_utils) - -cc_library( - _constant_folding_pass - SRCS constant_folding_pass.cc - DEPS standalone_executor pd_op_to_kernel_pass transform_general_functions) +file(GLOB_RECURSE transforms_srcs "*.cc") +if(NOT WITH_CINN) + list(REMOVE_ITEM transforms_srcs + ${CMAKE_CURRENT_SOURCE_DIR}/build_cinn_pass.cc) +endif() -cc_library( - pd_inplace_pass - SRCS inplace_pass.cc - DEPS pd_op_dialect_core op_yaml_info_parser) +set(transforms_deps drr op_dialect op_dialect_vjp standalone_executor pir) if(WITH_CINN) - cc_library( - pd_build_cinn_pass - SRCS build_cinn_pass.cc - DEPS pd_op_dialect cinn_op_dialect pir_control_flow cinnapi) + set(transforms_deps ${transforms_deps} cinn_op_dialect cinnapi) endif() + +cc_library( + transform + SRCS ${transforms_srcs} + DEPS ${transforms_deps}) diff --git a/paddle/fluid/pir/transforms/build_cinn_pass.cc b/paddle/fluid/pir/transforms/build_cinn_pass.cc index 245e26cabad7e..bae676086171f 100644 --- a/paddle/fluid/pir/transforms/build_cinn_pass.cc +++ b/paddle/fluid/pir/transforms/build_cinn_pass.cc @@ -14,7 +14,6 @@ #include "paddle/fluid/pir/transforms/build_cinn_pass.h" -#include #include #include #include @@ -27,12 +26,12 @@ #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" -#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" #include "paddle/cinn/frontend/op_mapper_registry.h" -#include "paddle/cinn/hlir/framework/new_ir/utils.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/utils/flags.h" PD_DECLARE_string(allow_cinn_ops); @@ -43,7 +42,7 @@ using GroupOpsVec = std::vector; // The delim(`;`) that is used to split the FLAGS_allow_cinn_ops // & FLAGS_deny_cinn_ops. constexpr char kDelim[] = ";"; -using CompatibleInfo = cinn::hlir::framework::newir::CompatibleInfo; +using CompatibleInfo = cinn::hlir::framework::pir::CompatibleInfo; // OpTransInfo contains informations used to detect subgraphs // supported by the CINN compiler. @@ -136,7 +135,12 @@ bool IsSupportCinn(pir::Operation* op) { VLOG(4) << "The allowed Cinn Ops: " << GetDebugInfo(allow_ops); VLOG(4) << "The denied Cinn Ops: " << GetDebugInfo(deny_ops); // Strip the dialect, like pd_op.abs -> abs - const auto& op_name = CompatibleInfo::OpName(*op); + const auto op_name = CompatibleInfo::OpName(*op); + if (CompatibleInfo::IsSupportCinn(*op)) { + VLOG(4) << "Found special supported op for CINN: " << op_name; + return true; + } + bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(op_name) != nullptr; @@ -212,6 +216,32 @@ std::vector InverselyTopologicalSort(pir::Block* block) { struct SubGraph; using SubGraphPtr = std::shared_ptr; +std::vector GetProducerOpsReverseSort( + pir::Operation* op, + const std::unordered_map& op2id) { + std::unordered_set producers; + + std::vector vec_res; + for (auto& operand : op->operands()) { + auto* source_op = operand.source().dyn_cast().owner(); + if (!producers.count(source_op)) { + producers.insert(source_op); + PADDLE_ENFORCE( + op2id.count(source_op), + phi::errors::PreconditionNotMet("source op MUST in op2id map")); + vec_res.emplace_back(source_op); + } + } + + std::sort(vec_res.begin(), + vec_res.end(), + [&op2id](pir::Operation* a, pir::Operation* b) { + return op2id.at(a) > op2id.at(b); + }); + + return vec_res; +} + std::unordered_set GetProducerOps(pir::Operation* op) { std::unordered_set producers; @@ -267,7 +297,12 @@ class CinnSubgraphDetector { using OpClassifier = std::function; CinnSubgraphDetector(pir::Block* block, const OpClassifier& classifier) - : block_(block), op_classifier_(classifier) {} + : block_(block), op_classifier_(classifier) { + sort_ops_ = InverselyTopologicalSort(block_); + for (size_t i = 0; i < sort_ops_.size(); ++i) { + op2id_[sort_ops_[i]] = i; + } + } std::vector operator()() { DoOpFusion(); @@ -287,7 +322,6 @@ class CinnSubgraphDetector { protected: // Do Op Fusion void DoOpFusion() { - sort_ops_ = InverselyTopologicalSort(block_); // do fusion for (auto* op : sort_ops_) { auto subgraph = subgraph_map_.count(op) @@ -296,7 +330,7 @@ class CinnSubgraphDetector { if (!subgraph_map_.count(op)) { subgraph_map_[op] = subgraph; } - auto producers = GetProducerOps(op); + auto producers = GetProducerOpsReverseSort(op, op2id_); for (auto* producer : producers) { if (op_classifier_(producer) != subgraph->substitute) { @@ -315,8 +349,10 @@ class CinnSubgraphDetector { continue; } // fuse producer to sub-graph - subgraph->Insert(producer); - subgraph_map_[producer] = subgraph; + if (!subgraph->op_set.count(producer)) { + subgraph->Insert(producer); + subgraph_map_[producer] = subgraph; + } } } } @@ -516,30 +552,36 @@ class CinnSubgraphDetector { OpClassifier op_classifier_; std::vector sort_ops_; + std::unordered_map op2id_; std::vector subgraph_list_; std::unordered_map subgraph_map_; }; std::vector AnalysisOutputs(GroupOpsVec& group_ops) { // NOLINT - std::set inputs; - std::set outputs; + // Get output by ud chain + std::unordered_set used_by_outside; + std::unordered_set op_set; + for (auto* op : group_ops) { - VLOG(4) << "AnalysisOutputs from " << op->name(); - for (auto& operand : op->operands()) { - inputs.emplace(operand.source()); - } - for (auto& result : op->results()) { - outputs.emplace(result); + op_set.insert(op); + } + + std::vector vec_res; + for (auto* op : group_ops) { + for (size_t i = 0; i < op->num_results(); ++i) { + auto result = op->result(i); + + for (auto use_iter = result.use_begin(); use_iter != result.use_end(); + ++use_iter) { + if (!op_set.count(use_iter->owner())) { + vec_res.push_back(result); + break; + } + } } } - std::vector results; - std::set_symmetric_difference(outputs.begin(), - outputs.end(), - inputs.begin(), - inputs.end(), - std::back_inserter(results)); - VLOG(3) << "Outputs size for GroupOp " << results.size(); - return results; + + return vec_res; } void ReplaceWithGroupOp(pir::Block* block, @@ -551,7 +593,6 @@ void ReplaceWithGroupOp(pir::Block* block, // step 1: Ensure the insert point and create GroupOp here. auto* laste_input_op = group_ops.back(); builder.SetInsertionPointAfter(laste_input_op); - // TODO(Aurelius84): Need confirm how many YieldOps we need. std::vector output_types; std::vector outputs = AnalysisOutputs(group_ops); for (auto& value : outputs) { @@ -563,23 +604,25 @@ void ReplaceWithGroupOp(pir::Block* block, for (auto* op : group_ops) { op->MoveTo(group_block, group_block->begin()); } - // step 3: Insert YieldOp for outputs - builder.SetInsertionPointToEnd(group_block); - builder.Build<::pir::YieldOp>(outputs); - // step 4: Replace outputs of inner ops + + // step 3: Replace outputs of inner ops std::vector group_outs = new_group_op->results(); for (size_t i = 0; i < outputs.size(); ++i) { outputs[i].ReplaceAllUsesWith(group_outs[i]); } + + // step 4: Insert YieldOp for outputs + builder.SetInsertionPointToEnd(group_block); + builder.Build<::pir::YieldOp>(outputs); } class BuildCinnPass : public pir::Pass { public: - BuildCinnPass() : pir::Pass("BuildCinnPass", /*opt_level=*/1) {} + BuildCinnPass() : pir::Pass("build_cinn_pass", /*opt_level=*/1) {} void Run(pir::Operation* op) override { auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "InplacePass should run on module op."); + IR_ENFORCE(module_op, "build_cinn_pass should run on module op."); auto* block = module_op.block(); std::vector groups = diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 3b40960373a2f..4e36f1df9defa 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -33,6 +33,7 @@ #include "paddle/phi/core/enforce.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/op_result.h" #include "paddle/pir/core/operation.h" #include "paddle/pir/core/parameter.h" #include "paddle/pir/core/program.h" @@ -46,22 +47,22 @@ namespace { class ConstantFoldingPattern : public pir::RewritePattern { public: ConstantFoldingPattern(pir::IrContext* context, + paddle::framework::Scope* scope, pir::PatternBenefit benefit = 1, const std::vector& generated_names = {}) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names) { - } + : RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names), + scope_(scope) {} bool Match(pir::Operation* op) const override { // TODO(liuyuanle): Use trait to improve robustness. - if (op->dyn_cast() || - op->dyn_cast() || - op->dyn_cast()) + if (op->isa() || op->isa() || + op->isa() || + op->isa()) return false; // Inputs must come from get parameter op. for (uint32_t i = 0; i < op->num_operands(); ++i) - if (pir::GetDefiningOpForInput(op, i)->dyn_cast() == - nullptr) + if (!pir::GetDefiningOpForInput(op, i)->isa()) return false; return true; } @@ -101,7 +102,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { paddle::framework::InterpreterCore core(phi::CPUPlace{}, fetch_var_names, kernel_program->block(), - &scope_, + scope_, exe_config_); paddle::framework::FetchList fetch_list = core.Run({}); @@ -118,7 +119,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { "@constant_folding_pass@_" + std::to_string(suffix_++); exe_config_.skip_gc_vars.insert(param_name); - auto* param_var = scope_.Var(param_name); + auto* param_var = scope_->Var(param_name); auto* param_tensor = param_var->GetMutable(); *param_tensor = out_tensor; program->SetParameter(param_name, std::move(parameter)); @@ -150,7 +151,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { program->SetParameter(param_name, std::make_unique(*param)); - auto* param_var = scope_.FindVar(param_name); + auto* param_var = scope_->FindVar(param_name); PADDLE_ENFORCE_NOT_NULL( param_var, phi::errors::InvalidArgument("Parameter var not in scope.")); @@ -163,7 +164,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { // prepare op outputs std::vector output_types; for (uint32_t i = 0; i < op->num_results(); i++) { - output_types.push_back(op->result(i).type()); + output_types.push_back(op->result_type(i)); } auto* temp_op = @@ -185,19 +186,18 @@ class ConstantFoldingPattern : public pir::RewritePattern { } private: + paddle::framework::Scope* scope_{nullptr}; inline static size_t suffix_{0}; - inline static paddle::framework::Scope scope_{}; inline static paddle::framework::interpreter::ExecutionConfig exe_config_{}; }; class ConstantFoldingPass : public pir::Pass { public: - // TODO(liuyuanle): Naming convention for pass. - ConstantFoldingPass() : pir::Pass("ConstantFoldingPass", 1) {} + ConstantFoldingPass() : pir::Pass("constant_folding_pass", 1) {} bool Initialize(pir::IrContext* context) override { pir::RewritePatternSet ps(context); - ps.Add(context); + ps.Add(context, &scope_); patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); return true; } @@ -215,6 +215,7 @@ class ConstantFoldingPass : public pir::Pass { private: pir::FrozenRewritePatternSet patterns_; + paddle::framework::Scope scope_; }; } // namespace diff --git a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc new file mode 100644 index 0000000000000..7535ddeb513db --- /dev/null +++ b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" + +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace { + +class DeadCodeEliminationPattern : public pir::RewritePattern { + public: + DeadCodeEliminationPattern( + pir::IrContext* context, + pir::PatternBenefit benefit = 1, + const std::vector& generated_names = {}) + : RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names) { + } + + bool Match(pir::Operation* op) const override { + if (op->isa() || + op->isa()) + return false; + + return op->use_empty(); + } + + void Rewrite(pir::Operation* op, + pir::PatternRewriter& rewriter) const override { // NOLINT + if (op->dyn_cast()) { + // Delete parameter from program. + pir::GetParameterOp get_parameter_op = + op->dyn_cast(); + get_parameter_op->GetParentProgram()->parameters().erase( + get_parameter_op->attributes() + .at(get_parameter_op.attributes_name[0]) + .dyn_cast() + .AsString()); + } + rewriter.EraseOp(op); + } +}; + +class DeadCodeEliminationPass : public pir::Pass { + public: + DeadCodeEliminationPass() : pir::Pass("dead_code_elimination_pass", 0) {} + + bool Initialize(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + ps.Add(context); + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation* op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateDeadCodeEliminationPass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(dead_code_elimination_pass, DeadCodeEliminationPass); diff --git a/paddle/pir/transforms/dead_code_elimination_pass.h b/paddle/fluid/pir/transforms/dead_code_elimination_pass.h similarity index 100% rename from paddle/pir/transforms/dead_code_elimination_pass.h rename to paddle/fluid/pir/transforms/dead_code_elimination_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc new file mode 100644 index 0000000000000..0bd8c5e29e7ef --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc @@ -0,0 +1,253 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h" + +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace { + +class MultiHeadMatmulFusePattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // + // Source Pattern. + // + pir::drr::SourcePattern src = ctx->SourcePattern(); + // The first path to matmul with scale (q). + const auto &matmul_1 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_1_transpose_x")}, + {"transpose_y", src.Attr("matmul_1_transpose_y")}}); + src.Tensor("matmul_1_out") = + matmul_1(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_1_in_2")); + const auto &add_1 = src.Op("pd_op.add"); + src.Tensor("add_1_out") = + add_1(src.Tensor("matmul_1_out"), src.Tensor("add_1_in_2")); + const auto &full_int_array_1 = + src.Op("pd_op.full_int_array", + {{"value", src.Attr("full_int_array_1_value")}}); + const auto &reshape_1 = src.Op("pd_op.reshape"); + reshape_1({&src.Tensor("add_1_out"), &full_int_array_1()}, + {&src.Tensor("reshape_1_out"), &src.Tensor("reshape_1_xshape")}); + const auto &transpose_1 = src.Op("pd_op.transpose"); + src.Tensor("transpose_1_out") = transpose_1(src.Tensor("reshape_1_out")); + const auto &full_1 = + src.Op("pd_op.full", {{"value", src.Attr("full_1_value")}}); + const auto &scale = src.Op("pd_op.scale"); + src.Tensor("scale_out") = scale(src.Tensor("transpose_1_out"), full_1()); + + // The second path to matmul (k). + const auto &matmul_2 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_2_transpose_x")}, + {"transpose_y", src.Attr("matmul_2_transpose_y")}}); + src.Tensor("matmul_2_out") = + matmul_2(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_2_in_2")); + const auto &add_2 = src.Op("pd_op.add"); + src.Tensor("add_2_out") = + add_2(src.Tensor("matmul_2_out"), src.Tensor("add_2_in_2")); + const auto &full_int_array_2 = src.Op("pd_op.full_int_array"); + const auto &reshape_2 = src.Op("pd_op.reshape"); + reshape_2({&src.Tensor("add_2_out"), &full_int_array_2()}, + {&src.Tensor("reshape_2_out"), &src.Tensor("reshape_2_xshape")}); + const auto &transpose_2 = src.Op("pd_op.transpose"); + src.Tensor("transpose_2_out") = transpose_2(src.Tensor("reshape_2_out")); + + // The third path to matmul (v). + const auto &matmul_3 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_3_transpose_x")}, + {"transpose_y", src.Attr("matmul_3_transpose_y")}}); + src.Tensor("matmul_3_out") = + matmul_3(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_3_in_2")); + const auto &add_3 = src.Op("pd_op.add"); + src.Tensor("add_3_out") = + add_3(src.Tensor("matmul_3_out"), src.Tensor("add_3_in_2")); + const auto &full_int_array_3 = src.Op("pd_op.full_int_array"); + const auto &reshape_3 = src.Op("pd_op.reshape"); + reshape_3({&src.Tensor("add_3_out"), &full_int_array_3()}, + {&src.Tensor("reshape_3_out"), &src.Tensor("reshape_3_xshape")}); + const auto &transpose_3 = src.Op("pd_op.transpose"); + src.Tensor("transpose_3_out") = transpose_3(src.Tensor("reshape_3_out")); + + // softmax(qk)v + const auto &matmul_4 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_4_transpose_x")}, + {"transpose_y", src.Attr("matmul_4_transpose_y")}}); + src.Tensor("matmul_4_out") = + matmul_4(src.Tensor("scale_out"), src.Tensor("transpose_2_out")); + const auto &add_4 = src.Op("pd_op.add"); + src.Tensor("add_4_out") = + add_4(src.Tensor("matmul_4_out"), src.Tensor("add_4_in_2")); + const auto &softmax = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}}); + src.Tensor("softmax_out") = softmax(src.Tensor("add_4_out")); + const auto &matmul_5 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_5_transpose_x")}, + {"transpose_y", src.Attr("matmul_5_transpose_y")}}); + src.Tensor("matmul_5_out") = + matmul_5(src.Tensor("softmax_out"), src.Tensor("transpose_3_out")); + const auto &transpose_4 = src.Op("pd_op.transpose"); + src.Tensor("transpose_4_out") = transpose_4(src.Tensor("matmul_5_out")); + const auto &full_int_array_4 = src.Op("pd_op.full_int_array"); + const auto &reshape_4 = src.Op("pd_op.reshape"); + reshape_4({&src.Tensor("transpose_4_out"), &full_int_array_4()}, + {&src.Tensor("reshape_4_out"), &src.Tensor("reshape_4_xshape")}); + + // + // Constraints. + // + src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + + bool matmul_1_transpose_x = match_ctx.Attr("matmul_1_transpose_x"); + bool matmul_1_transpose_y = match_ctx.Attr("matmul_1_transpose_y"); + if (matmul_1_transpose_x || matmul_1_transpose_y) return false; + + bool matmul_2_transpose_x = match_ctx.Attr("matmul_2_transpose_x"); + bool matmul_2_transpose_y = match_ctx.Attr("matmul_2_transpose_y"); + if (matmul_2_transpose_x || matmul_2_transpose_y) return false; + + bool matmul_3_transpose_x = match_ctx.Attr("matmul_3_transpose_x"); + bool matmul_3_transpose_y = match_ctx.Attr("matmul_3_transpose_y"); + if (matmul_3_transpose_x || matmul_3_transpose_y) return false; + + bool matmul_4_transpose_x = match_ctx.Attr("matmul_4_transpose_x"); + bool matmul_4_transpose_y = match_ctx.Attr("matmul_4_transpose_y"); + if (matmul_4_transpose_x || !matmul_4_transpose_y) return false; + + bool matmul_5_transpose_x = match_ctx.Attr("matmul_5_transpose_x"); + bool matmul_5_transpose_y = match_ctx.Attr("matmul_5_transpose_y"); + if (matmul_5_transpose_x || matmul_5_transpose_y) return false; + + return true; + }); + + // + // Result Pattern. + // + pir::drr::ResultPattern res = src.ResultPattern(); + // W combine. + const auto &combine_1 = res.Op("builtin.combine"); + combine_1({&res.Tensor("matmul_1_in_2"), + &res.Tensor("matmul_2_in_2"), + &res.Tensor("matmul_3_in_2")}, + {&res.Tensor("combine_1_out")}); + const auto &concat_axis = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> int { return 0; }); + const auto &concat_1 = res.Op("pd_op.concat", {{"axis", concat_axis}}); + res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); + const auto &reshape_5_shape = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> std::vector { + auto matmul_1_in_2 = match_ctx.Tensor("matmul_1_in_2").Shape(); + return {-1, 3, matmul_1_in_2.at(1)}; + }); + const auto &reshape_5 = + res.Op("pd_op.reshape", {{"shape", reshape_5_shape}}); + reshape_5({&res.Tensor("concat_1_out")}, + {&res.Tensor("reshape_5_out"), &res.NoneTensor()}); + + // Bias combine. + const auto &combine_2 = res.Op("builtin.combine"); + combine_2({&res.Tensor("add_1_in_2"), + &res.Tensor("add_2_in_2"), + &res.Tensor("add_3_in_2")}, + {&res.Tensor("combine_2_out")}); + const auto &concat_2 = res.Op("pd_op.concat", {{"axis", concat_axis}}); + res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out")); + const auto &reshape_6_shape = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> std::vector { + return {3, -1}; + }); + const auto &reshape_6 = + res.Op("pd_op.reshape", {{"shape", reshape_6_shape}}); + reshape_6({&res.Tensor("concat_2_out")}, + {&res.Tensor("reshape_6_out"), &res.NoneTensor()}); + + const auto &head_number = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> int { + const auto &full_int_array_1_value = + match_ctx.Attr>("full_int_array_1_value"); + return full_int_array_1_value.at(2); + }); + const auto &alpha = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + return match_ctx.Attr("full_1_value"); + }); + const auto &multihead_matmul = res.Op( + "pd_op.multihead_matmul", + {{"transpose_q", res.Attr([](const pir::drr::MatchContext &match_ctx) { + return false; + })}, + {"transpose_k", res.Attr([](const pir::drr::MatchContext &match_ctx) { + return true; + })}, + {"transpose_v", res.Attr([](const pir::drr::MatchContext &match_ctx) { + return false; + })}, + {"head_number", head_number}, + {"alpha", alpha}}); + multihead_matmul({&res.Tensor("matmul_1_in_1"), + &res.Tensor("reshape_5_out"), + &res.Tensor("reshape_6_out"), + &res.Tensor("add_4_in_2")}, + {&res.Tensor("reshape_4_out")}); + } +}; + +class AttentionFusePass : public pir::Pass { + public: + AttentionFusePass() : pir::Pass("attention_fuse_pass", 2) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(MultiHeadMatmulFusePattern().Build(context)); + // Add other attention variant fuse pattern. + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +} // namespace + +namespace pir { +std::unique_ptr CreateAttentionFusePass() { + return std::make_unique(); +} +} // namespace pir + +REGISTER_IR_PASS(attention_fuse_pass, AttentionFusePass); diff --git a/paddle/pir/transforms/reorder_block_ops_pass.h b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h similarity index 92% rename from paddle/pir/transforms/reorder_block_ops_pass.h rename to paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h index 51ab110bb3ac0..0c0d2e84952ca 100644 --- a/paddle/pir/transforms/reorder_block_ops_pass.h +++ b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h @@ -21,6 +21,6 @@ namespace pir { class Pass; -IR_API std::unique_ptr CreateReorderBlockOpsPass(); +IR_API std::unique_ptr CreateAttentionFusePass(); } // namespace pir diff --git a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc new file mode 100644 index 0000000000000..6025a3f7d1c3a --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc @@ -0,0 +1,146 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace { + +class FusedDropoutAddPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &dropout = pat.Op(paddle::dialect::DropoutOp::name(), + {{"p", pat.Attr("p")}, + {"is_test", pat.Attr("is_test")}, + {"mode", pat.Attr("mod")}, + {"seed", pat.Attr("seed")}, + {"fix_seed", pat.Attr("fix_seed")}}); + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + + dropout({&pat.Tensor("x"), &pat.Tensor("seed_tensor")}, + {&pat.Tensor("dropout_out"), &pat.Tensor("mask")}); + pat.Tensor("add_out") = add(pat.Tensor("dropout_out"), pat.Tensor("y")); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &fused_dropout_add = + res.Op(paddle::dialect::FusedDropoutAddOp::name(), + {{{"p", pat.Attr("p")}, + {"is_test", pat.Attr("is_test")}, + {"mode", pat.Attr("mod")}, + {"seed", pat.Attr("seed")}, + {"fix_seed", pat.Attr("fix_seed")}}}); + fused_dropout_add( + {&res.Tensor("x"), &res.Tensor("y"), &res.Tensor("seed_tensor")}, + {&res.Tensor("add_out"), &res.Tensor("mask")}); + } +}; + +class FusedDropoutGradAddGradPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &dropout = pat.Op(paddle::dialect::DropoutOp::name(), + {{"p", pat.Attr("p")}, + {"is_test", pat.Attr("is_test")}, + {"mode", pat.Attr("mod")}, + {"seed", pat.Attr("seed")}, + {"fix_seed", pat.Attr("fix_seed")}}); + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + + const auto &add_grad = pat.Op(paddle::dialect::AddGradOp::name()); + const auto &dropout_grad = pat.Op(paddle::dialect::DropoutGradOp::name(), + {{"p", pat.Attr("p")}, + {"is_test", pat.Attr("is_test")}, + {"mode", pat.Attr("mod")}}); + + dropout({&pat.Tensor("x"), &pat.Tensor("seed_tensor")}, + {&pat.Tensor("dropout_out"), &pat.Tensor("mask")}); + pat.Tensor("add_out") = add(pat.Tensor("dropout_out"), pat.Tensor("y")); + add_grad({&pat.Tensor("dropout_out"), + &pat.Tensor("y"), + &pat.Tensor("add_out_grad")}, + {&pat.Tensor("dropout_out_grad"), &pat.Tensor("y_grad")}); + dropout_grad({&pat.Tensor("mask"), &pat.Tensor("dropout_out_grad")}, + {&pat.Tensor("x_grad")}); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &fused_dropout_add = + res.Op(paddle::dialect::FusedDropoutAddOp::name(), + {{{"p", pat.Attr("p")}, + {"is_test", pat.Attr("is_test")}, + {"mode", pat.Attr("mod")}, + {"seed", pat.Attr("seed")}, + {"fix_seed", pat.Attr("fix_seed")}}}); + + const auto &fused_dropout_add_grad = + res.Op(paddle::dialect::FusedDropoutAddGradOp::name(), + {{{"p", pat.Attr("p")}, + {"is_test", pat.Attr("is_test")}, + {"mode", pat.Attr("mod")}, + {"fix_seed", pat.Attr("fix_seed")}}}); + + fused_dropout_add( + {&res.Tensor("x"), &res.Tensor("y"), &res.Tensor("seed_tensor")}, + {&res.Tensor("add_out"), &res.Tensor("mask")}); + fused_dropout_add_grad({&res.Tensor("mask"), &res.Tensor("add_out_grad")}, + {&res.Tensor("x_grad"), &res.Tensor("y_grad")}); + } +}; + +class FusedDropoutAddPass : public pir::Pass { + public: + FusedDropoutAddPass() : pir::Pass("fused_dropout_add_pass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + + ps.Add(FusedDropoutAddPattern().Build(context)); + ps.Add(FusedDropoutGradAddGradPattern().Build(context)); + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateFusedDropoutAddPass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(fused_dropout_add_pass, FusedDropoutAddPass); diff --git a/paddle/cinn/hlir/framework/convert_to_dialect.h b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h similarity index 73% rename from paddle/cinn/hlir/framework/convert_to_dialect.h rename to paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h index 7ea0a2ace40c7..3d78e6fe7b3b2 100644 --- a/paddle/cinn/hlir/framework/convert_to_dialect.h +++ b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h @@ -15,19 +15,12 @@ #pragma once #include +#include "paddle/pir/core/dll_decl.h" namespace pir { -class Program; -} // namespace pir -namespace cinn { -namespace hlir { -namespace framework { -class Program; +class Pass; -std::unique_ptr<::pir::Program> ConvertToRuntimeDialect( - const hlir::framework::Program& program); +IR_API std::unique_ptr CreateFusedDropoutAddPass(); -} // namespace framework -} // namespace hlir -} // namespace cinn +} // namespace pir diff --git a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc new file mode 100644 index 0000000000000..71fe6b6476302 --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc @@ -0,0 +1,298 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace { + +class FusedLinearPattern : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + + pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w")); + pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return (match_ctx.Tensor("w").Shape().size() == 2 && + match_ctx.Tensor("x").Shape().size() >= 2); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "none"; + }); + const auto &fused_gemm_epilogue = + res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), + {{{"trans_x", pat.Attr("trans_x")}, + {"trans_y", pat.Attr("trans_y")}, + {"activation", act_attr}}}); + fused_gemm_epilogue( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out")}); + } +}; + +class FusedLinearGradPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &matmul_grad = pat.Op(paddle::dialect::MatmulGradOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + const auto &add_grad = pat.Op(paddle::dialect::AddGradOp::name()); + + pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w")); + pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); + add_grad({&pat.Tensor("tmp"), &pat.Tensor("bias"), &pat.Tensor("out_grad")}, + {&pat.Tensor("tmp_grad"), &pat.Tensor("bias_grad")}); + matmul_grad({&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("tmp_grad")}, + {&pat.Tensor("x_grad"), &pat.Tensor("w_grad")}); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return (match_ctx.Tensor("w").Shape().size() == 2 && + match_ctx.Tensor("x").Shape().size() >= 2); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "none"; + }); + const auto &fused_gemm_epilogue = + res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), + {{{"trans_x", pat.Attr("trans_x")}, + {"trans_y", pat.Attr("trans_y")}, + {"activation", act_attr}}}); + const auto &fused_gemm_epilogue_grad = + res.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), + {{{"trans_x", pat.Attr("trans_x")}, + {"trans_y", pat.Attr("trans_y")}, + {"activation_grad", act_attr}}}); + fused_gemm_epilogue( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out")}); + fused_gemm_epilogue_grad({&res.Tensor("x"), + &res.Tensor("w"), + &res.NoneTensor(), + &res.Tensor("out_grad")}, + {&res.Tensor("x_grad"), + &res.Tensor("w_grad"), + &res.Tensor("bias_grad")}); + } +}; + +class FusedLinearGeluGradPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &fused_gemm_epilogue = + pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(), + {{{"trans_x", pat.Attr("trans_x1")}, + {"trans_y", pat.Attr("trans_y1")}, + {"activation", pat.Attr("act1")}}}); + const auto &fused_gemm_epilogue_grad1 = + pat.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), + {{{"trans_x", pat.Attr("trans_x2")}, + {"trans_y", pat.Attr("trans_y2")}, + {"activation_grad", pat.Attr("act2")}}}); + fused_gemm_epilogue( + {&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("bias")}, + {&pat.Tensor("fuse_out"), &pat.Tensor("reserve_space")}); + pat.Tensor("out") = + pat.Op(paddle::dialect::GeluOp::name())(pat.Tensor("fuse_out")); + + fused_gemm_epilogue_grad1({&pat.Tensor("x1"), + &pat.Tensor("w1"), + &pat.Tensor("reserve_space1"), + &pat.Tensor("out_grad")}, + {&pat.Tensor("x1_grad"), + &pat.Tensor("w1_grad"), + &pat.Tensor("bias1_grad")}); + pat.Tensor("gelu_dx") = pat.Op(paddle::dialect::GeluGradOp::name())( + pat.Tensor("fuse_out"), pat.Tensor("x1_grad")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return match_ctx.Attr("act1") == "none" && + match_ctx.Attr("act2") == "none"; + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "gelu"; + }); + const auto &fused_gemm_epilogue_new = + res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), + {{{"trans_x", pat.Attr("trans_x1")}, + {"trans_y", pat.Attr("trans_y1")}, + {"activation", act_attr}}}); + const auto &act_grad_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "gelu_grad"; + }); + const auto &fused_gemm_epilogue_grad_new = + res.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), + {{{"trans_x", pat.Attr("trans_x2")}, + {"trans_y", pat.Attr("trans_y2")}, + {"activation_grad", act_grad_attr}}}); + fused_gemm_epilogue_new( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out"), &res.Tensor("reserve_space2")}); + fused_gemm_epilogue_grad_new({&res.Tensor("x1"), + &res.Tensor("w1"), + &res.Tensor("reserve_space2"), + &res.Tensor("out_grad")}, + {&res.Tensor("gelu_dx"), + &res.Tensor("w1_grad"), + &res.Tensor("bias1_grad")}); + } +}; + +class FusedLinearReluGradPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &fused_gemm_epilogue = + pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(), + {{{"trans_x", pat.Attr("trans_x1")}, + {"trans_y", pat.Attr("trans_y1")}, + {"activation", pat.Attr("act1")}}}); + const auto &fused_gemm_epilogue_grad = + pat.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), + {{{"trans_x", pat.Attr("trans_x2")}, + {"trans_y", pat.Attr("trans_y2")}, + {"activation_grad", pat.Attr("act2")}}}); + const auto &fused_gemm_epilogue_grad1 = + pat.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), + {{{"trans_x", pat.Attr("trans_x3")}, + {"trans_y", pat.Attr("trans_y3")}, + {"activation_grad", pat.Attr("act3")}}}); + fused_gemm_epilogue( + {&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("bias")}, + {&pat.Tensor("fuse_out"), &pat.Tensor("reserve_space")}); + pat.Tensor("out") = pat.Op("pd_op.relu")(pat.Tensor("fuse_out")); + + fused_gemm_epilogue_grad1({&pat.Tensor("x1"), + &pat.Tensor("w1"), + &pat.Tensor("reserve_space2"), + &pat.Tensor("out_grad")}, + {&pat.Tensor("x1_grad"), + &pat.Tensor("w1_grad"), + &pat.Tensor("bias1_grad")}); + pat.Tensor("relu_dx") = + pat.Op("pd_op.relu_grad")(pat.Tensor("x1"), pat.Tensor("x1_grad")); + fused_gemm_epilogue_grad({&pat.Tensor("x"), + &pat.Tensor("w"), + &pat.Tensor("reserve_space1"), + &pat.Tensor("relu_dx")}, + {&pat.Tensor("x_grad"), + &pat.Tensor("w_grad"), + &pat.Tensor("bias_grad")}); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return match_ctx.Attr("act1") == "none" && + match_ctx.Attr("act3") == "none"; + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "relu"; + }); + const auto &fused_gemm_epilogue_new = + res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), + {{{"trans_x", pat.Attr("trans_x1")}, + {"trans_y", pat.Attr("trans_y1")}, + {"activation", act_attr}}}); + const auto &act_grad_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "relu_grad"; + }); + const auto &fused_gemm_epilogue_grad1_new = + res.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), + {{{"trans_x", pat.Attr("trans_x2")}, + {"trans_y", pat.Attr("trans_y2")}, + {"activation_grad", act_grad_attr}}}); + fused_gemm_epilogue_new( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out"), &res.Tensor("reserve_space3")}); + fused_gemm_epilogue_grad1_new({&res.Tensor("x1"), + &res.Tensor("w1"), + &res.Tensor("reserve_space3"), + &res.Tensor("out_grad")}, + {&res.Tensor("relu_dx"), + &res.Tensor("w1_grad"), + &res.Tensor("bias1_grad")}); + } +}; + +class FusedGemmEpiloguePass : public pir::Pass { + public: + FusedGemmEpiloguePass() : pir::Pass("fused_gemm_epilogue_pass", 2) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(FusedLinearGradPattern().Build(context)); + ps.Add(FusedLinearPattern().Build(context)); + ps.Add(FusedLinearGeluGradPattern().Build(context)); + ps.Add(FusedLinearReluGradPattern().Build(context)); + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateFusedGemmEpiloguePass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(fused_gemm_epilogue_pass, FusedGemmEpiloguePass); diff --git a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h new file mode 100644 index 0000000000000..61f503a530f72 --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateFusedGemmEpiloguePass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc new file mode 100644 index 0000000000000..8ae7b542f3dcd --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc @@ -0,0 +1,361 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" +namespace { + +// add_grad + matmul_grad + add_ -> matmul + fused_liner_param_gard_add +class FusedMatmulAddGradAddPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &add_grad = pat.Op(paddle::dialect::AddGradOp::name()); + const auto &matmul_grad = pat.Op(paddle::dialect::MatmulGradOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add_ = pat.Op(paddle::dialect::Add_Op::name()); + + add_grad( + {&pat.Tensor("out"), &pat.Tensor("bias"), &pat.Tensor("addout_grad")}, + {&pat.Tensor("out_grad"), &pat.Tensor("dbias")}); + matmul_grad( + {&pat.Tensor("x"), &pat.Tensor("weight"), &pat.Tensor("out_grad")}, + {&pat.Tensor("x_grad"), &pat.Tensor("weight_grad")}); + pat.Tensor("add_out") = + add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + const auto &x_trans = match_ctx.Attr("trans_x"); + const auto &y_trans = match_ctx.Attr("trans_y"); + return (match_ctx.Tensor("weight_grad").Shape() == + match_ctx.Tensor("dweight").Shape() && + match_ctx.Tensor("out").Shape() == + match_ctx.Tensor("addout_grad").Shape() && + x_trans == false && y_trans == false); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &muti_precision_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + return !(match_ctx.Tensor("dweight").Dtype() == + match_ctx.Tensor("weight_grad").Dtype()); + }); + + const auto &true_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &false_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + + const auto &matmul = + res.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", false_attr}, {"transpose_y", true_attr}}); + const auto &fused_linear_param_grad_add = res.Op( + paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, {"has_bias", true_attr}}}); + + matmul({&res.Tensor("addout_grad"), &res.Tensor("weight")}, + {&res.Tensor("x_grad")}); + fused_linear_param_grad_add({&res.Tensor("x"), + &res.Tensor("addout_grad"), + &res.Tensor("dweight"), + &res.NoneTensor()}, + {&res.Tensor("add_out"), &res.Tensor("dbias")}); + } +}; + +// matmul_grad + add_ -> matmul + fused_liner_param_gard_add +class FusedMatmulGradAddPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul_grad = pat.Op(paddle::dialect::MatmulGradOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add_ = pat.Op(paddle::dialect::Add_Op::name()); + + matmul_grad( + {&pat.Tensor("x"), &pat.Tensor("weight"), &pat.Tensor("out_grad")}, + {&pat.Tensor("x_grad"), &pat.Tensor("weight_grad")}); + pat.Tensor("add_out") = + add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + const auto &x_trans = match_ctx.Attr("trans_x"); + const auto &y_trans = match_ctx.Attr("trans_y"); + return (match_ctx.Tensor("weight_grad").Shape() == + match_ctx.Tensor("dweight").Shape() && + x_trans == false && y_trans == false); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + + const auto &muti_precision_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + return !(match_ctx.Tensor("dweight").Dtype() == + match_ctx.Tensor("weight_grad").Dtype()); + }); + + const auto &true_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &false_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + + const auto &matmul = + res.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", false_attr}, {"transpose_y", true_attr}}); + const auto &fused_linear_param_grad_add = res.Op( + paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, {"has_bias", false_attr}}}); + + matmul({&res.Tensor("out_grad"), &res.Tensor("weight")}, + {&res.Tensor("x_grad")}); + fused_linear_param_grad_add( + {&res.Tensor("x"), + &res.Tensor("out_grad"), + &res.Tensor("dweight"), + &res.NoneTensor()}, + {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); + } +}; + +// matmul + 0 = add_(0,1) -> fused_liner_param_gard_add +class FusedMatmulAddaPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add_ = pat.Op(paddle::dialect::Add_Op::name()); + + matmul({&pat.Tensor("x"), &pat.Tensor("out_grad")}, + {&pat.Tensor("weight_grad")}); + pat.Tensor("add_out") = + add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return (match_ctx.Tensor("weight_grad").Shape() == + match_ctx.Tensor("dweight").Shape()); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &muti_precision_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + return !(match_ctx.Tensor("dweight").Dtype() == + match_ctx.Tensor("weight_grad").Dtype()); + }); + + const auto &true_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &false_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + + const auto &fused_linear_param_grad_add = res.Op( + paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, {"has_bias", false_attr}}}); + fused_linear_param_grad_add( + {&res.Tensor("x"), + &res.Tensor("out_grad"), + &res.Tensor("dweight"), + &res.NoneTensor()}, + {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); + } +}; + +// matmul + 1 = add_(1,0) -> fused_liner_param_gard_add +class FusedMatmulAddbPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add_ = pat.Op(paddle::dialect::Add_Op::name()); + + matmul({&pat.Tensor("x"), &pat.Tensor("out_grad")}, + {&pat.Tensor("weight_grad")}); + pat.Tensor("add_out") = + add_(pat.Tensor("weight_grad"), pat.Tensor("dweight")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return (match_ctx.Tensor("weight_grad").Shape() == + match_ctx.Tensor("dweight").Shape()); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &muti_precision_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + return !(match_ctx.Tensor("dweight").Dtype() == + match_ctx.Tensor("weight_grad").Dtype()); + }); + + const auto &true_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &false_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + + const auto &fused_linear_param_grad_add = res.Op( + paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, {"has_bias", false_attr}}}); + fused_linear_param_grad_add( + {&res.Tensor("x"), + &res.Tensor("out_grad"), + &res.Tensor("dweight"), + &res.NoneTensor()}, + {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); + } +}; + +// add_grad + matmul + 0 = add_(0,1) -> fused_liner_param_gard_add +class FusedMatmulAddGradAddaPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &add_grad = pat.Op(paddle::dialect::AddGradOp::name()); + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add_ = pat.Op(paddle::dialect::Add_Op::name()); + add_grad({&pat.Tensor("out"), &pat.Tensor("bias"), &pat.Tensor("dadd_out")}, + {&pat.Tensor("dout"), &pat.Tensor("dbias")}); + matmul({&pat.Tensor("x"), &pat.Tensor("dout")}, + {&pat.Tensor("weight_grad")}); + pat.Tensor("dweight_out") = + add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return (match_ctx.Tensor("weight_grad").Shape() == + match_ctx.Tensor("dweight").Shape() && + match_ctx.Tensor("out").Shape() == + match_ctx.Tensor("dadd_out").Shape()); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &muti_precision_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + return !(match_ctx.Tensor("dweight").Dtype() == + match_ctx.Tensor("weight_grad").Dtype()); + }); + const auto &true_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &fused_linear_param_grad_add = res.Op( + paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, {"has_bias", true_attr}}}); + fused_linear_param_grad_add( + {&res.Tensor("x"), + &res.Tensor("dadd_out"), + &res.Tensor("dweight"), + &res.NoneTensor()}, + {&res.Tensor("dweight_out"), &res.Tensor("dbias")}); + } +}; + +// add_grad + matmul + 1 = add_(1,0) -> fused_liner_param_gard_add +class FusedMatmulAddGradAddbPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &add_grad = pat.Op(paddle::dialect::AddGradOp::name()); + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add_ = pat.Op(paddle::dialect::Add_Op::name()); + add_grad({&pat.Tensor("out"), &pat.Tensor("bias"), &pat.Tensor("dadd_out")}, + {&pat.Tensor("dout"), &pat.Tensor("dbias")}); + matmul({&pat.Tensor("x"), &pat.Tensor("dout")}, + {&pat.Tensor("weight_grad")}); + pat.Tensor("dweight_out") = + add_(pat.Tensor("weight_grad"), pat.Tensor("dweight")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return (match_ctx.Tensor("weight_grad").Shape() == + match_ctx.Tensor("dweight").Shape() && + match_ctx.Tensor("out").Shape() == + match_ctx.Tensor("dadd_out").Shape()); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &muti_precision_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + return !(match_ctx.Tensor("dweight").Dtype() == + match_ctx.Tensor("weight_grad").Dtype()); + }); + const auto &true_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &fused_linear_param_grad_add = res.Op( + paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, {"has_bias", true_attr}}}); + fused_linear_param_grad_add( + {&res.Tensor("x"), + &res.Tensor("dadd_out"), + &res.Tensor("dweight"), + &res.NoneTensor()}, + {&res.Tensor("dweight_out"), &res.Tensor("dbias")}); + } +}; + +class FusedLinearParamGradAddPass : public pir::Pass { + public: + FusedLinearParamGradAddPass() + : pir::Pass("fused_linear_param_grad_add_pass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(FusedMatmulAddGradAddPattern().Build(context)); + ps.Add(FusedMatmulGradAddPattern().Build(context)); + ps.Add(FusedMatmulAddaPattern().Build(context)); + ps.Add(FusedMatmulAddbPattern().Build(context)); + ps.Add(FusedMatmulAddGradAddaPattern().Build(context)); + ps.Add(FusedMatmulAddGradAddbPattern().Build(context)); + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateFusedLinearParamGradAddPass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(fused_linear_param_grad_add_pass, FusedLinearParamGradAddPass); diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h new file mode 100644 index 0000000000000..f4b17e8993a18 --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateFusedLinearParamGradAddPass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/inplace_pass.cc b/paddle/fluid/pir/transforms/inplace_pass.cc index f70fc12568991..2532ffb7ea748 100644 --- a/paddle/fluid/pir/transforms/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/inplace_pass.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/inplace_pass.h" +#include +#include +#include + #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" @@ -21,11 +24,16 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/pir/transforms/inplace_pass.h" +#include "paddle/phi/core/flags.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/operation.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" +PHI_DECLARE_string(ir_inplace_kernel_blacklist); + namespace details { // NOTE(zhangbo): Which kind of value can be deleted? // (1) Value's type needs to be AllocatedDenseTensorType or @@ -53,7 +61,44 @@ static bool CanBeDeleted(pir::Value value) { static bool CanDoInplace(const std::unordered_set& eager_dels, pir::Value input, pir::Value output) { - if (input.type() != output.type()) { + if (!input.type() || !output.type()) { + return false; + } + + if (input.type().isa() && + output.type().isa()) { + auto input_alloc_tensor_type = + input.type().dyn_cast(); + auto output_alloc_tensor_type = + output.type().dyn_cast(); + + if (input_alloc_tensor_type.dtype() != output_alloc_tensor_type.dtype()) { + VLOG(9) << " -- input's dtype != output's dtype, can't do inplace"; + return false; + } + + int64_t in_numel = 1; + int64_t out_numel = 1; + for (int i = 0; i < input_alloc_tensor_type.dims().size(); i++) { + if (input_alloc_tensor_type.dims()[i] == -1) { + VLOG(9) << " -- input's shape has -1, can't do inplace"; + return false; + } + in_numel *= input_alloc_tensor_type.dims()[i]; + } + + for (int i = 0; i < output_alloc_tensor_type.dims().size(); i++) { + if (output_alloc_tensor_type.dims()[i] == -1) { + VLOG(9) << " -- output's shape has -1, can't do inplace"; + return false; + } + out_numel *= output_alloc_tensor_type.dims()[i]; + } + if (in_numel != out_numel) { + VLOG(9) << " -- input's numel != output's numel, can't do inplace"; + return false; + } + } else if (input.type() != output.type()) { VLOG(9) << " -- input's type != output's type, can't do inplace"; return false; } @@ -80,7 +125,7 @@ static bool IsNoNeedBuffer(pir::Operation* op, pir::Value value) { op_info.GetInterfaceImpl(); if (info_interface) { paddle::dialect::OpYamlInfoParser info_parser( - info_interface->get_op_info_()); + info_interface->get_op_info_(), paddle::dialect::IsLegacyOp(op_name)); auto& no_need_buffer_ids = info_parser.NoNeedBufferIds(); for (size_t id = 0; id < no_need_buffer_ids.size(); id++) { if (value == op->operand_source(no_need_buffer_ids[id])) { @@ -140,21 +185,19 @@ static void GetEagerDelValueOfOp( for (size_t i = 0; i < op->num_operands(); ++i) { auto input = op->operand_source(i); - if (skip_dels.count(input) > 0 || !input || !CanBeDeleted(input) || - IsNoNeedBuffer(op, input)) { + if (skip_dels.count(input) > 0 || !input || !CanBeDeleted(input)) { VLOG(6) << "The " << i << "-th input value of the Operation(" << upper_op_name << ") can not be deleted."; VLOG(8) << " -- skip dels: " << skip_dels.count(input); VLOG(8) << " -- value is null: " << !input; VLOG(8) << " -- can be deleted: " << !CanBeDeleted(input); - VLOG(8) << " -- is no_need_buffer: " << IsNoNeedBuffer(op, input); continue; } (*del_value_2_op)[input] = op; } - for (size_t i = 0; i < op->num_results(); ++i) { - pir::Value output = op->result(i); + for (auto& result : op->results()) { + pir::Value output = result; if (output && CanBeDeleted(output)) { (*del_value_2_op)[output] = op; } @@ -206,8 +249,8 @@ static std::unordered_map GetInplaceOps( VLOG(6) << op->name() << "is not a kernel_dialect op, inplace only support " "kernel_dialect operators"; - for (size_t i = 0; i < op->num_results(); ++i) { - visited_values.insert(op->result(i)); + for (auto& result : op->results()) { + visited_values.insert(result); } continue; } @@ -226,8 +269,8 @@ static std::unordered_map GetInplaceOps( .dyn_cast() .data() .backend() == phi::Backend::CPU)) { - for (size_t i = 0; i < op->num_results(); ++i) { - visited_values.insert(op->result(i)); + for (auto& result : op->results()) { + visited_values.insert(result); } continue; } @@ -238,9 +281,9 @@ static std::unordered_map GetInplaceOps( for (size_t i = 0; i < op->num_operands(); ++i) { reused_input_values.insert(op->operand_source(i)); } - for (size_t i = 0; i < op->num_results(); ++i) { - reused_output_values.insert(op->result(i)); - visited_values.insert(op->result(i)); + for (auto& result : op->results()) { + reused_output_values.insert(result); + visited_values.insert(result); } continue; } @@ -248,7 +291,16 @@ static std::unordered_map GetInplaceOps( pir::OpInfo upper_inplace_op_info = pir::IrContext::Instance()->GetRegisteredOpInfo(upper_op_name + "_"); - if (eager_dels.count(op) == 0 || (!upper_inplace_op_info)) { + std::regex reg(","); + std::unordered_set elems{ + std::sregex_token_iterator(FLAGS_ir_inplace_kernel_blacklist.begin(), + FLAGS_ir_inplace_kernel_blacklist.end(), + reg, + -1), + std::sregex_token_iterator()}; + elems.erase(""); + + if (elems.count(upper_op_name)) { VLOG(6) << upper_op_name << "'s value can't delete or doesn't have inplace op, so that " "can't do inplace."; @@ -257,6 +309,19 @@ static std::unordered_map GetInplaceOps( } continue; } + if (eager_dels.count(op) == 0 || (!upper_inplace_op_info) || + upper_op_name == "pd_op.transpose") { + // NOTE(wanghuancoder): pd_op.transpose is not an + // inplace op, only strided transpose support + // inplace in dygraph + VLOG(6) << upper_op_name + << "'s value can't delete or doesn't have inplace op, so that " + "can't do inplace."; + for (auto& result : op->results()) { + visited_values.insert(result); + } + continue; + } auto upper_inplace_op_interface = upper_inplace_op_info @@ -310,8 +375,13 @@ static std::unordered_map GetInplaceOps( << " will change to inplace version op: " << upper_op_name + "_"; } - for (size_t i = 0; i < op->num_results(); ++i) { - visited_values.insert(op->result(i)); + for (auto& result : op->results()) { + visited_values.insert(result); + } + } + if (!FLAGS_ir_inplace_kernel_blacklist.empty()) { + for (auto i : inplace_ops) { + std::cout << i.second << std::endl; } } return inplace_ops; @@ -320,11 +390,11 @@ static std::unordered_map GetInplaceOps( class InplacePass : public pir::Pass { public: - InplacePass() : pir::Pass("InplacePass", 3) {} + InplacePass() : pir::Pass("inplace_pass", 3) {} void Run(pir::Operation* op) override { auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "InplacePass should run on module op."); + IR_ENFORCE(module_op, "inplace_pass should run on module op."); auto* block = module_op.block(); auto inplace_ops = details::GetInplaceOps(block); @@ -365,4 +435,4 @@ std::unique_ptr CreateInplacePass() { } // namespace pir -REGISTER_IR_PASS(inplace, InplacePass); +REGISTER_IR_PASS(inplace_pass, InplacePass); diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 633d1493db5c2..6bb3d12ca6756 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -12,8 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" + #include +#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" @@ -28,14 +31,15 @@ #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" -#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/platform/place.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/common/place.h" +#include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/kernel_factory.h" #include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" #include "paddle/utils/flags.h" PHI_DECLARE_bool(print_ir); @@ -63,15 +67,16 @@ const std::unordered_set UnchangeOutputOps = { "pd_op.fetch", "builtin.set_parameter", "builtin.get_parameter", - "builtin.shadow_output"}; - -const std::unordered_set SpecialLowerOps = {"builtin.combine", - "builtin.slice", - "builtin.split", - "pd_op.if", - "pd_op.while", - "cf.yield", - "cf.cond_yield"}; + "builtin.shadow_output", + "cinn_runtime.jit_kernel"}; +const std::unordered_set SpecialLowerOps = { + "builtin.combine", + "builtin.slice", + "builtin.split", + "pd_op.if", + "pd_op.while", + "cf.yield", + "cinn_runtime.jit_kernel"}; bool NeedFallBackCpu(const pir::Operation* op, const std::string& kernel_fn_name, @@ -284,6 +289,135 @@ pir::OpResult AddPlaceTransferOp(pir::Value in, } } +bool NeedTransformDataType(const phi::DataType& l, const phi::DataType& r) { + return l != phi::DataType::ALL_DTYPE && r != phi::DataType::ALL_DTYPE && + l != r; +} + +const phi::DataType GetKernelTypeforVar( + pir::Operation* op, + const std::string& var_name, + const phi::DataType& tensor_dtype, + const phi::KernelKey* expected_kernel_key) { + pir::IrContext* ctx = pir::IrContext::Instance(); + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op->name()); + + if (op_info + .GetInterfaceImpl()) { + auto get_kernel_type_for_var_interface = + op_info + .GetInterfaceImpl(); + phi::DataType kernel_dtype_for_var = + get_kernel_type_for_var_interface->get_kernel_type_for_var_( + var_name, tensor_dtype, (*expected_kernel_key).dtype()); + return kernel_dtype_for_var; + } + return (*expected_kernel_key).dtype(); +} + +pir::Type BuildDtypeTransferOutputType(pir::Type type, + const phi::Place& place, + phi::DataType data_dtype, + pir::IrContext* ctx) { + if (type.isa()) { + auto dense_tensor_type = + type.dyn_cast(); + + auto out_dtype = paddle::dialect::TransToIrDataType(data_dtype, ctx); + return paddle::dialect::AllocatedDenseTensorType::get( + ctx, + place, + out_dtype, + dense_tensor_type.dims(), + dense_tensor_type.data_layout(), + dense_tensor_type.lod(), + dense_tensor_type.offset()); + + } else if (type.isa()) { + auto selected_rows_type = + type.dyn_cast(); + auto out_dtype = paddle::dialect::TransToIrDataType(data_dtype, ctx); + return paddle::dialect::AllocatedSelectedRowsType::get( + ctx, + place, + out_dtype, + selected_rows_type.dims(), + selected_rows_type.data_layout(), + selected_rows_type.lod(), + selected_rows_type.offset()); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "BuildOutputType only support DenseTensorType and SelectedRowsType")); + } +} + +pir::OpResult AddDtypeTransferOp(pir::Value in, + pir::Block* block, + const phi::KernelKey& kernel_key, + const phi::Place& origin_place, + const phi::Place& out_place, + const phi::DataType& src_dtype, + const phi::DataType& dst_dtype) { + pir::IrContext* ctx = pir::IrContext::Instance(); + + pir::OpInfo phi_kernel_op_info = + ctx->GetRegisteredOpInfo(paddle::dialect::PhiKernelOp::name()); + + // Get kernelkey (backend、layout) + phi::Backend kernel_backend = phi::Backend::UNDEFINED; + phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED; + + if (in.type().isa()) { + kernel_backend = paddle::experimental::ParseBackend( + in.type() + .dyn_cast() + .place()); + kernel_layout = paddle::experimental::ParseLayout( + in.type() + .dyn_cast() + .data_layout()); + } else if (in.type().isa()) { + kernel_backend = paddle::experimental::ParseBackend( + in.type() + .dyn_cast() + .place()); + kernel_layout = paddle::experimental::ParseLayout( + in.type() + .dyn_cast() + .data_layout()); + } else { + PADDLE_THROW( + phi::errors::Unimplemented("Get kernelkey for CastOp only support " + "DenseTensorType and SelectedRowsType")); + } + if (kernel_backend == phi::Backend::UNDEFINED) { + kernel_backend = paddle::experimental::ParseBackend(origin_place); + } + + phi::KernelKey cast_kernel_key(kernel_backend, kernel_layout, src_dtype); + + // Create CastOp + std::unordered_map op_attribute{ + {"op_name", pir::StrAttribute::get(ctx, "pd_op.cast")}, + {"kernel_name", pir::StrAttribute::get(ctx, "cast")}, + {"kernel_key", + paddle::dialect::KernelAttribute::get(ctx, cast_kernel_key)}, + {"dtype", paddle::dialect::DataTypeAttribute::get(ctx, dst_dtype)}}; + + pir::Type output_types = + BuildDtypeTransferOutputType(in.type(), out_place, dst_dtype, ctx); + + pir::Operation* op = pir::Operation::Create( + {in}, op_attribute, {output_types}, phi_kernel_op_info); + + auto in_op = in.dyn_cast().owner(); + if (in_op && in_op->HasAttribute(kAttrIsPersisable)) { + op->set_attribute(kAttrIsPersisable, in_op->attribute(kAttrIsPersisable)); + } + block->push_back(op); + pir::OpResult new_in = op->result(0); + return new_in; +} pir::Type BuildOutputType(pir::Type type, const phi::Place& place, phi::DataType data_type, @@ -340,6 +474,12 @@ phi::DataType GetKernelDataTypeByYamlInfo( auto slot_name = data_type_info[i]; auto& input_map = op_info_parser->InputName2Id(); + bool is_complex_tag = false; + if (slot_name.find("complex:") == 0) { + slot_name = slot_name.substr(8); + is_complex_tag = true; + } + auto find_it = Str2PhiDataType.find(slot_name); if (find_it != Str2PhiDataType.end()) { kernel_data_type = find_it->second; @@ -380,6 +520,9 @@ phi::DataType GetKernelDataTypeByYamlInfo( PADDLE_THROW(phi::errors::Unimplemented( "Only support DenseTensorType, SelectedRows, VectorType")); } + if (is_complex_tag) { + kernel_data_type = phi::dtype::ToComplex(kernel_data_type); + } } else { PADDLE_ENFORCE_EQ(attr_map.count(slot_name), @@ -517,10 +660,7 @@ phi::KernelKey GetKernelKey( if (op->isa()) { VLOG(6) << "SeedOp doesn't need a kernel"; auto backend = paddle::experimental::ParseBackend(place); - return {backend, - phi::DataLayout::ANY, - TransToPhiDataType( - op->result(0).type().dyn_cast().dtype())}; + return {backend, phi::DataLayout::ANY, phi::DataType::INT32}; } if (op->isa()) { @@ -612,19 +752,16 @@ phi::KernelKey GetKernelKey( paddle::experimental::BackendSet(data_op_backend); VLOG(8) << "Update kernel backend set from owner op (DataOp): " << data_op_backend; - } else if (op->operand_source(i) - .dyn_cast() - .owner() - ->isa()) { - auto combine_op = - op->operand_source(i).dyn_cast().owner(); + } else if (op_res.owner()->isa()) { + auto combine_op = op_res.owner(); for (size_t j = 0; j < combine_op->num_operands(); ++j) { - if (combine_op->operand_source(j) - .dyn_cast() - .owner() - ->isa()) { - auto data_op = - combine_op->operand_source(j).dyn_cast().owner(); + auto combine_op_res = + combine_op->operand_source(j).dyn_cast(); + if (!combine_op_res) { + continue; + } + if (combine_op_res.owner()->isa()) { + auto data_op = combine_op_res.owner(); auto data_place = data_op->attribute("place").data(); @@ -675,10 +812,34 @@ phi::KernelKey GetKernelKey( phi::KernelKey res(kernel_backend, kernel_layout, kernel_data_type); + // kernel backend infered incorrectly from memcpy op operands, + // case that place from (not GPU) to GPU. + // We handle this special case by following code to fix up the problem. + // This could be further improved if we had another method. + if (!platform::is_gpu_place(place)) { + if (op->isa()) { + VLOG(6) << "MemcpyOp need a special handle"; + int dst_place_type = op->attribute("dst_place_type") + .dyn_cast() + .data(); + if (dst_place_type == 1) { + res.set_backend(phi::Backend::GPU); + } + } + } + if (op->isa()) { res.set_dtype(phi::DataType::FLOAT32); VLOG(8) << "LoadCombineOp's kernel data type must be FLOAT32"; } + + if (op->isa() || + op->isa()) { + res.set_dtype(phi::DataType::FLOAT32); + VLOG(8) << "CSyncCommStream_Op/CSyncCommStreamOp's kernel data type must " + "be FLOAT32"; + } + if (NeedFallBackCpu((op), kernel_fn_str, res)) { res.set_backend(phi::Backend::CPU); VLOG(8) << "kernel backend must be on CPU when need fallback"; @@ -788,14 +949,14 @@ void HandleForWhileOp( phi::errors::PreconditionNotMet( "[%d]'s input of [%s] op MUST in map pair", 0, op_item->name())); auto new_in = map_value_pair->at(cur_in); - if (i == 0) + if (i == 0) { cond_val = new_in; - else + } else { vec_in.push_back(new_in); + } } pir::Builder builder(ctx, block); - auto base_while_op = op_item->dyn_cast(); auto new_while_op = builder.Build(cond_val, vec_in); pir::Block* body_block = new_while_op.body_block(); @@ -811,13 +972,22 @@ void HandleForWhileOp( ctx, map_op_pair, map_value_pair); + + (*map_op_pair)[op_item] = new_while_op; + + // only deal with single output + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + (*map_value_pair)[op_item->result(i)] = new_while_op->result(i); + } + } } pir::Value GetNewInput( const pir::Value cur_in, const std::unordered_map& map_value_pair, const int index, - const std::string op_name) { + const std::string& op_name) { PADDLE_ENFORCE_EQ( map_value_pair.count(cur_in), true, @@ -918,7 +1088,7 @@ void HandleForSpecialOp( } } - if (op_item->name() == "cf.yield" || op_item->name() == "cf.cond_yield") { + if (op_item->isa<::pir::YieldOp>()) { if (op_item->num_operands() > 0) { for (size_t i = 0; i < op_item->num_operands(); ++i) { auto cur_in = op_item->operand_source(i); @@ -933,6 +1103,28 @@ void HandleForSpecialOp( } } + if (op_item->name() == "cinn_runtime.jit_kernel") { + if (op_item->num_operands() > 0) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + auto new_in = GetNewInput( + cur_in, *map_value_pair, static_cast(i), op_item->name()); + vec_inputs.push_back(new_in); + } + } + + for (size_t i = 0; i < op_item->num_results(); ++i) { + op_output_types.push_back(paddle::dialect::AllocatedDenseTensorType::get( + ctx, + place, + op_item->result(i).type().dyn_cast())); + } + } + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); // Generate new op pir::Operation* op = pir::Operation::Create( @@ -1075,6 +1267,7 @@ std::vector BuildOpInputList( } } + // 1.backend transfer bool check_place_transfer = (op_item->isa<::pir::SetParameterOp>()) || (kernel.IsValid() && (!UnchangeOutputOps.count(op_item->name()))); @@ -1095,7 +1288,8 @@ std::vector BuildOpInputList( op_info_parser, kernel.InputAt(tensor_param_index).backend, i); - VLOG(6) << "Infer kernel backend from input " << i << " of op "; + VLOG(6) << "Infer kernel backend from input " << i << " of op " + << op_item->name(); bool need_trans = (in_place.GetType() != phi::AllocationType::UNDEFINED) && @@ -1248,9 +1442,39 @@ std::vector BuildOpInputList( "type and selected rows for now")); } } + + // 2. dtype transfer + if (op_info_parser != nullptr) { + std::string var_name = op_info_parser->InputNames()[i]; + auto fake_tensors = GetFakeTensorList(new_in); + if (!fake_tensors.empty()) { + const phi::KernelKey expected_kernel_key = kernel_key; + const phi::DataType kernel_dtype_for_var = + GetKernelTypeforVar(op_item, + var_name, + (*fake_tensors[0]).dtype(), + &expected_kernel_key); + + bool check_dtype_transfer = NeedTransformDataType( + expected_kernel_key.dtype(), kernel_dtype_for_var); + if (check_dtype_transfer) { + VLOG(4) << "trans input: " << var_name << "'s dtype from " + << kernel_dtype_for_var << " to " + << expected_kernel_key.dtype(); + + auto out_place = phi::TransToPhiPlace(expected_kernel_key.backend()); + new_in = AddDtypeTransferOp(new_in, + block, + kernel_key, + place, + out_place, + kernel_dtype_for_var, + expected_kernel_key.dtype()); + } + } + } vec_inputs.push_back(new_in); } - return vec_inputs; } @@ -1311,8 +1535,8 @@ std::unique_ptr GetOpYamlInfoParser(pir::Operation* op) { std::unique_ptr op_info_parser(nullptr); if (op_info_interface) { - op_info_parser = - std::make_unique(op_info_interface.GetOpInfo()); + op_info_parser = std::make_unique( + op_info_interface.GetOpInfo(), paddle::dialect::IsLegacyOp(op->name())); } return op_info_parser; diff --git a/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc b/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc new file mode 100644 index 0000000000000..1af9cdb39cb45 --- /dev/null +++ b/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h" + +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace { + +class ReplaceFetchWithShadowOutputPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite( + paddle::dialect::FetchOp op, + pir::PatternRewriter& rewriter) const override { // NOLINT + rewriter.Build( + op->operand_source(0).dyn_cast(), + op->attributes().at("name").dyn_cast().AsString()); + rewriter.EraseOp(op); + return true; + } +}; + +class ReplaceFetchWithShadowOutputPass : public pir::Pass { + public: + ReplaceFetchWithShadowOutputPass() + : pir::Pass("replace_fetch_with_shadow_output_pass", 1) {} + + bool Initialize(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + ps.Add(context); + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation* op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation* op) const override { + return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateReplaceFetchWithShadowOutputPass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(replace_fetch_with_shadow_output_pass, + ReplaceFetchWithShadowOutputPass); diff --git a/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h b/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h new file mode 100644 index 0000000000000..8c96f2d8abe0e --- /dev/null +++ b/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateReplaceFetchWithShadowOutputPass(); + +} // namespace pir diff --git a/paddle/fluid/platform/device/gpu/nccl_helper.h b/paddle/fluid/platform/device/gpu/nccl_helper.h index 6afcd2eb7cd97..8afcfc9f2b700 100644 --- a/paddle/fluid/platform/device/gpu/nccl_helper.h +++ b/paddle/fluid/platform/device/gpu/nccl_helper.h @@ -32,6 +32,8 @@ #ifdef PADDLE_WITH_RCCL #include "paddle/fluid/platform/dynload/rccl.h" #endif +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/enforce.h" diff --git a/paddle/fluid/platform/device_event_base.h b/paddle/fluid/platform/device_event_base.h index 03fd7d4bb13f0..828b54c44a2dd 100644 --- a/paddle/fluid/platform/device_event_base.h +++ b/paddle/fluid/platform/device_event_base.h @@ -18,6 +18,7 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" +#include "paddle/utils/test_macros.h" namespace paddle { namespace platform { @@ -213,7 +214,7 @@ struct EventCreateFunctionRegisterer : public framework::Registrar { "REGISTER_EVENT_CREATE_FUNCTION must be called in global namespace"); \ static ::paddle::platform::EventCreateFunctionRegisterer \ __reg_event_create_##device_type##__(func); \ - int TouchDeviceEventCreate##device_type() { \ + TEST_API int TouchDeviceEventCreate##device_type() { \ __reg_event_create_##device_type##__.Touch(); \ return 0; \ } @@ -233,7 +234,7 @@ struct EventRecordFunctionRegisterer : public framework::Registrar { "REGISTER_EVENT_RECORD_FUNCTION must be called in global namespace"); \ static ::paddle::platform::EventRecordFunctionRegisterer \ __reg_event_record_##device_type##__(func); \ - int TouchDeviceEventRecord##device_type() { \ + TEST_API int TouchDeviceEventRecord##device_type() { \ __reg_event_record_##device_type##__.Touch(); \ return 0; \ } @@ -253,7 +254,7 @@ struct EventQueryFunctionRegisterer : public framework::Registrar { "REGISTER_EVENT_QUERY_FUNCTION must be called in global namespace"); \ static ::paddle::platform::EventQueryFunctionRegisterer \ __reg_event_query_##device_type##__(func); \ - int TouchDeviceEventQuery##device_type() { \ + TEST_API int TouchDeviceEventQuery##device_type() { \ __reg_event_query_##device_type##__.Touch(); \ return 0; \ } @@ -273,7 +274,7 @@ struct EventFinishFunctionRegisterer : public framework::Registrar { "REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \ static ::paddle::platform::EventFinishFunctionRegisterer \ __reg_event_finish_##device_type##__(func); \ - int TouchDeviceEventFinish##device_type() { \ + TEST_API int TouchDeviceEventFinish##device_type() { \ __reg_event_finish_##device_type##__.Touch(); \ return 0; \ } @@ -293,7 +294,7 @@ struct EventSetFinishedFunctionRegisterer : public framework::Registrar { "REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \ static ::paddle::platform::EventSetFinishedFunctionRegisterer \ __reg_event_finished_setter_##device_type##__(func); \ - int TouchDeviceEventSetFinished##device_type() { \ + TEST_API int TouchDeviceEventSetFinished##device_type() { \ __reg_event_finished_setter_##device_type##__.Touch(); \ return 0; \ } @@ -315,7 +316,7 @@ struct EventWaitFunctionRegisterer : public framework::Registrar { static ::paddle::platform::EventWaitFunctionRegisterer \ __reg_event_wait_##waiter_type##event_type##__(func); \ - int TouchDeviceEventWait##waiter_type##event_type() { \ + TEST_API int TouchDeviceEventWait##waiter_type##event_type() { \ __reg_event_wait_##waiter_type##event_type##__.Touch(); \ return 0; \ } @@ -335,7 +336,7 @@ struct EventResetFunctionRegisterer : public framework::Registrar { "REGISTER_EVENT_RESET_FUNCTION must be called in global namespace"); \ static ::paddle::platform::EventResetFunctionRegisterer \ __reg_event_resetter_##device_type##__(func); \ - int TouchDeviceEventReset##device_type() { \ + TEST_API int TouchDeviceEventReset##device_type() { \ __reg_event_resetter_##device_type##__.Touch(); \ return 0; \ } diff --git a/paddle/fluid/platform/dynload/cuda_driver.cc b/paddle/fluid/platform/dynload/cuda_driver.cc index c6851594b803b..c0c752cab5fc5 100644 --- a/paddle/fluid/platform/dynload/cuda_driver.cc +++ b/paddle/fluid/platform/dynload/cuda_driver.cc @@ -24,6 +24,7 @@ namespace dynload { #if CUDA_VERSION >= 10020 CUDA_ROUTINE_EACH_VVM(DEFINE_WRAP); +CUDA_ROUTINE_EACH_CUDA_GRAPH(DEFINE_WRAP); #endif CUDA_ROUTINE_EACH(DEFINE_WRAP); diff --git a/paddle/fluid/platform/dynload/cuda_driver.h b/paddle/fluid/platform/dynload/cuda_driver.h index b696ffc1a3be8..8f3896530b430 100644 --- a/paddle/fluid/platform/dynload/cuda_driver.h +++ b/paddle/fluid/platform/dynload/cuda_driver.h @@ -60,7 +60,13 @@ extern bool HasCUDADriver(); __macro(cuMemRelease); \ __macro(cuMemAddressFree) +#define CUDA_ROUTINE_EACH_CUDA_GRAPH(__macro) \ + __macro(cuGraphNodeGetType); \ + __macro(cuGraphKernelNodeGetParams); \ + __macro(cuGraphExecKernelNodeSetParams) + CUDA_ROUTINE_EACH_VVM(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDA_WRAP); +CUDA_ROUTINE_EACH_CUDA_GRAPH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDA_WRAP); #endif CUDA_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDA_WRAP); diff --git a/paddle/fluid/platform/dynload/nccl.h b/paddle/fluid/platform/dynload/nccl.h index c2052719dd56c..d9516c9f4de4e 100644 --- a/paddle/fluid/platform/dynload/nccl.h +++ b/paddle/fluid/platform/dynload/nccl.h @@ -31,6 +31,7 @@ namespace dynload { __macro(ncclCommInitAll); \ __macro(ncclGetUniqueId); \ __macro(ncclCommInitRank); \ + __macro(ncclCommAbort); \ __macro(ncclCommDestroy); \ __macro(ncclCommCount); \ __macro(ncclCommCuDevice); \ @@ -42,6 +43,7 @@ namespace dynload { __macro(ncclGroupEnd); \ __macro(ncclReduce); \ __macro(ncclReduceScatter); \ + __macro(ncclCommGetAsyncError); \ __macro(ncclGetErrorString); NCCL_RAND_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_NCCL_WRAP) diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 64c431b3d237f..6e12d6fa464cc 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1212,8 +1212,8 @@ void scatter_grad(const Tensor& index, template void batch_norm_grad(const Tensor& x, - const Tensor& scale, - const Tensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean_out, const paddle::optional& variance_out, const Tensor& saved_mean, @@ -1306,14 +1306,20 @@ void batch_norm_grad(const Tensor& x, if (x_grad) { if (use_global_stats) { - auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad; + auto nhwc_x_grad = rsqrt_var * nhwc_out_grad; + if (scale) { + nhwc_x_grad = scale.get() * nhwc_x_grad; + } auto nchw_x_grad = transpose(nhwc_x_grad, nhwc_to_nchw_dim); if (need_cast) { nchw_x_grad = cast(nchw_x_grad, x.dtype()); } set_output(nchw_x_grad, x_grad); } else { - auto part1 = scale * rsqrt_var; + auto part1 = rsqrt_var; + if (scale) { + part1 = scale.get() * part1; + } auto mean_temp1 = nhwc_out_grad_sum / nhw; auto mean_temp2 = sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var; auto part2 = @@ -1343,14 +1349,19 @@ void batch_norm_grad(const Tensor& x, auto nhwc_sum_dout_mul_diff = sum( out_grad_data * (x_data - mean_data), reduce_axis, dtype, false); if (use_global_stats) { - auto x_grad_data = scale * rsqrt_var * out_grad_data; + auto x_grad_data = rsqrt_var * out_grad_data; + if (scale) { + x_grad_data = scale.get() * x_grad_data; + } if (need_cast) { x_grad_data = cast(x_grad_data, x.dtype()); } set_output(x_grad_data, x_grad); } else { - auto part1 = scale * rsqrt_var; - + auto part1 = rsqrt_var; + if (scale) { + part1 = scale.get() * part1; + } auto mean_temp1 = out_grad_data_sum / nhw; auto mean_temp2 = nhwc_sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var; diff --git a/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc b/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc index 547590c2a9892..f89a898ca1a58 100644 --- a/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc +++ b/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc @@ -27,9 +27,9 @@ namespace paddle { namespace prim { using Tensor = paddle::Tensor; template <> -Tensor empty(const paddle::experimental::IntArray& shape, - phi::DataType dtype, - const paddle::Place& place) { +TEST_API Tensor empty(const paddle::experimental::IntArray& shape, + phi::DataType dtype, + const paddle::Place& place) { framework::VarDesc* new_var = StaticCompositeContext::Instance().GetBlock()->Var( StaticCompositeContext::Instance().GenerateUniqueName()); diff --git a/paddle/fluid/prim/utils/static/static_global_utils.h b/paddle/fluid/prim/utils/static/static_global_utils.h index c08405bb18dbe..b88292d488ab6 100644 --- a/paddle/fluid/prim/utils/static/static_global_utils.h +++ b/paddle/fluid/prim/utils/static/static_global_utils.h @@ -25,7 +25,6 @@ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/type_defs.h" - namespace paddle { namespace prim { @@ -109,7 +108,7 @@ class StaticCompositeContext { static thread_local bool enable_bwd_prim_; static thread_local bool enable_fwd_prim_; static thread_local bool enable_eager_prim_; - static StaticCompositeContext* static_composite_context_; + TEST_API static StaticCompositeContext* static_composite_context_; DISABLE_COPY_AND_ASSIGN(StaticCompositeContext); }; diff --git a/paddle/fluid/primitive/backend/CMakeLists.txt b/paddle/fluid/primitive/backend/CMakeLists.txt index d352880871121..ec3d39c8739c1 100644 --- a/paddle/fluid/primitive/backend/CMakeLists.txt +++ b/paddle/fluid/primitive/backend/CMakeLists.txt @@ -12,4 +12,4 @@ set(static_backend_files cc_library( primitive_backend_static_experimental SRCS ${static_backend_files} - DEPS pd_op_dialect_api) + DEPS op_dialect) diff --git a/paddle/fluid/primitive/codegen/CMakeLists.txt b/paddle/fluid/primitive/codegen/CMakeLists.txt index d01d21829ca1e..56019296e90ae 100644 --- a/paddle/fluid/primitive/codegen/CMakeLists.txt +++ b/paddle/fluid/primitive/codegen/CMakeLists.txt @@ -4,6 +4,15 @@ set(fwd_path ${parsed_yaml_path}/ops.parsed.yaml) set(fwd_legacy_path ${parsed_yaml_path}/legacy_ops.parsed.yaml) set(rev_path ${parsed_yaml_path}/backward_ops.parsed.yaml) set(rev_legacy_path ${parsed_yaml_path}/legacy_backward_ops.parsed.yaml) +set(fwd_pd_op_path + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated/ops.parsed.yaml +) +set(update_fwd_pd_op_path + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated/update_ops.parsed.yaml +) +set(rev_pd_op_path + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated/ops_backward.parsed.yaml +) set(prim_path "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/primitive.yaml") set(templates_dir "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/templates/") @@ -17,7 +26,9 @@ execute_process( COMMAND ${PYTHON_EXECUTABLE} ${scripts} --fwd_path ${fwd_path} --fwd_legacy_path ${fwd_legacy_path} --rev_path ${rev_path} --rev_legacy_path - ${rev_legacy_path} --prim_path ${prim_path} --templates_dir ${templates_dir} + ${rev_legacy_path} --fwd_pd_op_path ${fwd_pd_op_path} + --update_fwd_pd_op_path ${update_fwd_pd_op_path} --rev_pd_op_path + ${rev_pd_op_path} --prim_path ${prim_path} --templates_dir ${templates_dir} --compat_path ${compat_path} --destination_dir ${destination_dir} RESULT_VARIABLE _result) if(${_result}) @@ -26,3 +37,21 @@ if(${_result}) "Automatic code generation for paddle/fluid/primitive failed, exiting.") endif() message("Automatic code generation for paddle/fluid/primitive succeed.") + +execute_process( + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/primitive/codegen + COMMAND + ${PYTHON_EXECUTABLE} + ${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/decomp_gen.py --fwd_path + ${fwd_path} --fwd_legacy_path ${fwd_legacy_path} --fwd_pd_op_path + ${fwd_pd_op_path} --templates_dir ${templates_dir} --compat_path + ${compat_path} --destination_dir + ${PADDLE_BINARY_DIR}/paddle/fluid/pir/dialect/operator/ir/op_decomp.cc + RESULT_VARIABLE _result) +if(${_result}) + message( + FATAL_ERROR + "Automatic code generation for build/paddle/fluid/pir/dialect/operator/ir/op_decomp.cc failed." + ) +endif() +message("Automatic code generation for decomp interface succeed.") diff --git a/paddle/fluid/primitive/codegen/decomp_gen.py b/paddle/fluid/primitive/codegen/decomp_gen.py new file mode 100644 index 0000000000000..15dbdd30539a1 --- /dev/null +++ b/paddle/fluid/primitive/codegen/decomp_gen.py @@ -0,0 +1,261 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import hashlib +import pathlib +import sys + +import jinja2 +import yaml + +# fmt: off +# import from paddle/fluid/operators/generator +sys.path.append( + str(pathlib.Path(__file__).resolve().parents[2] / 'operators/generator') +) +import filters as op_gen_filters +import tests_utils as op_gen_tests +from gen import extend_compat_info, filter_compat_info +from parse_utils import to_named_dict +from type_mapping import output_type_map + +# import from paddle/fluid/pir/dialect/op_generator/api_gen.py +sys.path.append( + str(pathlib.Path(__file__).resolve().parents[2] / 'pir/dialect/op_generator') +) + +from decomp_interface_gen_op_list import ( + decomp_interface_implementation_gen_op_list, +) +from op_gen import attr_types_map, to_pascal_case + +# fmt: on + + +def load(path: pathlib.Path): + """Load config from yaml file. + + Args: + path (pathlib.Path): The path of yaml config. + + Returns: + dict: The config info. + + """ + with open(path, 'rt') as f: + return yaml.safe_load(f) + + +def render(src_dir: pathlib.Path, dst_dir: pathlib.Path, *args, **kwargs): + """Render and save Jinja2 templates to the destination directory. + + Args: + src_dir (pathlib.Path): The source directory containing Jinja2 templates. + dst_dir (pathlib.Path): The destination directory to save rendered files. + *args: Additional positional arguments passed to the `render` function. + **kwargs: Additional keyword arguments passed to the `render` function. + + Returns: + None + """ + env = jinja2.Environment( + loader=jinja2.FileSystemLoader(src_dir), + keep_trailing_newline=True, + trim_blocks=True, + lstrip_blocks=True, + undefined=jinja2.StrictUndefined, + extensions=['jinja2.ext.do'], + ) + env.filters.update( + { + 'to_paddle_attr_type': op_gen_filters.to_paddle_attr_type, + 'to_paddle_input_type': op_gen_filters.to_paddle_input_type, + 'to_paddle_output_type': op_gen_filters.to_paddle_output_type, + 'trip_intermediate': op_gen_filters.filter_intermediate, + } + ) + env.tests.update( + { + 'scalar': op_gen_tests.is_scalar, + 'intarray': op_gen_tests.is_intarray, + 'datatype': op_gen_tests.is_datatype, + 'exist_mutable_attribute': op_gen_tests.exist_mutable_attribute, + 'mutable_attribute': op_gen_tests.is_mutable_attribute, + 'only_composite_op': op_gen_tests.is_only_composite_op, + } + ) + + decomp_temp = "decomp/generated_decomp.j2" + save( + env.get_template(decomp_temp).render(*args, **kwargs), + pathlib.Path(dst_dir), + ) + + +def save(content: str, path: pathlib.Path): + """Saves the given string contents to a file in the specified path. + + Args: + content (str): The string content that needs to be saved. + path (pathlib.Path): The path to save the file, a Pathlib path object + + Returns: + None + """ + path.parent.mkdir(parents=True, exist_ok=True) + + dst_content = '' + if path.is_file(): + with open(path, 'r') as f: + dst_content = f.read() + + if ( + hashlib.md5(content.encode("UTF-8")).hexdigest() + != hashlib.md5(dst_content.encode("UTF-8")).hexdigest() + ): + with open(path, 'w') as f: + f.write(content) + + +def process_optional_output_info(apis): + for api in apis: + inputs_dict = to_named_dict(api['inputs']) + for output in api['outputs']: + if ( + api.get("inplace", None) + and output['name'] in api['inplace'] + and inputs_dict[api['inplace'][output['name']]]['optional'] + ): + output['optional'] = True + else: + output['optional'] = False + + +def gen( + fwd_path: pathlib.Path, + fwd_legacy_path: pathlib.Path, + compat_path: pathlib.Path, + fwd_pd_op_path: pathlib.Path, + templates_dir: pathlib.Path, + destination_dir: pathlib.Path, +): + """The `gen` load jinja2 templates and relative config info, use jinja2 + templating engine to generate c++ code, and save the code into destination. + + Args: + prim_path (pathlib.Path): The YAML file path of the primitive API. + fwd_path (pathlib.Path): The YAML file path of the forwad API. + fwd_legacy_path (pathlib.Path): The YAML file path of the legacy + forwad API. + rev_path (pathlib.Path): The YAML file path of the backward API. + rev_legacy_path (pathlib.Path): The YAML file path of the legacy + backward API. + compat_path: (pathlib.Path): The YAML file path of the ops compat. + fwd_pd_op_path (pathlib.Path): The YAML file path of the ir forward API. + rev_pd_op_path (pathlib.Path): The YAML file path of the ir backward API. + templates_dir (pathlib.Path): The directory of the templates. + destination_dir (pathlib.Path): The Directory of the generated file. + + Returns: + None + """ + ( + fwds, + legacy_fwds, + compats, + ir_fwds, + ) = ( + load(fwd_path), + load(fwd_legacy_path), + load(compat_path), + load(fwd_pd_op_path), + ) + filter_compat_info(compats) + apis = [ + {**api, **{'class_name': to_pascal_case(api["name"]) + "Op"}} + for api in fwds + legacy_fwds + ir_fwds + ] + + apis = extend_compat_info(apis, compats) + + process_optional_output_info(apis) + + for item in apis: + for attr_item in item["attrs"]: + if attr_item["typename"] not in attr_types_map.keys(): + raise TypeError + attr_item["mapped_type"] = attr_types_map[attr_item["typename"]][0] + for out_item in item["outputs"]: + if out_item["typename"] not in output_type_map.keys(): + name = out_item["typename"] + raise TypeError(f"err type {name}") + if out_item["optional"]: + out_item["mapped_type"] = ( + "paddle::optional<" + + output_type_map[out_item["typename"]] + + ">" + ) + else: + out_item["mapped_type"] = output_type_map[out_item["typename"]] + render( + templates_dir, + destination_dir, + apis=apis, + decomp_white_list=decomp_interface_implementation_gen_op_list, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Generate Static Primitive API' + ) + parser.add_argument( + '--fwd_path', type=str, help='The parsed ops yaml file.' + ) + parser.add_argument( + '--fwd_legacy_path', + type=str, + help='The parsed ops yaml file.', + ) + parser.add_argument( + '--compat_path', + type=str, + help='The parsed ops compat yaml file.', + ) + parser.add_argument( + '--fwd_pd_op_path', + type=str, + help='The ir forward ops parsed yaml file.', + ) + parser.add_argument( + '--templates_dir', + type=str, + help='JinJa2 templates base directory.', + ) + parser.add_argument( + '--destination_dir', + type=str, + help='Destination base directory for generated file.', + ) + args = parser.parse_args() + + gen( + pathlib.Path(args.fwd_path), + pathlib.Path(args.fwd_legacy_path), + pathlib.Path(args.compat_path), + pathlib.Path(args.fwd_pd_op_path), + pathlib.Path(args.templates_dir), + pathlib.Path(args.destination_dir), + ) diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index f9f0d5c32b11c..ce7538e2ba999 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -14,6 +14,7 @@ import argparse import hashlib +import os import pathlib import sys @@ -42,7 +43,15 @@ 'add_n_grad', ] -BACKENDS_BLACK_LIST = ['copy_to', 'add_n_grad', "allclose", "isclose"] +BACKENDS_BLACK_LIST = [ + 'copy_to', + 'add_n_grad', + "allclose", + "isclose", + "send_v2", + "assert", + "embedding_grad_sparse", +] PRIM_VJP = [ @@ -74,6 +83,7 @@ 'silu_grad', 'softmax_grad', 'sqrt_grad', + 'relu_grad', ] # custom vjp list of composite op VJP_COMPS = PRIM_VJP + CUSTOM_VJP @@ -280,18 +290,29 @@ def process_backward_invoke_info(apis): def process_optional_output_info(apis): for api in apis: - if not api['is_fwd']: - continue inputs_dict = to_named_dict(api['inputs']) for output in api['outputs']: - if ( - api.get("inplace", None) - and output['name'] in api['inplace'] - and inputs_dict[api['inplace'][output['name']]]['optional'] - ): - output['optional'] = True - else: + if not api['is_fwd']: output['optional'] = False + else: + if ( + api.get("inplace", None) + and output['name'] in api['inplace'] + and inputs_dict[api['inplace'][output['name']]]['optional'] + ): + output['optional'] = True + else: + output['optional'] = False + + +def update_apis(op_yaml_items, update_yaml_file): + with open(update_yaml_file, "r") as f: + update_apis = yaml.safe_load(f) + for i in range(len(op_yaml_items)): + for update_api in update_apis: + if op_yaml_items[i]['name'] == update_api['name']: + op_yaml_items[i] = update_api + break def gen( @@ -301,6 +322,9 @@ def gen( rev_path: pathlib.Path, rev_legacy_path: pathlib.Path, compat_path: pathlib.Path, + fwd_pd_op_path: pathlib.Path, + update_fwd_pd_op_path: pathlib.Path, + rev_pd_op_path: pathlib.Path, templates_dir: pathlib.Path, destination_dir: pathlib.Path, ): @@ -316,23 +340,45 @@ def gen( rev_legacy_path (pathlib.Path): The YAML file path of the legacy backward API. compat_path: (pathlib.Path): The YAML file path of the ops compat. + fwd_pd_op_path (pathlib.Path): The YAML file path of the ir forward API. + update_fwd_pd_op_path (pathlib.Path): The YAML file path of the ir update_ops. + rev_pd_op_path (pathlib.Path): The YAML file path of the ir backward API. templates_dir (pathlib.Path): The directory of the templates. destination_dir (pathlib.Path): The Directory of the generated file. Returns: None """ - prims, fwds, legacy_fwds, revs, legacy_revs, compats = ( + ( + prims, + fwds, + legacy_fwds, + revs, + legacy_revs, + compats, + ir_fwds, + ir_revs, + ) = ( load(prim_path), load(fwd_path), load(fwd_legacy_path), load(rev_path), load(rev_legacy_path), load(compat_path), + load(fwd_pd_op_path), + load(rev_pd_op_path), ) filter_compat_info(compats) - apis = [{**api, **{'is_fwd': True}} for api in fwds + legacy_fwds] - apis = apis + [{**api, **{'is_fwd': False}} for api in revs + legacy_revs] + + fwd_apis = fwds + legacy_fwds + ir_fwds + # replace old ir ops with pir ops + if os.path.exists(update_fwd_pd_op_path): + update_apis(fwd_apis, update_fwd_pd_op_path) + + apis = [{**api, **{'is_fwd': True}} for api in fwd_apis] + apis = apis + [ + {**api, **{'is_fwd': False}} for api in revs + legacy_revs + ir_revs + ] apis = [ {**api, **{'is_prim': True}} if api['name'] in prims @@ -383,6 +429,21 @@ def gen( type=str, help='The parsed ops compat yaml file.', ) + parser.add_argument( + '--fwd_pd_op_path', + type=str, + help='The ir forward ops parsed yaml file.', + ) + parser.add_argument( + '--update_fwd_pd_op_path', + type=str, + help='The ir update forward ops parsed yaml file.', + ) + parser.add_argument( + '--rev_pd_op_path', + type=str, + help='The ir backward ops parsed yaml file.', + ) parser.add_argument( '--templates_dir', type=str, @@ -402,6 +463,9 @@ def gen( pathlib.Path(args.rev_path), pathlib.Path(args.rev_legacy_path), pathlib.Path(args.compat_path), + pathlib.Path(args.fwd_pd_op_path), + pathlib.Path(args.update_fwd_pd_op_path), + pathlib.Path(args.rev_pd_op_path), pathlib.Path(args.templates_dir), pathlib.Path(args.destination_dir), ) diff --git a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 index 863bbb7de633f..e422bd61a9618 100644 --- a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 +++ b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 @@ -20,7 +20,7 @@ using IntArray = paddle::experimental::IntArray; using DataType = phi::DataType; {% for api in apis %} - {%- if api is only_composite_op -%}{#- render nothing -#} + {%- if api is only_composite_op or "infer_meta" not in api and "composite" not in api and "invoke" not in api -%}{#- render nothing -#} {%- elif api.name not in backend_black_list -%} {%- if 'invoke' not in api or 'invoke' in api and api.is_fwd -%} {% if api.attrs is exist_mutable_attribute %} diff --git a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_eager_backend.cc.j2 b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_eager_backend.cc.j2 index 3b9a94993eaa4..7f9f4b5b8676f 100644 --- a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_eager_backend.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_eager_backend.cc.j2 @@ -12,7 +12,8 @@ namespace backend { {%- macro args(inputs, attrs) -%} {#- Arguments are variable pass into method -#} {{common.sequence('', '', ', ', inputs)}} - {%- if attrs|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between inputs and attrs -#} + {%- if attrs|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between + nputs and attrs -#} {{common.sequence('', '', ', ', attrs)}} {%- endmacro -%} diff --git a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 index 97b150b0d2dfc..36adc8ac964c4 100644 --- a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 @@ -139,7 +139,7 @@ auto op_res = paddle::dialect::{{name}}({{common.args(input_names, attr_names)}} {% for api in apis %} -{%- if api is only_composite_op -%}{#- render nothing -#} +{%- if api is only_composite_op or "infer_meta" not in api and "composite" not in api and "invoke" not in api -%}{#- render nothing -#} {% elif api.name not in backend_black_list %} {%- if 'invoke' not in api or 'invoke' in api and api.is_fwd-%} {% set api_outputs = api.outputs | trip_intermediate %} diff --git a/paddle/fluid/primitive/codegen/templates/decomp/generated_decomp.j2 b/paddle/fluid/primitive/codegen/templates/decomp/generated_decomp.j2 new file mode 100644 index 0000000000000..e038e61e7a861 --- /dev/null +++ b/paddle/fluid/primitive/codegen/templates/decomp/generated_decomp.j2 @@ -0,0 +1,154 @@ +{% import "common.j2" as common %} +// Auto Generated by decomp_gen.py, DO NOT EDIT! + +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/primitive/composite/composite.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/op_base.h" + +namespace paddle { +namespace dialect { +using IntArray = paddle::experimental::IntArray; +{% macro sig(fwd_name, class_name, inputs, attrs, outputs) %} +{% set input_names=[] %} +{% set attr_names=[] %} +{% set output_names=[] %} +{% set output_types=[] %} + +std::vector> {{class_name}}::Decomp(pir::Operation* op) { + {{class_name}} op_obj = op->dyn_cast<{{class_name}}>(); + (void)op_obj; + + FLAGS_tensor_operants_mode = "static"; + + VLOG(4) << "Decomp Prepare inputs of {{fwd_name}}"; + + {% for item in inputs -%} + {% do input_names.append(item.name) %} + {% if item.typename == "Tensor" %} {#- Tensor or Tensor[] #} + {% if item.optional %} + paddle::optional {{item.name}}; + if (!IsEmptyValue(op_obj.{{item.name}}())){ + {{item.name}} = paddle::make_optional(Tensor(std::make_shared(op_obj.{{item.name}}()))); + } + {% else %} + {{item.typename}} {{item.name}}(std::make_shared(op_obj.{{item.name}}())); + {% endif %} + {% elif item.typename == "Tensor[]" %} + {% if item.optional %} + + paddle::optional> {{item.name}}; + if (!IsEmptyValue(op_obj.{{item.name}}())){ + pir::CombineOp combine_op_obj = + op_obj.{{item.name}}().dyn_cast().owner()->dyn_cast(); + std::vector optional_{{item.name}}; + for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) { + optional_{{item.name}}.emplace_back( + std::make_shared(combine_op_obj.inputs()[idx])); + } + {{item.name}} = paddle::make_optional>(optional_{{item.name}}); + } + + {% else %} + pir::CombineOp combine_op_obj_{{item.name}} = + op_obj.{{item.name}}().dyn_cast().owner()->dyn_cast(); + std::vector {{item.name}}; + for (size_t idx = 0; idx < combine_op_obj_{{item.name}}.inputs().size(); idx++) { + {{item.name}}.emplace_back( + std::make_shared(combine_op_obj_{{item.name}}.inputs()[idx])); + } + {% endif %} + {% endif %} + {% endfor %} + + VLOG(4) << "Decomp prepare attributes of {{fwd_name}}"; + {% if attrs %} + {% for item in attrs %} + {% do attr_names.append(item.name) %} + {% if item.typename == "Scalar" and item.support_tensor %} + + Tensor {{item.name}}_(std::make_shared(op_obj.{{item.name}}())); + + auto* {{item.name}}_define_op = + std::static_pointer_cast({{item.name}}_.impl()) + ->value() + .dyn_cast() + .owner(); + if ({{item.name}}_define_op->name() != "pd_op.full") { + PADDLE_THROW( + platform::errors::Unimplemented("We don't support dynamic tensors " + "attribute {{item.name}} for {{fwd_name}} decomposition " + "for now. ")); + } + Scalar {{item.name}} = {{item.name}}_define_op->attribute("value").dyn_cast().data(); + + {% elif item.typename == "IntArray" and item.support_tensor %} + + Tensor {{item.name}}_(std::make_shared(op_obj.{{item.name}}())); + + auto* {{item.name}}_define_op = + std::static_pointer_cast({{item.name}}_.impl()) + ->value() + .dyn_cast() + .owner(); + if ({{item.name}}_define_op->name() != "pd_op.full_int_array") { + PADDLE_THROW( + platform::errors::Unimplemented("We don't support dynamic tensors " + "attribute {{item.name}} for {{fwd_name}} decomposition " + "for now. ")); + } + IntArray {{item.name}} = phi::IntArray( + paddle::dialect::GetInt64Vector({{item.name}}_define_op->attribute("value"))); + + {% else %} + {{item.typename}} {{item.name}} = op->attribute("{{item.name}}").dyn_cast<{{item.mapped_type}}>().data(); + {% endif %} + {% endfor %} + {% endif %} + + VLOG(4) << "Decomp prepare call {{fwd_name}}'s decomp interface"; + + auto org_res = op->results(); + std::vector> res(org_res.size()); + + {% if outputs|length == 1 %} + Tensor op_res = paddle::primitive::details::{{fwd_name}}_decomp({{common.args(input_names, attr_names)}}); + res[0].push_back( + std::static_pointer_cast(op_res.impl()) + ->value() + .dyn_cast()); + {% else %} + {% for item in outputs %} + {% do output_names.append(item.name) %} + {% do output_types.append(item.mapped_type) %} + {% endfor %} + std::tuple<{{common.sequence('', '', ', ', output_types)}}> op_res = paddle::primitive::details::{{fwd_name}}_decomp( + {{common.args(input_names, attr_names)}}); + {% for k in range(outputs|length) %} + {% if outputs[k].intermediate %} + pir::OpResult {{outputs[k].name}}; + res[{{k}}].push_back({{outputs[k].name}}); + {% else %} + res[{{k}}].push_back(std::static_pointer_cast(std::get<{{k}}>(op_res).impl())->value().dyn_cast()); + {% endif %} + {% endfor %} + {% endif %} + + return res; +} +{% endmacro %} + +{% for api in apis -%} + {% if api.name in decomp_white_list %} + {{sig(api.name, api.class_name, api.inputs, api.attrs, api.outputs)}} + {% else %} {# render nothing #} + {% endif %} +{% endfor %} + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 index 02e6c58f97af6..fbf082c72fddd 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 @@ -12,6 +12,7 @@ #include "paddle/pir/core/operation.h" #include "paddle/phi/core/flags.h" #include "paddle/utils/optional.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" PHI_DECLARE_string(tensor_operants_mode); @@ -57,7 +58,7 @@ if({{i.name}}_define_op->name() != "pd_op.full_int_array"){ "We don't support dynamic tensors attribute {{i.name}} for {{api_name}} composite " "for now. ")); } -auto {{i.name}} = {{i.name}}_define_op->attribute("value").dyn_cast().data(); +auto {{i.name}} = phi::IntArray(paddle::dialect::GetInt64Vector({{i.name}}_define_op->attribute("value"))); {% endif %} {% endif %} {% endfor %} @@ -102,7 +103,7 @@ vjp_res = ConstructVjpResultByStopGradients(vjp_res, stop_gradients); {% macro body_prim(api) %} FLAGS_tensor_operants_mode = "static"; -VLOG(4) << "Call PIR Decomposed backward op {{api.name}}"; +VLOG(4) << "Call Pir Decomposed backward op {{api.name}}"; {% for i in range(api.outputs|length) %} {% if api.outputs[i].typename=='Tensor' %} paddle::Tensor* {{api.outputs[i].name}} = !stop_gradients[{{i}}][0] ? &vjp_res[{{i}}][0] : nullptr; diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 7ac642573ca79..374dafe8f8a4a 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -14,11 +14,79 @@ #pragma once -namespace paddle { +#include "paddle/fluid/primitive/primitive/primitive.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/fluid/primitive/utils/utils.h" +namespace paddle { namespace primitive { +namespace details { + +template +Tensor mean_decomp(const Tensor& x, const IntArray& axis, bool keepdim) { + auto org_dtype = x.dtype(); + auto x_tmp = x; + bool need_cast = org_dtype == phi::DataType::FLOAT16 || + org_dtype == phi::DataType::BFLOAT16; + if (need_cast) { + x_tmp = cast(x, phi::DataType::FLOAT32); + } + std::vector x_dim = phi::vectorize(x_tmp.dims()); + int64_t axis_size = axis.size(); + int64_t x_dim_size = x_dim.size(); + auto axis_ = std::vector(); + if (axis_size == 0) { + for (int64_t i = 0; i < x_dim_size; i++) { + axis_.push_back(i); + } + } else { + axis_ = axis.GetData(); + for (int64_t i = 0; i < axis_size; i++) { + if (axis[i] < 0) { + axis_[i] = axis[i] + x_dim_size; + } + } + } + + int64_t value = 1; + for (size_t i = 0; i < axis_.size(); i++) { + value *= x_dim[axis_[i]]; + } + auto sum_x = sum(x_tmp, IntArray(axis_), x_tmp.dtype(), keepdim); + auto res = + sum_x / full(phi::vectorize(sum_x.dims()), value, sum_x.dtype()); + if (need_cast) { + return cast(res, org_dtype); + } else { + return res; + } +} + +template +Tensor relu_decomp(const Tensor& x) { + return maximum(x, full(phi::vectorize(x.dims()), 0.0, x.dtype())); +} + +template +std::tuple squeeze_decomp(const Tensor& x, + const IntArray& axis) { + auto axis_ = process_dims(x, axis.GetData()); + auto out_shape = get_squeeze_dims(x, axis_); + Tensor out = reshape(x, out_shape); + Tensor xshape; + return std::make_tuple(out, xshape); +} + +template +Tensor add_n_decomp(const std::vector& x) { + Tensor res = x[0]; + for (size_t i = 1; i < x.size(); i++) { + res = res + x[i]; + } + return res; +} -namespace experimental {} +} // namespace details } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/CMakeLists.txt b/paddle/fluid/primitive/rule/vjp/CMakeLists.txt index 4b790fd07900b..c0f52d0d55df5 100644 --- a/paddle/fluid/primitive/rule/vjp/CMakeLists.txt +++ b/paddle/fluid/primitive/rule/vjp/CMakeLists.txt @@ -5,4 +5,4 @@ cc_library( primitive_vjp_experimental SRCS ${VJP_SRCS} DEPS primitive_backend_static_experimental static_global_utils - primitive_static_utils_experimental pd_op_dialect_core) + primitive_static_utils_experimental op_dialect) diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 5e8863027a78d..0a02f15aeea10 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -453,6 +453,14 @@ void layer_norm_grad(const Tensor& x, } // cast dtype to float32 if dtype =float16 or bfloat16 + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { + x_cast = cast(x_cast, phi::DataType::FLOAT32); + out_grad_cast = cast(out_grad_cast, phi::DataType::FLOAT32); + if (scale_ptr) { + scale_cast = cast(scale_cast, phi::DataType::FLOAT32); + } + } auto x_sub_mean = x_cast - mean_; // M,N auto tmp = (1.0 / (variance_ + epsilon)); // M,1 @@ -480,6 +488,10 @@ void layer_norm_grad(const Tensor& x, auto x_grad_tmp = dx_end - d_mean_d_std; x_grad_tmp = reshape(x_grad_tmp, phi::vectorize(x.dims())); + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { + x_grad_tmp = cast(x_grad_tmp, x.dtype()); + } set_output(x_grad_tmp, x_grad); } @@ -489,6 +501,10 @@ void layer_norm_grad(const Tensor& x, (x_sub_mean_mul_sqrt_var_1 * out_grad_cast) .sum(std::vector({0}), x_cast.dtype(), true); scale_grad_tmp = reshape(scale_grad_tmp, scale_ptr->shape()); + if (scale_ptr->dtype() == phi::DataType::FLOAT16 || + scale_ptr->dtype() == phi::DataType::BFLOAT16) { + scale_grad_tmp = cast(scale_grad_tmp, scale_ptr->dtype()); + } set_output(scale_grad_tmp, scale_grad); } else { scale_grad = nullptr; @@ -500,6 +516,10 @@ void layer_norm_grad(const Tensor& x, auto bias_grad_tmp = out_grad_cast.sum(std::vector({0}), x_cast.dtype(), true); bias_grad_tmp = reshape(bias_grad_tmp, bias_ptr->shape()); + if (bias_ptr->dtype() == phi::DataType::FLOAT16 || + bias_ptr->dtype() == phi::DataType::BFLOAT16) { + bias_grad_tmp = cast(bias_grad_tmp, bias_ptr->dtype()); + } set_output(bias_grad_tmp, bias_grad); } else { bias_grad = nullptr; @@ -654,6 +674,18 @@ void softmax_grad(const Tensor& out, } } +template +void relu_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + auto condition = greater_than( + out, full(phi::vectorize(out.dims()), 0.0, out.dtype())); + auto res = where(condition, + out_grad, + full(phi::vectorize(out.dims()), 0.0, out.dtype())); + set_output(res, x_grad); + } +} + template void gather_nd_grad(const Tensor& x, const Tensor& index, diff --git a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc index 6b3b1050448ef..999f9bfb0306b 100644 --- a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc @@ -55,7 +55,7 @@ std::vector> reshape_vjp( if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled() && !need_skip) { FLAGS_tensor_operants_mode = "static"; - VLOG(4) << "Call PIR Decomposed backward op reshape_grad"; + VLOG(4) << "Call Pir Decomposed backward op reshape_grad"; paddle::Tensor* x_grad = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr; details::reshape_grad(xshape, out_grad, x_grad); diff --git a/paddle/fluid/primitive/utils/CMakeLists.txt b/paddle/fluid/primitive/utils/CMakeLists.txt index d33bb8fe15d5d..babaa5cd7da7f 100644 --- a/paddle/fluid/primitive/utils/CMakeLists.txt +++ b/paddle/fluid/primitive/utils/CMakeLists.txt @@ -7,4 +7,4 @@ endif() cc_library( primitive_static_utils_experimental SRCS static_utils.cc - DEPS phi common_infer_shape_functions pd_op_dialect_api) + DEPS phi common_infer_shape_functions op_dialect) diff --git a/paddle/fluid/primitive/utils/utils.h b/paddle/fluid/primitive/utils/utils.h index 73fa68a0bb937..da6cd28bfa476 100644 --- a/paddle/fluid/primitive/utils/utils.h +++ b/paddle/fluid/primitive/utils/utils.h @@ -55,6 +55,47 @@ static std::vector get_unsqueeze_dims( return result; } +// This fucction compute unsqueeze dims for reshape to replace unsqueeze. +static std::vector get_squeeze_dims(const Tensor& origin, + const std::vector& axis) { + auto origin_dims = origin.shape(); + auto total_shape_size = origin_dims.size(); + std::vector result; + for (size_t i = 0; i < total_shape_size; ++i) { + if (origin_dims[i] != 1) { + result.push_back(origin_dims[i]); + } else if (origin_dims[i] == 1 && + std::find(axis.begin(), axis.end(), int64_t(i)) == axis.end()) { + result.push_back(1); + } else { + continue; + } + } + return result; +} + +static std::vector process_dims(const Tensor& origin, + const std::vector& axis) { + auto origin_dims = origin.shape(); + auto total_shape_size = origin_dims.size(); + std::vector result; + auto axis_size = axis.size(); + if (axis_size == 0) { + for (size_t i = 0; i < total_shape_size; ++i) { + result.push_back(i); + } + } else { + for (size_t i = 0; i < axis_size; ++i) { + if (axis[i] < 0) { + result.push_back(axis[i] + total_shape_size); + } else { + result.push_back(axis[i]); + } + } + } + return result; +} + // These method don't need to be specified static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims, const phi::DDim& in_dims) { diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 2dfeb89bef5c4..abd92642cd6f8 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -39,9 +39,9 @@ set(PYBIND_DEPS phi_utils phi pir_adaptor - pd_op_dialect + op_dialect_vjp program_translator - pd_inplace_pass + transform pir new_profiler jit_layer @@ -51,6 +51,10 @@ set(PYBIND_DEPS type_info auto_parallel) +if(WITH_CINN) + set(PYBIND_DEPS ${PYBIND_DEPS} transform op_with_group_merge_pass) +endif() + if(WITH_PSCORE) set(PYBIND_DEPS ${PYBIND_DEPS} ps_service) if(WITH_HETERPS) @@ -145,6 +149,7 @@ set(PYBIND_SRCS xpu_streams_py.cc jit.cc auto_parallel_py.cc + eval_frame_tools.cc eval_frame.c) if(WITH_CUSTOM_DEVICE) diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 09d76e33d69c1..8bcee4fede371 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -35,17 +35,18 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" +#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/backends/context_pool.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" -#include "paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.h" -#include "paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.h" -#include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h" -#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" -#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h" -#include "paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.h" -#include "paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h" #include "paddle/phi/core/enforce.h" #ifdef PADDLE_WITH_DISTRIBUTE @@ -73,6 +74,7 @@ using paddle::distributed::auto_parallel::SPMDRuleMap; using paddle::framework::BlockDesc; using paddle::framework::OpDesc; using paddle::framework::VarDesc; +using phi::distributed::ArgDistAttr; using phi::distributed::ProcessMesh; using phi::distributed::TensorDistAttr; using phi::distributed::auto_parallel::Device; @@ -142,9 +144,9 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) { dist_attr->clear_annotated(); } -static std::pair, std::vector> +static std::pair, std::vector> infer_forward(const phi::distributed::SpmdRule &self, const py::args &args); -static std::pair, std::vector> +static std::pair, std::vector> infer_backward(const phi::distributed::SpmdRule &self, const py::args &args); void BindAutoParallel(py::module *m) { @@ -181,10 +183,18 @@ void BindAutoParallel(py::module *m) { *m, "RToSReshardFunction", ReshardFunction) .def(py::init<>()); + py::class_( + *m, "RToSReshardFunctionCrossMesh", ReshardFunction) + .def(py::init<>()); + py::class_( *m, "SToRReshardFunction", ReshardFunction) .def(py::init<>()); + py::class_( + *m, "SToRReshardFunctionCrossMesh", ReshardFunction) + .def(py::init<>()); + py::class_( *m, "RToPReshardFunction", ReshardFunction) .def(py::init<>()); @@ -480,6 +490,9 @@ void BindAutoParallel(py::module *m) { static_cast &( OperatorDistAttr::*)()>(&OperatorDistAttr::output_dist_attrs), &OperatorDistAttr::set_output_dist_attrs) + .def_property("run_time_us", + &OperatorDistAttr::run_time_us, + &OperatorDistAttr::set_run_time_us) .def("get_input_dist_attr", static_cast( @@ -571,33 +584,7 @@ void BindAutoParallel(py::module *m) { "reshard", [](py::handle py_tensor, const TensorDistAttr &dist_attr) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); - auto dev_ctx = phi::DeviceContextPool::Instance().Get(tensor.place()); - std::shared_ptr dist_out_ptr = nullptr; - if (phi::distributed::DistTensor::classof(tensor.impl().get())) { - auto tensor_in = tensor.impl(); - if (tensor_in) { - phi::distributed::DistTensor *dist_tensor = - static_cast(tensor_in.get()); - if (dist_tensor->dist_attr() != dist_attr) { - VLOG(6) << "reshard func, reshard tensor from " - << dist_tensor->dist_attr() << " to " << dist_attr; - auto *func = phi::distributed::ChooseProperReshardFunction( - *dist_tensor, dist_attr); - dist_out_ptr = func->Eval(dev_ctx, *dist_tensor, dist_attr); - } else { - dist_out_ptr = - std::static_pointer_cast( - tensor_in); - } - } - return paddle::Tensor(dist_out_ptr); - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "The input tensor of shard function should be " - "``phi::distributed::DistTensor``. " - "However it's %s", - typeid(tensor.impl().get()).name())); - } + return reshard_ad_function(tensor, dist_attr); }, py::return_value_policy::reference); @@ -728,7 +715,7 @@ static void prepare_ctx(phi::distributed::InferSpmdContext *ctx, parse_single_pyobject(obj, ctx, i); } } -static std::pair, std::vector> +static std::pair, std::vector> infer_forward(const phi::distributed::SpmdRule &self, const py::args &args) { VLOG(6) << "infer_forward "; phi::distributed::InferSpmdContext ctx; @@ -736,7 +723,7 @@ infer_forward(const phi::distributed::SpmdRule &self, const py::args &args) { return self.InferForward(ctx); } -static std::pair, std::vector> +static std::pair, std::vector> infer_backward(const phi::distributed::SpmdRule &self, const py::args &args) { VLOG(6) << "infer_backward "; phi::distributed::InferSpmdContext ctx; diff --git a/paddle/fluid/pybind/communication.cc b/paddle/fluid/pybind/communication.cc index 82408b5236936..64ff801d464f4 100644 --- a/paddle/fluid/pybind/communication.cc +++ b/paddle/fluid/pybind/communication.cc @@ -38,6 +38,9 @@ namespace paddle { namespace pybind { void BindCommContextManager(py::module *m) { + auto P2POption = py::class_(*m, "P2POption") + .def(py::init<>()); + auto CommContextManager = py::class_>( @@ -49,6 +52,12 @@ void BindCommContextManager(py::module *m) { .def_static( "create_nccl_comm_context", &phi::distributed::CommContextManager::CreateNCCLCommContext, + py::arg("store"), + py::arg("unique_comm_key"), + py::arg("rank"), + py::arg("size"), + py::arg("hash_key") = "", + py::arg("p2p_opt") = nullptr, py::call_guard()) #endif #if defined(PADDLE_WITH_GLOO) diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 259aa1f5dac49..5c492815f108d 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -23,7 +23,6 @@ limitations under the License. */ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/reducer.h" -#include "paddle/fluid/distributed/collective/types.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/imperative/layer.h" @@ -31,6 +30,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/process_group_utils.h" #include "paddle/phi/api/all.h" +#include "paddle/phi/core/distributed/types.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/process_group_nccl.h" @@ -265,8 +265,8 @@ void BindDistributed(py::module *m) { in_tensor.impl()); auto in_dense = *p_in_tensor; - auto *dev_ctx = self.GetDeviceContext(in_tensor.place()); auto task = self.AllGather(out_dense, in_dense, sync_op); + auto *dev_ctx = self.GetDeviceContext(in_tensor.place()); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); task->UpdateWaitChain(*dev_ctx); return task; @@ -320,8 +320,6 @@ void BindDistributed(py::module *m) { auto in_dense = *p_in_tensor; // in_tensor_list should not be empty - auto *dev_ctx = - self.GetDeviceContext(in_tensor_list.back().place()); int world_size = self.GetSize(); auto task = self.AllToAll(out_dense, @@ -329,6 +327,8 @@ void BindDistributed(py::module *m) { GetDefaultSplitSizes(*out_dense, world_size), GetDefaultSplitSizes(in_dense, world_size), sync_op); + auto *dev_ctx = + self.GetDeviceContext(in_tensor_list.back().place()); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); task->UpdateWaitChain(*dev_ctx); return task; @@ -542,11 +542,11 @@ void BindDistributed(py::module *m) { in_tensor.impl()); auto in_dense = *p_in_tensor; - auto *dev_ctx = - self.GetDeviceContext(in_tensor.place(), use_calc_stream); distributed::GatherOptions gather_opts{dst}; auto task = self.Gather( out_dense, in_dense, gather_opts, sync_op, use_calc_stream); + auto *dev_ctx = + self.GetDeviceContext(in_tensor.place(), use_calc_stream); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); if (!use_calc_stream && dev_ctx->GetPlace() != platform::CPUPlace()) { @@ -582,8 +582,7 @@ void BindDistributed(py::module *m) { opts.reduce_op = op; auto dense = std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.AllReduce(tensors, tensors, opts); + return self.AllReduce(dense.get(), *dense, opts, false); }, py::arg("tensor"), py::arg("op") = distributed::ReduceOp::SUM, @@ -599,8 +598,7 @@ void BindDistributed(py::module *m) { opts.source_rank = source_rank; auto dense = std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.Broadcast(tensors, tensors, opts); + return self.Broadcast(dense.get(), *dense, opts, false); }, py::arg("tensor"), py::arg("source_rank"), @@ -614,8 +612,7 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto dense = std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.Send(tensors, dst); + return self.Send(*dense, dst, false); }, py::arg("tensor"), py::arg("dst"), @@ -629,8 +626,7 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto dense = std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.Recv(tensors, src); + return self.Recv(dense.get(), src, false); }, py::arg("tensor"), py::arg("src"), @@ -647,9 +643,7 @@ void BindDistributed(py::module *m) { in_tensor.impl()); auto out_dense = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector in_tensors = {*in_dense}; - std::vector out_tensors = {*out_dense}; - return self.AllGather(in_tensors, out_tensors); + return self.AllGather(out_dense.get(), *in_dense, false); }, py::arg("in"), py::arg("out"), @@ -695,9 +689,14 @@ void BindDistributed(py::module *m) { in_tensor.impl()); auto out_dense = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector in_tensors = {*in_dense}; - std::vector out_tensors = {*out_dense}; - return self.AllToAll(in_tensors, out_tensors); + + int world_size = self.GetSize(); + return self.AllToAll( + out_dense.get(), + *in_dense, + GetDefaultSplitSizes(*out_dense, world_size), + GetDefaultSplitSizes(*in_dense, world_size), + false); }, py::arg("in"), py::arg("out"), @@ -741,8 +740,7 @@ void BindDistributed(py::module *m) { opts.root_rank = dst; auto dense = std::dynamic_pointer_cast( in_tensor.impl()); - std::vector tensors = {*dense}; - return self.Reduce(tensors, tensors, opts); + return self.Reduce(dense.get(), *dense, opts, false); }, py::arg("tensor"), py::arg("dst"), @@ -763,9 +761,7 @@ void BindDistributed(py::module *m) { in_tensor.impl()); auto out_dense = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector in_tensors = {*in_dense}; - std::vector out_tensors = {*out_dense}; - return self.Scatter(in_tensors, out_tensors, opts); + return self.Scatter(out_dense.get(), *in_dense, opts, false); }, py::arg("in"), py::arg("out"), @@ -788,12 +784,11 @@ void BindDistributed(py::module *m) { auto p_in_tensor = std::dynamic_pointer_cast( in_tensor.impl()); auto in_dense = *p_in_tensor; - - auto *dev_ctx = self.GetDeviceContext(in_tensor.place(), true); auto task = self.AllGather(out_dense, in_dense, /*sync_op*/ true, /*use_calc_stream*/ true); + auto *dev_ctx = self.GetDeviceContext(in_tensor.place(), true); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); return task; }, @@ -900,8 +895,6 @@ void BindDistributed(py::module *m) { auto in_dense = *p_in_tensor; // in_tensor_list should not be empty - auto *dev_ctx = self.GetDeviceContext( - in_tensor_list.back().place(), /*use_calc_stream*/ true); int world_size = self.GetSize(); auto task = self.AllToAll(out_dense, @@ -910,6 +903,8 @@ void BindDistributed(py::module *m) { GetDefaultSplitSizes(in_dense, world_size), /*sync_op*/ true, /*use_calc_stream*/ true); + auto *dev_ctx = self.GetDeviceContext( + in_tensor_list.back().place(), /*use_calc_stream*/ true); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); return task; }, @@ -1239,6 +1234,7 @@ void BindDistributed(py::module *m) { py::arg("rank"), py::arg("world_size"), py::arg("group_id") = 0, + py::arg("timeout") = 30 * 60 * 1000, py::call_guard()) .def_static("group_start", distributed::ProcessGroupNCCL::GroupStart) .def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd); diff --git a/paddle/fluid/pybind/eager.cc b/paddle/fluid/pybind/eager.cc index a30f01084a060..fb5fd57e26255 100644 --- a/paddle/fluid/pybind/eager.cc +++ b/paddle/fluid/pybind/eager.cc @@ -220,20 +220,46 @@ void InitTensorWithNumpyValue(TensorObject* self, "EmptyTensorInitializer is " "forbidden. Please check your code and make sure you new a " "eager tensor before init it with NumPy.")); + phi::DenseTensor* impl_ptr = static_cast(self->tensor.impl().get()); - if (platform::is_cpu_place(place)) { SetTensorFromPyArray(impl_ptr, array, place, zero_copy); } else if (platform::is_xpu_place(place)) { +#if defined(PADDLE_WITH_XPU) + phi::backends::xpu::SetXPUDeviceId(place.device); + VLOG(4) << "CurrentDeviceId: " + << phi::backends::xpu::GetXPUCurrentDeviceId() << " from " + << static_cast(place.device); +#else + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with XPU if use XPUPlace.")); +#endif SetTensorFromPyArray(impl_ptr, array, place, zero_copy); } else if (platform::is_gpu_place(place)) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + phi::backends::gpu::SetDeviceId(place.device); + VLOG(4) << "CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId() + << " from " << static_cast(place.device); +#else + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with GPU if use CUDAPlace.")); +#endif SetTensorFromPyArray( impl_ptr, array, place, zero_copy); } else if (platform::is_cuda_pinned_place(place)) { SetTensorFromPyArray( impl_ptr, array, place, zero_copy); } else if (platform::is_custom_place(place)) { +#if defined(PADDLE_WITH_CUSTOM_DEVICE) + phi::DeviceManager::SetDevice(place); + VLOG(4) << "CurrentDeviceId: " + << phi::DeviceManager::GetDevice(place.GetDeviceType()) << " from " + << static_cast(place.device); +#else + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with CUSTOM_DEVICE if use CustomPlace.")); +#endif SetTensorFromPyArray( impl_ptr, array, place, zero_copy); } else { @@ -455,7 +481,7 @@ std::string ParseName(std::unordered_map kws_map, } } else { if (flag_kwargs) { - if ((kws_map["name"] == nullptr) || (kws_map["name"] == Py_None)) { + if ((kws_map["name"] == NULL) || (kws_map["name"] == Py_None)) { act_name = egr::Controller::Instance().GenerateUniqueName(unique_name_prefix); } else { diff --git a/paddle/fluid/pybind/eager_custom_python_api.h b/paddle/fluid/pybind/eager_custom_python_api.h index 85afc274623ea..8552f1e7208b8 100644 --- a/paddle/fluid/pybind/eager_custom_python_api.h +++ b/paddle/fluid/pybind/eager_custom_python_api.h @@ -28,14 +28,26 @@ static PyObject *eager_api_linear(PyObject *self, auto x = GetTensorFromArgs("linear", "X", args, 0, false); auto weight = GetTensorFromArgs("linear", "weight", args, 1, false); auto bias = GetTensorFromArgs("linear", "Bias", args, 2, true); + tstate = PyEval_SaveThread(); + if (bias.initialized()) { + const phi::distributed::ProcessMesh *mesh = nullptr; + if (InputsContainDistTensor(&mesh, x, weight, bias)) { + ConvertAllInputsToDistTensor(mesh, x, weight, bias); + } + auto mm_out = matmul_ad_func(x, weight, false, false); auto out = add_ad_func(mm_out, bias); PyEval_RestoreThread(tstate); tstate = nullptr; return ToPyObject(out); } else { + const phi::distributed::ProcessMesh *mesh = nullptr; + if (InputsContainDistTensor(&mesh, x, weight)) { + ConvertAllInputsToDistTensor(mesh, x, weight); + } + auto mm_out = matmul_ad_func(x, weight, false, false); PyEval_RestoreThread(tstate); tstate = nullptr; diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index df3e62b3bae47..6175e0fcae972 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -62,10 +62,16 @@ typedef SSIZE_T ssize_t; #include "paddle/fluid/pybind/cuda_streams_py.h" #endif +#include "paddle/fluid/eager/custom_operator/custom_operator_utils.h" #include "paddle/phi/api/include/operants_manager.h" #include "paddle/phi/api/include/tensor_operants.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/core/flags.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/api/lib/api_gen_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#endif PHI_DECLARE_string(tensor_operants_mode); @@ -280,7 +286,7 @@ PyObject* eager_api_get_grads_types(PyObject* self, auto& grad = meta->Grad(); if (meta && grad.initialized()) { - if (grad.is_dense_tensor() && + if ((grad.is_dense_tensor() || grad.is_dist_tensor()) && (tensor.dtype() == phi::DataType::FLOAT32 || tensor.dtype() == phi::DataType::FLOAT16 || tensor.dtype() == phi::DataType::BFLOAT16)) { @@ -535,6 +541,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self, const auto& attrs = paddle::OpMetaInfoHelper::GetAttrs(vec_map[0]); const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(vec_map[0]); const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(vec_map[0]); + for (size_t i = 0; i < inputs.size(); ++i) { const auto& input = inputs.at(i); // Parse op_type first, so that use i + 1 @@ -552,17 +559,6 @@ static PyObject* eager_api_run_custom_op(PyObject* self, if (paddle::framework::detail::IsDuplicableVar(input)) { std::vector tensors = std::move(CastPyArg2VectorOfTensor(obj, i + 1)); // NOLINT - for (auto& tensor : tensors) { - if (tensor.initialized() && tensor.is_dense_tensor() && - !std::dynamic_pointer_cast(tensor.impl()) - ->meta() - .is_contiguous()) { - tensor.set_impl(std::make_shared( - std::move(paddle::experimental::Trans2Contiguous( - *(std::dynamic_pointer_cast( - tensor.impl())))))); - } - } ctx.EmplaceBackInputs(std::move(tensors)); VLOG(7) << "Custom operator add input " << input << " to CustomOpKernelContext. Add vector size = " @@ -570,19 +566,12 @@ static PyObject* eager_api_run_custom_op(PyObject* self, } else { paddle::Tensor tensor = std::move(CastPyArg2Tensor(obj, i + 1)); // NOLINT - if (tensor.initialized() && tensor.is_dense_tensor() && - !std::dynamic_pointer_cast(tensor.impl()) - ->meta() - .is_contiguous()) { - tensor.set_impl(std::make_shared( - std::move(paddle::experimental::Trans2Contiguous(*( - std::dynamic_pointer_cast(tensor.impl())))))); - } ctx.EmplaceBackInput(std::move(tensor)); VLOG(7) << "Custom operator add input " << input << " to CustomOpKernelContext. Add Tensor for general case."; } } + // Parse op_type and inputs first, so that use 1 + inputs.size() + i int attr_start_idx = static_cast(1 + inputs.size()); for (size_t i = 0; i < attrs.size(); ++i) { @@ -628,6 +617,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self, attr_type_str)); } } + { eager_gil_scoped_release guard; ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); @@ -671,11 +661,8 @@ static PyObject* eager_api_run_custom_op(PyObject* self, ctx.EmplaceBackOutput(std::move(InitializedEmptyTensor())); } - // handle inplace map - ctx.UpdatePlainOutputs(inputs, outputs, inplace_map); VLOG(7) << "Run Kernel of Custom Op: " << op_type; - (*paddle::OpMetaInfoHelper::GetKernelFn(vec_map[0]))(&ctx); - ctx.AssignInplaceOutputs(); + egr::run_custom_op_impl(vec_map[0], true, false, ctx); // handle optional None output when construct backward graph for (size_t i = 0; i < ctx.OutputRange().size(); i++) { @@ -684,7 +671,8 @@ static PyObject* eager_api_run_custom_op(PyObject* self, ctx.MutableOutputAt(ctx.OutputRangeAt(i).first); if (!out_tensor->initialized()) { PADDLE_ENFORCE( - paddle::framework::detail::IsOptionalVar(outputs.at(i)), + paddle::framework::detail::IsOptionalVar(outputs.at(i)) || + out_tensor->is_dist_tensor(), phi::errors::InvalidArgument( "Custom operator's %d-th output is not initialized. " "Please check your implementation again. If you are " diff --git a/paddle/fluid/pybind/eager_legacy_custom_python_api.h b/paddle/fluid/pybind/eager_legacy_custom_python_api.h index 1c40ce4275c42..682995f9874fc 100644 --- a/paddle/fluid/pybind/eager_legacy_custom_python_api.h +++ b/paddle/fluid/pybind/eager_legacy_custom_python_api.h @@ -31,14 +31,13 @@ static PyObject *eager_api_run_program(PyObject *self, // TOREMOVE auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, true); auto OutScope = GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false); - auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true); framework::AttributeMap attrs; // TODO(zengjinle): support CUDA Graph on eager mode ConstructAttrMapFromPyArgs( - "run_program", args, 6, PyTuple_GET_SIZE(args), attrs); + "run_program", args, 5, PyTuple_GET_SIZE(args), attrs); tstate = PyEval_SaveThread(); - run_program_ad_func(X, Params, Out, OutScope, DOut, attrs); + run_program_ad_func(X, Params, Out, OutScope, attrs); PyEval_RestoreThread(tstate); tstate = nullptr; Py_RETURN_NONE; @@ -61,9 +60,9 @@ static PyObject *eager_api_run_program(PyObject *self, // TOREMOVE } } -static PyObject *newir_eager_api_run_program(PyObject *self, - PyObject *args, - PyObject *kwargs) { +static PyObject *pir_eager_api_run_program(PyObject *self, + PyObject *args, + PyObject *kwargs) { PyThreadState *tstate = nullptr; try { auto X = GetTensorListFromArgs("run_program", "X", args, 0, true); @@ -71,17 +70,16 @@ static PyObject *newir_eager_api_run_program(PyObject *self, auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, true); auto OutScope = GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false); - auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true); framework::AttributeMap attrs; // TODO(zengjinle): support CUDA Graph on eager mode - VLOG(1) << "Start NewIR ConstructAttrMapFromPyArgs"; + VLOG(1) << "Start Pir ConstructAttrMapFromPyArgs"; ConstructAttrMapForRunProgram( - "run_program", args, 6, PyTuple_GET_SIZE(args), attrs); + "run_program", args, 5, PyTuple_GET_SIZE(args), attrs); - VLOG(1) << "Finish NewIR ConstructAttrMapFromPyArgs"; + VLOG(1) << "Finish Pir ConstructAttrMapFromPyArgs"; tstate = PyEval_SaveThread(); - newir_run_program_ad_func(X, Params, Out, OutScope, DOut, attrs); + pir_run_program_ad_func(X, Params, Out, OutScope, attrs); PyEval_RestoreThread(tstate); tstate = nullptr; Py_RETURN_NONE; @@ -109,8 +107,8 @@ static PyMethodDef CustomEagerMethods[] = { (PyCFunction)(void (*)(void))eager_api_run_program, METH_VARARGS | METH_KEYWORDS, "C++ interface function for run_program in dygraph."}, - {"newir_run_program", - (PyCFunction)(void (*)(void))newir_eager_api_run_program, + {"pir_run_program", + (PyCFunction)(void (*)(void))pir_eager_api_run_program, METH_VARARGS | METH_KEYWORDS, "C++ interface function for run_program in dygraph."}, {nullptr, nullptr, 0, nullptr}}; diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index ecae39fb43a49..21578110323ab 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -579,6 +579,11 @@ static PyObject* tensor__mul__method(TensorObject* self, } } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); + } + // 4. calculation VLOG(6) << "Calling multiply_ad_func in tensor__mul__method"; { diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 199d05d2c9800..8de8d4d64e18e 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -61,8 +61,8 @@ typedef SSIZE_T ssize_t; #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -893,10 +893,17 @@ static PyObject* tensor_clear_gradient(TensorObject* self, selected_rows->mutable_rows()->clear(); selected_rows->mutable_value()->clear(); } - } else if (grad->is_dense_tensor()) { + } else if (grad->is_dense_tensor() || grad->is_dist_tensor()) { if (grad->initialized()) { + phi::DenseTensor* grad_t = nullptr; + if (grad->is_dense_tensor()) { + grad_t = static_cast(grad->impl().get()); + } else { + grad_t = + static_cast(grad->impl().get()) + ->unsafe_mutable_value(); + } if (set_to_zero) { - auto* grad_t = static_cast(grad->impl().get()); auto* dev_ctx = platform::DeviceContextPool::Instance().Get(grad_t->place()); phi::funcs::set_constant(*dev_ctx, grad_t, 0.0); @@ -908,9 +915,7 @@ static PyObject* tensor_clear_gradient(TensorObject* self, } else { VLOG(4) << "Gradient of " << self->tensor.name() << " is initialized, will be released."; - auto dense_tensor = - std::dynamic_pointer_cast(grad->impl()); - dense_tensor->MoveMemoryHolder(); + grad_t->MoveMemoryHolder(); } } } @@ -937,8 +942,14 @@ static PyObject* tensor__zero_grads(TensorObject* self, "Please check if you have manually cleared" "the grad inside autograd_meta")); if (grad->initialized()) { - if (grad->is_dense_tensor()) { - auto* t = static_cast(grad->impl().get()); + if (grad->is_dense_tensor() || grad->is_dist_tensor()) { + phi::DenseTensor* t = nullptr; + if (grad->is_dense_tensor()) { + t = static_cast(grad->impl().get()); + } else { + t = static_cast(grad->impl().get()) + ->unsafe_mutable_value(); + } auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place()); phi::funcs::set_constant(*dev_ctx, t, 0.0); } else { @@ -949,9 +960,16 @@ static PyObject* tensor__zero_grads(TensorObject* self, eager_gil_scoped_release guard; auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor); if (meta->MutableGrad()->initialized()) { - if (meta->MutableGrad()->is_dense_tensor()) { - auto* t = - static_cast(meta->MutableGrad()->impl().get()); + if (meta->MutableGrad()->is_dense_tensor() || + meta->MutableGrad()->is_dist_tensor()) { + phi::DenseTensor* t = nullptr; + if (meta->MutableGrad()->is_dense_tensor()) { + t = static_cast(meta->MutableGrad()->impl().get()); + } else { + t = static_cast( + meta->MutableGrad()->impl().get()) + ->unsafe_mutable_value(); + } auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place()); phi::funcs::set_constant(*dev_ctx, t, 0.0); } else { @@ -978,13 +996,28 @@ static PyObject* tensor__share_buffer_to(TensorObject* self, "Tensor %s has not been initialized! please initialize " "src tensor before share_buffer_with to other.", self->tensor.name())); - auto* src_tensor = static_cast(self->tensor.impl().get()); - if (!dst_ptr->defined()) { - dst_ptr->set_impl(std::make_shared()); + if (self->tensor.is_dist_tensor()) { + auto* src_tensor = + static_cast(self->tensor.impl().get()) + ->unsafe_mutable_value(); + if (!dst_ptr->defined()) { + dst_ptr->set_impl(std::make_shared()); + } + auto dst_tensor = + static_cast(dst_ptr->impl().get()) + ->unsafe_mutable_value(); + dst_tensor->ShareBufferWith(*src_tensor); + dst_tensor->ShareDataTypeWith(*src_tensor); + } else { + auto* src_tensor = + static_cast(self->tensor.impl().get()); + if (!dst_ptr->defined()) { + dst_ptr->set_impl(std::make_shared()); + } + auto dst_tensor = static_cast(dst_ptr->impl().get()); + dst_tensor->ShareBufferWith(*src_tensor); + dst_tensor->ShareDataTypeWith(*src_tensor); } - auto dst_tensor = static_cast(dst_ptr->impl().get()); - dst_tensor->ShareBufferWith(*src_tensor); - dst_tensor->ShareDataTypeWith(*src_tensor); RETURN_PY_NONE EAGER_CATCH_AND_THROW_RETURN_NULL @@ -1006,10 +1039,21 @@ static PyObject* tensor__is_shared_buffer_with(TensorObject* self, if (!self->tensor.defined() || !dst_ptr->defined()) { return ToPyObject(res); } - auto* self_ptr = static_cast(self->tensor.impl().get()); - auto dst_tensor = static_cast(dst_ptr->impl().get()); - res = dst_tensor->IsSharedBufferWith(*self_ptr); - return ToPyObject(res); + if (self->tensor.is_dist_tensor()) { + auto* self_ptr = + static_cast(self->tensor.impl().get()) + ->unsafe_mutable_value(); + auto dst_tensor = + static_cast(dst_ptr->impl().get()) + ->unsafe_mutable_value(); + res = dst_tensor->IsSharedBufferWith(*self_ptr); + return ToPyObject(res); + } else { + auto* self_ptr = static_cast(self->tensor.impl().get()); + auto dst_tensor = static_cast(dst_ptr->impl().get()); + res = dst_tensor->IsSharedBufferWith(*self_ptr); + return ToPyObject(res); + } EAGER_CATCH_AND_THROW_RETURN_NULL } @@ -1974,15 +2018,24 @@ static PyObject* tensor__use_gpudnn(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY - PADDLE_ENFORCE(self->tensor.defined() && self->tensor.is_dense_tensor(), - paddle::platform::errors::Fatal( - "function _use_gpudnn is only effective for DenseTensor")); + PADDLE_ENFORCE( + self->tensor.defined() && + (self->tensor.is_dense_tensor() || self->tensor.is_dist_tensor()), + paddle::platform::errors::Fatal("Function _use_gpudnn is only effective " + "for DenseTensor and DistTensor.")); bool use_gpudnn = pybind::CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0); // Set the same use_gpudnn attribute, return directly - phi::DenseTensor* dense_tensor = - static_cast(self->tensor.impl().get()); + phi::DenseTensor* dense_tensor = nullptr; + if (self->tensor.is_dist_tensor()) { + dense_tensor = + static_cast(self->tensor.impl().get()) + ->unsafe_mutable_value(); + } else { + dense_tensor = static_cast(self->tensor.impl().get()); + } + phi::DenseTensorMeta* dense_tensor_meta = phi::DenseTensorUtils::GetMutableMeta(dense_tensor); if (use_gpudnn == dense_tensor_meta->use_gpudnn) { @@ -1996,10 +2049,20 @@ static PyObject* tensor__use_gpudnn(TensorObject* self, target_dense_tensor.ShareDataWith(*dense_tensor); target_dense_tensor.set_meta(target_dense_meta); // Construct returned tensor - paddle::Tensor target_tensor( - std::make_shared(target_dense_tensor), - self->tensor.name()); + paddle::Tensor target_tensor(self->tensor.name()); target_tensor.set_autograd_meta(self->tensor.mutable_autograd_meta()); + if (self->tensor.is_dist_tensor()) { + auto dist_tensor = + static_cast(self->tensor.impl().get()); + auto target_dist_tensor = std::make_shared( + dist_tensor->dims(), dist_tensor->dist_attr()); + *(target_dist_tensor->unsafe_mutable_value()) = target_dense_tensor; + target_tensor.set_impl(target_dist_tensor); + } else { + target_tensor.set_impl( + std::make_shared(target_dense_tensor)); + } + VLOG(4) << "Tensor: " << target_tensor.name() << " set use_gpudnn = " << use_gpudnn; @@ -2652,8 +2715,8 @@ static PyObject* tensor__reset_grad_inplace_version(TensorObject* self, } paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor); - if (grad && grad->defined() && grad->is_dense_tensor() && - grad->initialized()) { + if (grad && grad->defined() && grad->initialized() && + (grad->is_dense_tensor() || grad->is_dist_tensor())) { grad->reset_inplace_version(set_to_zero); } RETURN_PY_NONE @@ -2704,14 +2767,21 @@ static PyObject* tensor__offset(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY - auto t = std::dynamic_pointer_cast(self->tensor.impl()); + phi::DenseTensor* dense_tensor = nullptr; + if (self->tensor.is_dist_tensor()) { + dense_tensor = + static_cast(self->tensor.impl().get()) + ->unsafe_mutable_value(); + } else { + dense_tensor = static_cast(self->tensor.impl().get()); + } PADDLE_ENFORCE_EQ( - t->IsInitialized(), + dense_tensor->IsInitialized(), true, platform::errors::InvalidArgument("Tensor %s has not been initialized!", self->tensor.name())); - return ToPyObject(t->offset()); + return ToPyObject(dense_tensor->offset()); EAGER_CATCH_AND_THROW_RETURN_NULL } @@ -2748,9 +2818,14 @@ static PyObject* tensor__grad_value(TensorObject* self, if (grad->is_dense_tensor()) { auto* grad_tensor = static_cast(grad->impl().get()); return ToPyObject(grad_tensor); + } else if (grad->is_dist_tensor()) { + auto* grad_tensor = + static_cast(self->tensor.impl().get()) + ->unsafe_mutable_value(); + return ToPyObject(grad_tensor); } else { PADDLE_THROW(paddle::platform::errors::Fatal( - "this method is only supported for DenseTensor")); + "This method is only supported for DenseTensor and DistTensor.")); RETURN_PY_NONE } EAGER_CATCH_AND_THROW_RETURN_NULL @@ -2833,7 +2908,15 @@ static PyObject* tensor_data_ptr(TensorObject* self, (int64_t)std::dynamic_pointer_cast( // NOLINT self->tensor.impl()) ->data()); + } else if (self->tensor.initialized() && self->tensor.is_dist_tensor()) { + return ToPyObject( + (int64_t) + std::dynamic_pointer_cast( // NOLINT + self->tensor.impl()) + ->unsafe_mutable_value() + ->data()); } + RETURN_PY_NONE EAGER_CATCH_AND_THROW_RETURN_NULL } @@ -2879,7 +2962,8 @@ static PyObject* tensor_method_strides(TensorObject* self, PyObject* kwargs) { EAGER_TRY std::vector value; - if (!self->tensor.defined() || !self->tensor.is_dense_tensor()) { + if (!self->tensor.defined() || + (!self->tensor.is_dense_tensor() && !self->tensor.is_dist_tensor())) { return ToPyObject(value); } auto stride = self->tensor.strides(); @@ -2919,20 +3003,24 @@ static PyObject* tensor_contiguous(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY - if (self->tensor.is_dense_tensor()) { - auto dense_tensor = - std::dynamic_pointer_cast(self->tensor.impl()); + if (self->tensor.is_dense_tensor() || self->tensor.is_dist_tensor()) { + phi::DenseTensor* dense_tensor = nullptr; + if (self->tensor.is_dist_tensor()) { + dense_tensor = + static_cast(self->tensor.impl().get()) + ->unsafe_mutable_value(); + } else { + dense_tensor = static_cast(self->tensor.impl().get()); + } if (dense_tensor->meta().is_contiguous()) { Py_INCREF(self); return reinterpret_cast(self); } else { eager_gil_scoped_release guard; - self->tensor.set_impl(std::make_shared(std::move( - paddle::experimental::Trans2Contiguous(*(dense_tensor.get()))))); + *dense_tensor = paddle::experimental::Trans2Contiguous(*dense_tensor); Py_INCREF(self); return reinterpret_cast(self); } - } else { Py_INCREF(self); return reinterpret_cast(self); @@ -2967,6 +3055,11 @@ static PyObject* tensor_is_contiguous(TensorObject* self, auto dense_tensor = std::dynamic_pointer_cast(self->tensor.impl()); return ToPyObject(dense_tensor->meta().is_contiguous()); + } else if (self->tensor.is_dist_tensor()) { + auto dense_tensor = std::dynamic_pointer_cast( + self->tensor.impl()) + ->unsafe_mutable_value(); + return ToPyObject(dense_tensor->meta().is_contiguous()); } else { return ToPyObject(true); } @@ -2991,19 +3084,27 @@ static PyObject* tensor_method__uva(TensorObject* self, PyObject* kwargs) { EAGER_TRY VLOG(4) << "Running in tensor_method__uva."; - PADDLE_ENFORCE_EQ(self->tensor.is_dense_tensor(), - true, - platform::errors::InvalidArgument( - "Unified virtual addressing only support " - "DenseTensor currently.")); + PADDLE_ENFORCE_EQ( + self->tensor.is_dense_tensor() || self->tensor.is_dist_tensor(), + true, + platform::errors::InvalidArgument( + "Unified virtual addressing only support " + "DenseTensor and DistTensor currently.")); PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()), true, platform::errors::InvalidArgument( "Unified virtual addressing only support " "CPU Tensor currently.")); int device_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0); - auto* self_tensor = static_cast(self->tensor.impl().get()); - tensor_uva(self_tensor, device_id); + phi::DenseTensor* dense_tensor = nullptr; + if (self->tensor.is_dist_tensor()) { + dense_tensor = + static_cast(self->tensor.impl().get()) + ->unsafe_mutable_value(); + } else { + dense_tensor = static_cast(self->tensor.impl().get()); + } + tensor_uva(dense_tensor, device_id); RETURN_PY_NONE diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 46170298ce42b..298e191f062a5 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -50,6 +50,7 @@ extern PyTypeObject* p_string_tensor_type; extern PyTypeObject* g_framework_scope_pytype; extern PyTypeObject* g_ir_opresult_pytype; +extern PyTypeObject* g_ir_value_pytype; extern PyTypeObject* g_vartype_pytype; extern PyTypeObject* g_data_type_pytype; extern PyTypeObject* g_place_pytype; @@ -1521,6 +1522,8 @@ pir::Value CastPyArg2Value(PyObject* obj, size_t arg_pos) { if (PyObject_TypeCheck(obj, g_ir_opresult_pytype)) { return ::pybind11::handle(obj).cast(); + } else if (PyObject_TypeCheck(obj, g_ir_value_pytype)) { + return ::pybind11::handle(obj).cast(); } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -1583,8 +1586,6 @@ std::vector CastPyArg2VectorOfValue(PyObject* obj, ->tp_name)); // NOLINT } } - } else if (PyObject_TypeCheck(obj, g_ir_opresult_pytype)) { - return {::pybind11::handle(obj).cast()}; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " diff --git a/paddle/fluid/pybind/eval_frame.c b/paddle/fluid/pybind/eval_frame.c index 5b4f216be24dc..492159fd1ae05 100644 --- a/paddle/fluid/pybind/eval_frame.c +++ b/paddle/fluid/pybind/eval_frame.c @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/pybind/eval_frame.h" +#include "paddle/fluid/pybind/eval_frame_tools.h" #include #include @@ -458,6 +459,7 @@ inline static PyObject *eval_custom_code_py311_plus(PyThreadState *tstate, // Create a new function object from code object. Refer to MAKE_FUNCTION. PyFunctionObject *func = (PyFunctionObject *)PyFunction_New((PyObject *)code, frame->f_globals); + Py_INCREF(func); #if PY_VERSION_HEX < 0x030c0000 Py_XINCREF(frame->f_func->func_closure); func->func_closure = frame->f_func->func_closure; @@ -559,10 +561,19 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, FrameObject *frame, int throw_flag, PyObject *callback) { + PyObject *out; + eval_frame_callback_set(Py_None); + // https://peps.python.org/pep-0558/#fast-locals-proxy-implementation-details // https://devguide.python.org/internals/interpreter/#all-sorts-of-variables #if PY_VERSION_HEX >= 0x030b0000 if (frame->owner == FRAME_OWNED_BY_GENERATOR) { + out = eval_frame_default(tstate, frame, throw_flag); + eval_frame_callback_set(callback); + return out; + } + if (PyBytes_GET_SIZE(frame->f_code->co_exceptiontable)) { + eval_frame_callback_set(callback); return eval_frame_default(tstate, frame, throw_flag); } // PyFrame_FastToLocalsWithError receives a PyFrameObject, but if we created a @@ -577,6 +588,11 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, if (Internal_PyFrame_FastToLocalsWithError(frame) < 0) { #endif #else + if (frame->f_code->co_flags & 0x20) { + out = eval_frame_default(tstate, frame, throw_flag); + eval_frame_callback_set(callback); + return out; + } if (PyFrame_FastToLocalsWithError(frame) < 0) { #endif return NULL; @@ -593,58 +609,70 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, // # <--- which Cause the PyObject_CallObject raise // SystemError. if (PyErr_ExceptionMatches(PyExc_GeneratorExit)) { - return eval_frame_default(tstate, frame, throw_flag); + out = eval_frame_default(tstate, frame, throw_flag); + eval_frame_callback_set(callback); + return out; } - // We don't run the current custom_eval_frame behavior for guards. - // So we temporarily set the callback to Py_None to drive the correct behavior - // in the shim. - eval_frame_callback_set(Py_None); + PyObject *code; + PyObject *disable_eval_frame; + // get code & disable_eval_frame + if (need_skip(frame)) { + Py_INCREF(Py_None); + code = Py_None; + Py_INCREF(Py_False); + disable_eval_frame = Py_False; + } else { + /* should calculate guards here if we want */ #if PY_VERSION_HEX >= 0x030b0000 - PyObject *args = Py_BuildValue("(O)", PyInterpreterFrameProxy_New(frame)); + PyObject *args = Py_BuildValue("(O)", PyInterpreterFrameProxy_New(frame)); #else - PyObject *args = Py_BuildValue("(O)", frame); + PyObject *args = Py_BuildValue("(O)", frame); #endif - PyObject *result = PyObject_CallObject(callback, args); - Py_DECREF(args); - // VLOG(7) << "After call eval_frame_function and decrease frame."; - // class CustomCode(Protocal): - // code: CodeType | None - // disable_eval_frame: bool - // result: CustomCode - if (result == NULL) { - // internal exception - // VLOG(7) << "Error happened."; - return NULL; - } else { - // NOTE: Cache is not supported now - PyCodeObject *code = (PyCodeObject *)PyObject_GetAttrString(result, "code"); - PyObject *disable_eval_frame = - PyObject_GetAttrString(result, "disable_eval_frame"); - PyObject *out = NULL; - // VLOG(7) << "Start eval new frame and code."; - if (disable_eval_frame != Py_True) { - // Re-enable custom behavior - eval_frame_callback_set(callback); - if ((PyObject *)code != Py_None) { - out = eval_custom_code(tstate, frame, code, throw_flag); - } else { - out = eval_frame_default(tstate, frame, throw_flag); - } - } else { - if ((PyObject *)code != Py_None) { - out = eval_custom_code(tstate, frame, code, throw_flag); - } else { - out = eval_frame_default(tstate, frame, throw_flag); - } - // Re-enable custom behavior - eval_frame_callback_set(callback); + PyObject *result = PyObject_CallObject(callback, args); + Py_DECREF(args); + if (result == NULL) { + return NULL; } + code = PyObject_GetAttrString(result, "code"); + disable_eval_frame = PyObject_GetAttrString(result, "disable_eval_frame"); Py_DECREF(result); + } + + // code status + if (is_code_without_graph(code == Py_None ? frame->f_code + : (PyCodeObject *)code) && + disable_eval_frame == Py_False) { + out = eval_frame_default(tstate, frame, throw_flag); + eval_frame_callback_set(callback); Py_DECREF(code); + Py_DECREF(disable_eval_frame); return out; } + + // run code + if (disable_eval_frame != Py_True) { + // Re-enable custom behavior + eval_frame_callback_set(callback); + if (code != Py_None) { + out = eval_custom_code(tstate, frame, (PyCodeObject *)code, throw_flag); + } else { + out = eval_frame_default(tstate, frame, throw_flag); + } + } else { + if (code != Py_None) { + out = eval_custom_code(tstate, frame, (PyCodeObject *)code, throw_flag); + } else { + out = eval_frame_default(tstate, frame, throw_flag); + } + // Re-enable custom behavior + eval_frame_callback_set(callback); + } + + Py_DECREF(code); + Py_DECREF(disable_eval_frame); + return out; } static PyObject *_custom_eval_frame_shim(PyThreadState *tstate, diff --git a/paddle/fluid/pybind/eval_frame_tools.cc b/paddle/fluid/pybind/eval_frame_tools.cc new file mode 100644 index 0000000000000..3b8df99eb2a3f --- /dev/null +++ b/paddle/fluid/pybind/eval_frame_tools.cc @@ -0,0 +1,283 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pybind/eval_frame_tools.h" + +#include + +#include + +#include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" + +/*============================ Dict Tree ================================*/ + +class TreeNode { + public: + TreeNode() = default; + ~TreeNode() { clear(); } + void clear(); + int add_prefix(const char* filename); + int check_filename(const char* filename); + + private: + int is_prefix; + TreeNode* children[256]; +}; + +void TreeNode::clear() { + for (int i = 0; i < 256; i++) { + if (children[i] != NULL) delete children[i]; + } +} + +int TreeNode::add_prefix(const char* filepath) { + if (is_prefix) return 0; + if (filepath[0] == '\0') return 1; + + int ch = (int)filepath[0]; // NOLINT + if (children[ch] == NULL) { + TreeNode* node = new TreeNode(); + children[ch] = node; + } + + if (children[ch]->add_prefix(filepath + 1)) is_prefix = 1; + + return 0; +} + +int TreeNode::check_filename(const char* filename) { + int cur_idx = 0; + TreeNode* cur_node = this; + + while (filename[cur_idx] != '\0') { + cur_node = cur_node->children[(int)filename[cur_idx]]; // NOLINT + if (cur_node == NULL) return 0; + if (cur_node->is_prefix) return 1; + cur_idx += 1; + } + + return 0; +} + +/*========================== utils ==========================*/ + +const char* pystr_to_cstr(PyObject* pystr) { + if (PyUnicode_Check(pystr)) + return PyUnicode_AsUTF8(pystr); + else + PADDLE_THROW(phi::errors::InvalidArgument("Input PyObject is not string!")); +} + +/*========================== SkipCodeInfo ===============================*/ + +class SkipCodeInfo { + public: + static SkipCodeInfo& Instance(); + void clear_code_info(); + + void add_no_skip_code(PyCodeObject* code); + void add_skip_file_prefix(PyObject* filepath); + + int is_no_skip_code(PyCodeObject* code); + int in_skip_path(PyObject* filename); + + private: + SkipCodeInfo() { + no_skip_codes = std::unordered_set(); + skip_codes = std::unordered_set(); + root = new TreeNode(); + } + ~SkipCodeInfo() { clear_code_info(); } + std::unordered_set no_skip_codes; + std::unordered_set skip_codes; + TreeNode* root; +}; + +SkipCodeInfo& SkipCodeInfo::Instance() { + static SkipCodeInfo _instance; + return _instance; +} + +void SkipCodeInfo::clear_code_info() { + no_skip_codes.clear(); + skip_codes.clear(); + root->clear(); +} + +void SkipCodeInfo::add_no_skip_code(PyCodeObject* code) { + no_skip_codes.insert(code); +} + +void SkipCodeInfo::add_skip_file_prefix(PyObject* filepath) { + const char* path = pystr_to_cstr(filepath); + root->add_prefix(path); +} + +int SkipCodeInfo::is_no_skip_code(PyCodeObject* code) { + return no_skip_codes.find(code) != no_skip_codes.end(); +} + +int SkipCodeInfo::in_skip_path(PyObject* filename) { + const char* name = pystr_to_cstr(filename); + return root->check_filename(name); +} + +/*========================== code status ==============================*/ +enum CodeState { UNKNOW, WITH_GRAPH, WITHOUT_GRAPH }; + +class CodeInfo { + public: + CodeState state; + int counter; +}; + +class CodeStatus { + public: + static CodeStatus& Instance(); + int is_code_without_graph(PyCodeObject* code); + void set_with_graph(PyCodeObject* code); + void add_with_graph_code(PyCodeObject* code); + void clear(); + + private: + CodeStatus() { code_map = std::unordered_map(); } + ~CodeStatus() { clear(); } + std::unordered_map code_map; +}; + +CodeStatus& CodeStatus::Instance() { + static CodeStatus _instance; + return _instance; +} + +int CodeStatus::is_code_without_graph(PyCodeObject* code) { + CodeInfo* code_info; + if (code_map.find(code) != code_map.end()) { + code_info = code_map[code]; + } else { + code_info = new CodeInfo(); + code_map.emplace(code, code_info); + } + if (code_info->state == WITHOUT_GRAPH) return 1; + if (code_info->state == UNKNOW) { + code_info->counter += 1; + if (code_info->counter >= 10) code_info->state = WITHOUT_GRAPH; + } + return 0; +} + +void CodeStatus::set_with_graph(PyCodeObject* code) { + CodeInfo* code_info; + if (code_map.find(code) != code_map.end()) { + code_info = code_map[code]; + code_info->state = WITH_GRAPH; + } +} + +void CodeStatus::add_with_graph_code(PyCodeObject* code) { + CodeInfo* code_info; + if (code_map.find(code) != code_map.end()) { + code_info = code_map[code]; + code_info->state = WITH_GRAPH; + } else { + code_info = new CodeInfo(); + code_info->state = WITH_GRAPH; + code_map.emplace(code, code_info); + } +} + +void CodeStatus::clear() { + for (auto iter = code_map.begin(); iter != code_map.end(); iter++) { + delete iter->second; + } + code_map.clear(); +} + +/*========================== interfaces ===============================*/ + +int need_skip(FrameObject* frame) { + auto& skip_info = SkipCodeInfo::Instance(); + PyCodeObject* code = frame->f_code; // NOLINT + PyObject* co_filename = code->co_filename; + + if (skip_info.is_no_skip_code(code)) { + return 0; + } + +#if PY_VERSION_HEX >= 0x030b0000 + const char* filename = pystr_to_cstr(co_filename); + PyObject* _filename = NULL; + if (memcmp(filename, "f_globals; + _filename = PyDict_GetItemString(f_globals, "__file__"); + if (_filename != NULL) { + Py_INCREF(_filename); + co_filename = _filename; + } + } +#endif + + int result = skip_info.in_skip_path(co_filename); + +#if PY_VERSION_HEX >= 0x030b0000 + if (_filename != NULL) Py_DECREF(_filename); +#endif + return result; +} + +int is_code_without_graph(PyCodeObject* code) { + auto& code_status = CodeStatus::Instance(); + return code_status.is_code_without_graph(code); +} + +/*========================== pybind ===============================*/ +PyObject* set_with_graph(PyObject* code) { + auto& code_status = CodeStatus::Instance(); + code_status.set_with_graph((PyCodeObject*)code); // NOLINT + return Py_None; +} + +PyObject* setup_codes_with_graph(PyObject* code_tuple) { + auto& code_status = CodeStatus::Instance(); + Py_ssize_t size = PyTuple_GET_SIZE(code_tuple); + for (Py_ssize_t i = 0; i < size; i++) { + PyCodeObject* code = + (PyCodeObject*)PyTuple_GetItem(code_tuple, i); // NOLINT + code_status.add_with_graph_code(code); + } + return Py_None; +} + +PyObject* no_skip_codes(PyObject* code_tuple) { + auto& skip_info = SkipCodeInfo::Instance(); + Py_ssize_t size = PyTuple_GET_SIZE(code_tuple); + for (Py_ssize_t i = 0; i < size; i++) { + PyCodeObject* code = + (PyCodeObject*)PyTuple_GetItem(code_tuple, i); // NOLINT + skip_info.add_no_skip_code(code); + } + return Py_None; +} + +PyObject* skip_file_prefix(PyObject* filepath_tuple) { + auto& skip_info = SkipCodeInfo::Instance(); + Py_ssize_t size = PyTuple_GET_SIZE(filepath_tuple); + for (Py_ssize_t i = 0; i < size; i++) { + PyObject* code = PyTuple_GetItem(filepath_tuple, i); + skip_info.add_skip_file_prefix(code); + } + return Py_None; +} diff --git a/paddle/fluid/pybind/eval_frame_tools.h b/paddle/fluid/pybind/eval_frame_tools.h new file mode 100644 index 0000000000000..cfcb5940dcfb1 --- /dev/null +++ b/paddle/fluid/pybind/eval_frame_tools.h @@ -0,0 +1,41 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include +#if PY_VERSION_HEX >= 0x030b0000 +#include +typedef _PyInterpreterFrame FrameObject; +#else +typedef PyFrameObject FrameObject; +#endif + +int need_skip(FrameObject* frame); +int is_code_without_graph(PyCodeObject* code); + +PyObject* set_with_graph(PyObject* code); +PyObject* setup_codes_with_graph(PyObject* code_tuple); +PyObject* no_skip_codes(PyObject* code_tuple); +PyObject* skip_file_prefix(PyObject* filepath_tuple); + +#ifdef __cplusplus +} +#endif diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 019b5098feb75..39c22f9301457 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -855,6 +855,9 @@ void BindAnalysisConfig(py::module *m) { .def("enable_memory_optim", &AnalysisConfig::EnableMemoryOptim, py::arg("x") = true) + .def("enable_new_executor", + &AnalysisConfig::EnableNewExecutor, + py::arg("x") = true) .def("enable_profile", &AnalysisConfig::EnableProfile) .def("disable_glog_info", &AnalysisConfig::DisableGlogInfo) .def("glog_info_disabled", &AnalysisConfig::glog_info_disabled) diff --git a/paddle/fluid/pybind/jit.cc b/paddle/fluid/pybind/jit.cc index 09e194bf0b7c8..15b73fda53002 100644 --- a/paddle/fluid/pybind/jit.cc +++ b/paddle/fluid/pybind/jit.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/jit/serializer.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/pybind/eval_frame.h" +#include "paddle/fluid/pybind/eval_frame_tools.h" #include "paddle/utils/pybind.h" namespace py = pybind11; @@ -70,6 +71,42 @@ void BindEvalFrame(pybind11::module *m) { return obj; }, py::arg("callback")); + + m->def( + "sot_setup_codes_with_graph", + [](const py::object &py_codes) { + auto ret = setup_codes_with_graph(py_codes.ptr()); + auto obj = py::reinterpret_borrow(ret); + return obj; + }, + py::arg("py_codes")); + + m->def( + "sot_set_with_graph", + [](const py::object &py_codes) { + auto ret = set_with_graph(py_codes.ptr()); + auto obj = py::reinterpret_borrow(ret); + return obj; + }, + py::arg("py_codes")); + + m->def( + "eval_frame_no_skip_codes", + [](const py::object &py_codes) { + auto ret = no_skip_codes(py_codes.ptr()); + auto obj = py::reinterpret_borrow(ret); + return obj; + }, + py::arg("py_codes")); + + m->def( + "eval_frame_skip_file_prefix", + [](const py::object &py_codes) { + auto ret = skip_file_prefix(py_codes.ptr()); + auto obj = py::reinterpret_borrow(ret); + return obj; + }, + py::arg("py_codes")); } } // namespace pybind diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 3e50bd64ca4ac..3a5716877a59d 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -21,6 +21,7 @@ #include #include +#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/pybind/pybind_variant_caster.h" #include "paddle/pir/core/builtin_op.h" @@ -36,8 +37,13 @@ #include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" +#include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h" +#include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h" #include "paddle/fluid/pir/transforms/inplace_pass.h" +#include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h" #include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/attribute.h" #include "paddle/pir/core/block.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/program.h" @@ -46,14 +52,20 @@ #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/transforms/dead_code_elimination_pass.h" #include "paddle/utils/flags.h" #include "pybind11/stl.h" +#ifdef PADDLE_WITH_CINN +#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" +#include "paddle/fluid/pir/transforms/build_cinn_pass.h" +#endif + namespace py = pybind11; using paddle::dialect::APIBuilder; using paddle::dialect::DenseTensorType; using paddle::dialect::SelectedRowsType; +using pir::Attribute; using pir::Block; using pir::Operation; using pir::OpOperand; @@ -65,8 +77,13 @@ using pir::Type; using pir::Value; using pybind11::return_value_policy; -USE_PASS(dead_code_elimination); -USE_PASS(inplace); +USE_PIR_PASS(dead_code_elimination_pass); +USE_PIR_PASS(attention_fuse_pass); +USE_PIR_PASS(fused_gemm_epilogue_pass); +USE_PIR_PASS(fused_dropout_add_pass); +USE_PIR_PASS(fused_linear_param_grad_add_pass); +USE_PIR_PASS(inplace_pass); +USE_PIR_PASS(replace_fetch_with_shadow_output_pass); PHI_DECLARE_bool(print_ir); @@ -74,6 +91,7 @@ namespace paddle { namespace pybind { PyTypeObject *g_ir_opresult_pytype = nullptr; +PyTypeObject *g_ir_value_pytype = nullptr; void BindOpsAPI(pybind11::module *module); @@ -112,7 +130,8 @@ std::string GetValueInfo(Value v) { } void BindProgram(py::module *m) { - py::class_> program(*m, "Program", R"DOC( + py::class_> program( + *m, "Program", py::dynamic_attr(), R"DOC( Create Python Program. Program is an abstraction of model structure, divided into computational graphs and weights. The Program has a main block that stores the computational graphs. @@ -137,20 +156,41 @@ void BindProgram(py::module *m) { Examples: .. code-block:: python - import paddle - import paddle.static as static - - paddle.enable_static() + >>> import paddle + >>> import paddle.static as static + + >>> paddle.enable_static() + + >>> main_program = static.Program() + >>> startup_program = static.Program() + >>> with static.program_guard(main_program=main_program, startup_program=startup_program): + ... x = static.data(name="x", shape=[-1, 784], dtype='float32') + ... y = static.data(name="y", shape=[-1, 1], dtype='int32') + ... z = static.nn.fc(name="fc", x=x, size=10, activation="relu") + + >>> print("main program is: {}".format(main_program)) + main program is: { // block 0 + var x : LOD_TENSOR.shape(-1, 784).dtype(float32).stop_gradient(True) + var y : LOD_TENSOR.shape(-1, 1).dtype(int32).stop_gradient(True) + persist trainable param fc.w_0 : LOD_TENSOR.shape(784, 10).dtype(float32).stop_gradient(False) + var fc.tmp_0 : LOD_TENSOR.shape(-1, 10).dtype(float32).stop_gradient(False) + persist trainable param fc.b_0 : LOD_TENSOR.shape(10,).dtype(float32).stop_gradient(False) + var fc.tmp_1 : LOD_TENSOR.shape(-1, 10).dtype(float32).stop_gradient(False) + var fc.tmp_2 : LOD_TENSOR.shape(-1, 10).dtype(float32).stop_gradient(False) + + {Out=['fc.tmp_0']} = mul(inputs={X=['x'], Y=['fc.w_0']}, force_fp32_output = False, op_device = , op_namescope = /, op_role = 0, op_role_var = [], scale_out = 1.0, scale_x = 1.0, scale_y = [1.0], use_mkldnn = False, with_quant_attr = False, x_num_col_dims = 1, y_num_col_dims = 1) + {Out=['fc.tmp_1']} = elementwise_add(inputs={X=['fc.tmp_0'], Y=['fc.b_0']}, Scale_out = 1.0, Scale_x = 1.0, Scale_y = 1.0, axis = 1, mkldnn_data_type = float32, op_device = , op_namescope = /, op_role = 0, op_role_var = [], use_mkldnn = False, use_quantizer = False, with_quant_attr = False, x_data_format = , y_data_format = ) + {Out=['fc.tmp_2']} = relu(inputs={X=['fc.tmp_1']}, op_device = , op_namescope = /, op_role = 0, op_role_var = [], use_cudnn = False, use_mkldnn = False, with_quant_attr = False) + } - main_program = static.Program() - startup_program = static.Program() - with static.program_guard(main_program=main_program, startup_program=startup_program): - x = static.data(name="x", shape=[-1, 784], dtype='float32') - y = static.data(name="y", shape=[-1, 1], dtype='int32') - z = static.nn.fc(name="fc", x=x, size=10, activation="relu") + >>> print("start up program is: {}".format(startup_program)) + start up program is: { // block 0 + persist trainable param fc.w_0 : LOD_TENSOR.shape(784, 10).dtype(float32).stop_gradient(False) + persist trainable param fc.b_0 : LOD_TENSOR.shape(10,).dtype(float32).stop_gradient(False) - print("main program is: {}".format(main_program)) - print("start up program is: {}".format(startup_program)) + {Out=['fc.w_0']} = uniform_random(inputs={ShapeTensor=[], ShapeTensorList=[]}, diag_num = 0, diag_step = 0, diag_val = 1.0, dtype = 5, max = 0.08692913502454758, min = -0.08692913502454758, op_device = , op_namescope = /, op_role = 0, op_role_var = [], seed = 0, shape = [784, 10], with_quant_attr = False) + {Out=['fc.b_0']} = fill_constant(inputs={}, dtype = 5, force_cpu = False, op_device = , op_namescope = /, op_role = 0, op_role_var = [], place_type = -1, shape = [10], str_value = 0.0, use_mkldnn = False, value = 0.0, with_quant_attr = False) + } )DOC"); program .def("__init__", @@ -234,6 +274,24 @@ void BindBlock(py::module *m) { None )DOC") + .def( + "move_op", + [](Block &self, Operation *op, uint32_t offset) { + Block::Iterator position = self.begin(); + std::advance(position, offset); + op->MoveTo(&self, position); + }, + R"DOC( + Move an op to a specific position (block.begin() + offset). + + Args: + op (pir.Operation): the operator to be moved. + offset (uint32_t) : offset relative to the begin of the block + + Returns: + None + + )DOC") .def("all_parameters", [](Block &self) -> py::list { py::list param_list; for (auto iter = self.begin(); iter != self.end(); iter++) { @@ -327,6 +385,17 @@ void BindOperation(py::module *m) { } return op_list; }) + .def("get_output_intermediate_status", + [](Operation &self) -> py::list { + py::list op_list; + paddle::dialect::OpYamlInfoInterface yaml_interface = + self.dyn_cast(); + auto outputs_info = std::get<2>(yaml_interface.GetOpInfo()); + for (auto &output_info : outputs_info) { + op_list.append(output_info.intermediate); + } + return op_list; + }) .def("get_input_grad_semantics", [](Operation &self) -> py::list { py::list op_list; @@ -352,6 +421,30 @@ py::str Value2String(const Value &self) { return print_stream.str(); } +phi::DataType GetValueDtype(const Value &value) { + if (value.type().isa()) { + return paddle::dialect::TransToPhiDataType( + value.type().dyn_cast().dtype()); + } else if (value.type().isa()) { + return paddle::dialect::TransToPhiDataType( + value.type().dyn_cast().dtype()); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Currently, we can only get phi::DataType from DenseTensorType and " + "SelectedRowsType.")); + } +} + +phi::DDim GetValueDims(const Value &value) { + if (value.type().isa()) { + return value.type().dyn_cast().dims(); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Currently, we can only get shape for dense " + "tensor.")); + } +} + void BindValue(py::module *m) { py::class_ value(*m, "Value", R"DOC( Value class represents the SSA value in the IR system. It is a directed edge @@ -362,6 +455,7 @@ void BindValue(py::module *m) { when build network. )DOC"); + g_ir_value_pytype = reinterpret_cast(value.ptr()); value .def( "get_defining_op", @@ -387,7 +481,21 @@ void BindValue(py::module *m) { .def("__hash__", [](const Value &self) { return std::hash{}(self); }) .def("__str__", &Value2String) - .def("__repr__", &Value2String); + .def("__repr__", &Value2String) + .def_property( + "shape", + [](Value &self) { return phi::vectorize(GetValueDims(self)); }, + [](Value &self, const std::vector &shape) { + PADDLE_THROW(phi::errors::InvalidArgument( + "can't set shape when building static graph")); + }) + .def_property( + "dtype", + [](Value &self) { return GetValueDtype(self); }, + [](Value &self, phi::DataType dtype) { + PADDLE_THROW(phi::errors::InvalidArgument( + "can't set dtype when building static graph")); + }); } void BindOpOperand(py::module *m) { @@ -698,6 +806,15 @@ void BindType(py::module *m) { }); } +void BindAttribute(py::module *m) { + py::class_ ir_attr(*m, "Attribute", py::module_local()); + ir_attr.def("__str__", [](Attribute &self) { + std::ostringstream print_stream; + print_stream << self; + return print_stream.str(); + }); +} + Operation *BuildOpFrom( Operation *to_copy_op, std::unordered_map &value_map) { // NOLINT @@ -739,20 +856,6 @@ Operation *BuildOpFrom( return cloned_op; } -std::shared_ptr ProgramClone(const Program &program) { - // Limitation of this function: - // 1. don't support Parameters. - // 2. don't support Regions in operator. - pir::IrContext *ctx = pir::IrContext::Instance(); - auto cloned_program = std::make_shared(ctx); - std::unordered_map value_map; - for (auto &op : *program.block()) { - auto *cloned_op = BuildOpFrom(op, value_map); - cloned_program->block()->push_back(cloned_op); - } - return cloned_program; -} - std::list::const_iterator list_offset(const Block *block, int start_idx) { auto it = block->begin(); @@ -868,7 +971,31 @@ static auto GetNoNeedBufferValue(const ::pir::Block *whole_block, no_need_buffer_values.end()); } -SplitedResult ForwardBackwardSplit( +using OpResultMap = std::unordered_map; +std::pair, OpResultMap> CloneProgram( + const Program &program, + const std::vector &op_result_forward_inputs, + const std::vector &op_result_forward_params, + const std::vector &op_result_forward_outputs) { + // Limitation of this function: + // 1. don't support Parameters. + // 2. don't support Regions in operator. + pir::IrContext *ctx = pir::IrContext::Instance(); + auto cloned_program = std::make_shared(ctx); + std::unordered_map value_map; + for (auto &op : *program.block()) { + auto *cloned_op = BuildOpFrom(op, value_map); + cloned_program->block()->push_back(cloned_op); + } + std::unordered_map op_result_map; + for (auto &pair : value_map) { + op_result_map[pair.first.dyn_cast()] = + pair.second.dyn_cast(); + } + return std::make_pair(cloned_program, op_result_map); +} + +SplitedResult SplitForwardBackward( const Program &program, const std::vector &op_result_forward_inputs, const std::vector &op_result_forward_params, @@ -973,11 +1100,28 @@ SplitedResult ForwardBackwardSplit( if (v.impl() == nullptr) { return; } + // NOTE(Aurelius84): we should skip insert SetParameterOp repeatly by + // calling SplitForwardBackward multi-times. + std::string parameter_name = + std::string("output_") + std::to_string(counter); + for (auto it = forward_program->block()->rbegin(); + it != forward_program->block()->rend(); + ++it) { + auto *op = *it; + if (op->isa()) { + auto out_name = + op->attribute("parameter_name").AsString(); + if (out_name == parameter_name) { + VLOG(4) << out_name + << " has been inserted SetParameterOp, skip it now."; + return; + } + } + } + auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name()); pir::AttributeMap attribute_map = { - {"parameter_name", - pir::StrAttribute::get( - ctx, std::string("output_") + std::to_string(counter))}, + {"parameter_name", pir::StrAttribute::get(ctx, parameter_name)}, }; pir::Operation *operation = pir::Operation::Create( {forward_value_map[v]}, attribute_map, {}, op_info); @@ -1105,8 +1249,8 @@ SplitedResult ForwardBackwardSplit( } void BindUtils(pybind11::module *m) { - m->def("program_clone", ProgramClone); - m->def("program_split", ForwardBackwardSplit); + m->def("clone_program", CloneProgram); + m->def("split_program", SplitForwardBackward); m->def("fake_op_result", FakeOpResult); m->def("is_fake_op_result", IsFakeOpResult); m->def("set_global_program", @@ -1122,7 +1266,7 @@ void BindUtils(pybind11::module *m) { ->GetOrRegisterDialect(); }); m->def( - "translate_to_new_ir", + "translate_to_pir", [](const ::paddle::framework::ProgramDesc &legacy_program) { std::shared_ptr ret = std::move(paddle::TranslateLegacyProgramToProgram(legacy_program)); @@ -1144,24 +1288,37 @@ void BindUtils(pybind11::module *m) { Examples: .. code-block:: python - import paddle - from paddle import pir - paddle.enable_static() - - x = paddle.randn([4, 4]) - main_program, start_program = ( - paddle.static.Program(), - paddle.static.Program(), - ) - with paddle.static.program_guard(main_program, start_program): - x_s = paddle.static.data('x', [4, 4], x.dtype) - x_s.stop_gradient = False - y_s = paddle.matmul(x_s, x_s) - z_s = paddle.add(y_s, y_s) - k_s = paddle.tanh(z_s) - newir_program = pir.translate_to_new_ir(main_program.desc) - - print(newir_program) + >>> import os + >>> # Paddle will remove this flag in the next version + >>> pir_flag = 'FLAGS_enable_pir_in_executor' + >>> os.environ[pir_flag] = 'True' + + >>> import paddle + >>> from paddle import pir + >>> paddle.enable_static() + + >>> x = paddle.randn([4, 4]) + >>> main_program, start_program = ( + ... paddle.static.Program(), + ... paddle.static.Program(), + ...) + + >>> with paddle.static.program_guard(main_program, start_program): + ... x_s = paddle.static.data('x', [4, 4], x.dtype) + ... x_s.stop_gradient = False + ... y_s = paddle.matmul(x_s, x_s) + ... z_s = paddle.add(y_s, y_s) + ... k_s = paddle.tanh(z_s) + >>> pir_program = pir.translate_to_pir(main_program.desc) + + >>> print(pir_program) + { + (%0) = "pd_op.data" () {dtype:(pd_op.DataType)float32,is_persisable:[false],name:"x",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[4,4],stop_gradient:[false]} : () -> pd_op.tensor<4x4xf32> + (%1) = "pd_op.matmul" (%0, %0) {is_persisable:[false],stop_gradient:[false],transpose_x:false,transpose_y:false} : (pd_op.tensor<4x4xf32>, pd_op.tensor<4x4xf32>) -> pd_op.tensor<4x4xf32> + (%2) = "pd_op.add" (%1, %1) {is_persisable:[false],stop_gradient:[false]} : (pd_op.tensor<4x4xf32>, pd_op.tensor<4x4xf32>) -> pd_op.tensor<4x4xf32> + (%3) = "pd_op.tanh" (%2) {is_persisable:[false],stop_gradient:[false]} : (pd_op.tensor<4x4xf32>) -> pd_op.tensor<4x4xf32> + } + )DOC"); m->def( @@ -1180,17 +1337,17 @@ void BindUtils(pybind11::module *m) { list[str] : List of unregistered operators in paddle dialect, the name is expressed by origin op name. )DOC"); m->def( - "translate_to_new_ir_with_param_map", + "translate_to_pir_with_param_map", [](const framework::ProgramDesc &legacy_program) { auto ir_ctx = pir::IrContext::Instance(); auto program = std::make_shared(ir_ctx); translator::ProgramTranslator program_translator(&legacy_program, program.get()); program_translator.Translate(); - return std::make_pair(program, program_translator.VarDesc2Value()); + return std::make_pair(program, program_translator.VarDesc2OpResult()); }, R"DOC( - Convert Fluid Program to New IR Program and get the mappings of VarDesc -> pir::Value. + Convert Fluid Program to New IR Program and get the mappings of VarDesc -> pir::OpResult. Args: @@ -1198,7 +1355,7 @@ void BindUtils(pybind11::module *m) { Returns: Program: The New IR Program - dict[str, pir::Value]: Mapping between VarDesc(by name) and pir::Value. + dict[str, pir::OpResult]: Mapping between VarDesc(by name) and pir::OpResult. Raises: PreconditionNotMet: If legacy_program has multi block will raise error. @@ -1206,29 +1363,72 @@ void BindUtils(pybind11::module *m) { Examples: .. code-block:: python - import paddle - from paddle import pir - paddle.enable_static() - - x = paddle.randn([4, 4]) - main_program, start_program = ( - paddle.static.Program(), - paddle.static.Program(), - ) - with paddle.static.program_guard(main_program, start_program): - x_s = paddle.static.data('x', [4, 4], x.dtype) - x_s.stop_gradient = False - y_s = paddle.matmul(x_s, x_s) - z_s = paddle.add(y_s, y_s) - k_s = paddle.tanh(z_s) - newir_program, mappings = pir.translate_to_new_ir_with_param_map(main_program.desc) - - print(newir_program) - print(mappings) + >>> import os + >>> # Paddle will remove this flag in the next version + >>> pir_flag = 'FLAGS_enable_pir_in_executor' + >>> os.environ[pir_flag] = 'True' + + >>> import paddle + >>> from paddle import pir + >>> paddle.enable_static() + + >>> x = paddle.randn([4, 4]) + >>> main_program, start_program = ( + ... paddle.static.Program(), + ... paddle.static.Program(), + ... ) + + >>> with paddle.static.program_guard(main_program, start_program): + ... x_s = paddle.static.data('x', [4, 4], x.dtype) + ... x_s.stop_gradient = False + ... y_s = paddle.matmul(x_s, x_s) + ... z_s = paddle.add(y_s, y_s) + ... k_s = paddle.tanh(z_s) + >>> pir_program, mappings = pir.translate_to_pir_with_param_map(main_program.desc) + + >>> print(pir_program) + { + (%0) = "pd_op.data" () {dtype:(pd_op.DataType)float32,is_persisable:[false],name:"x",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[4,4],stop_gradient:[false]} : () -> pd_op.tensor<4x4xf32> + (%1) = "pd_op.matmul" (%0, %0) {is_persisable:[false],stop_gradient:[false],transpose_x:false,transpose_y:false} : (pd_op.tensor<4x4xf32>, pd_op.tensor<4x4xf32>) -> pd_op.tensor<4x4xf32> + (%2) = "pd_op.add" (%1, %1) {is_persisable:[false],stop_gradient:[false]} : (pd_op.tensor<4x4xf32>, pd_op.tensor<4x4xf32>) -> pd_op.tensor<4x4xf32> + (%3) = "pd_op.tanh" (%2) {is_persisable:[false],stop_gradient:[false]} : (pd_op.tensor<4x4xf32>) -> pd_op.tensor<4x4xf32> + } + + >>> print(mappings) + {'matmul_v2_0.tmp_0': [Value(define_op_name=pd_op.matmul, index=0, dtype=pd_op.tensor<4x4xf32>)], 'x': [Value(define_op_name=pd_op.data, index=0, dtype=pd_op.tensor<4x4xf32>)], 'tanh_0.tmp_0': [Value(define_op_name=pd_op.tanh, index=0, dtype=pd_op.tensor<4x4xf32>)], 'elementwise_add_0': [Value(define_op_name=pd_op.add, index=0, dtype=pd_op.tensor<4x4xf32>)]} )DOC"); + + m->def("clear_pir_compiler_manager", []() { +#ifdef PADDLE_WITH_CINN + pybind11::gil_scoped_release release; + VLOG(4) << "clear PirCompilerManager and free PirCompiler resources."; + cinn::hlir::framework::PirCompilerManager::Instance().clear(); +#endif + }); } +// TODO(Aurelius84): Need consider to make an agreement about +// what a Pass should receive and return. Existed Passes have +// mutable and immutable interface. +std::shared_ptr ApplyPirPass(Program &forward_program) { // NOLINT +#ifdef PADDLE_WITH_CINN + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::PassManager pass_manager(ctx); + pass_manager.AddPass(pir::CreateBuildCinnPass()); + pass_manager.Run(&forward_program); + VLOG(3) << "after BuildCinnPass, forward_program:\n" << forward_program; + std::unique_ptr new_program = + cinn::dialect::ir::CINNGroupLoweringPass(&forward_program); + VLOG(3) << "after CINNGroupLoweringPass, forward_program:\n" << *new_program; + return std::move(new_program); +#endif + PADDLE_THROW(platform::errors::Unimplemented( + "Currently we only support CINN Pass for Pir under @to_static, please " + "compile PaddlePaddle with CINN")); +} void BindIrPass(pybind11::module *m) { + m->def("apply_pir_pass", ApplyPirPass); + py::class_> pass(*m, "Pass", R"DOC( @@ -1274,7 +1474,7 @@ void BindPassManager(pybind11::module *m) { .def("empty", &PassManager::Empty); } -void BindPIR(pybind11::module *module) { +void BindPir(pybind11::module *module) { auto ir_module = module->def_submodule("pir"); BindProgram(&ir_module); BindBlock(&ir_module); @@ -1283,6 +1483,7 @@ void BindPIR(pybind11::module *module) { BindOpOperand(&ir_module); BindOpResult(&ir_module); BindType(&ir_module); + BindAttribute(&ir_module); BindUtils(&ir_module); BindIrPass(&ir_module); BindPassManager(&ir_module); diff --git a/paddle/fluid/pybind/pir.h b/paddle/fluid/pybind/pir.h index b64de63452f40..5bc01c63e62e7 100644 --- a/paddle/fluid/pybind/pir.h +++ b/paddle/fluid/pybind/pir.h @@ -18,6 +18,6 @@ namespace paddle { namespace pybind { -void BindPIR(pybind11::module *m); +void BindPir(pybind11::module *m); } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 0308d06d9305e..fefb64e25b3e9 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -94,7 +94,8 @@ static void DeserializeMessage(T *self, const std::string &str) { // Bind Methods void BindProgramDesc(pybind11::module *m) { - pybind11::class_(*m, "ProgramDesc", "") + pybind11::class_>( + *m, "ProgramDesc", "") .def(pybind11::init<>()) .def("__init__", [](pd::ProgramDesc &self, const pd::ProgramDesc &other) { @@ -191,6 +192,7 @@ void BindBlockDesc(pybind11::module *m) { std::string name = byte_name; return self.HasVarRecursive(name); }) + .def("set_parent_idx", &pd::BlockDesc::SetParent) .def( "find_var", [](pd::BlockDesc &self, pybind11::bytes byte_name) { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 5aa8a552e3437..87d6a029ccf78 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -195,6 +195,7 @@ limitations under the License. */ #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/nan_inf_utils.h" #include "paddle/fluid/imperative/layout_autotune.h" +#include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" #include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" #include "paddle/fluid/prim/utils/eager/eager_tensor_operants.h" @@ -694,6 +695,8 @@ void BindVjp(pybind11::module *m) { m->def( "call_vjp", [](pir::Operation &fwd_op, + const std::vector> &inputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients) { py::list res; @@ -703,8 +706,8 @@ void BindVjp(pybind11::module *m) { vjp_interface, phi::errors::InvalidArgument( "The vjp function is not registered in %s op ", fwd_op.name())); - std::vector> vjp_res = - vjp_interface.Vjp(&fwd_op, out_grads, stop_gradients); + std::vector> vjp_res = vjp_interface.Vjp( + &fwd_op, inputs, outputs, out_grads, stop_gradients); PADDLE_ENFORCE_EQ( stop_gradients.size(), vjp_res.size(), @@ -766,6 +769,42 @@ void BindVjp(pybind11::module *m) { out (bool): True means that the op has custom vjp rules, False means it does not. )DOC"); } + +void BindDecomp(pybind11::module *m) { + m->def("call_decomp", [](pir::Operation &fwd_op) { + py::list res; + paddle::dialect::DecompInterface decomp_interface = + fwd_op.dyn_cast(); + PADDLE_ENFORCE( + decomp_interface, + phi::errors::InvalidArgument( + "The decomp function is not registered in %s op ", fwd_op.name())); + std::vector> decomp_res = + decomp_interface.Decomp(&fwd_op); + for (size_t i = 0; i < decomp_res.size(); ++i) { + py::list sub_res; + for (size_t j = 0; j < decomp_res[i].size(); ++j) { + if (!decomp_res[i][j]) { + sub_res.append(nullptr); + } else { + sub_res.append(decomp_res[i][j]); + } + } + res.append(sub_res); + } + return res; + }); + + m->def("has_decomp", [](pir::Operation &fwd_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::OpInfo fwd_op_info = ctx->GetRegisteredOpInfo(fwd_op.name()); + auto decomp_interface_impl = + fwd_op_info.GetInterfaceImpl(); + if (decomp_interface_impl == nullptr) return false; + return true; + }); +} + PYBIND11_MODULE(libpaddle, m) { BindImperative(&m); BindEager(&m); @@ -1987,8 +2026,6 @@ All parameter, weight, gradient are variables in Paddle. .def(py::init(), py::arg("type")) .def("micro_batch_id", &framework::interpreter::Job::MicroBatchId) .def("type", &framework::interpreter::Job::Type) - .def("set_col_attr_for_fetch_op", - &framework::interpreter::Job::SetColAttrForFetchOp) .def("set_micro_batch_id", &framework::interpreter::Job::SetMicroBatchId) .def("set_skip_gc_vars", &framework::interpreter::Job::SetSkipGcVars); @@ -1996,7 +2033,8 @@ All parameter, weight, gradient are variables in Paddle. .def( py::init< const std::vector> &, - const std::unordered_map + const std::unordered_map> &>(), py::arg("job_list"), py::arg("type_to_program")) @@ -2008,7 +2046,10 @@ All parameter, weight, gradient are variables in Paddle. py::arg("job_list"), py::arg("type_to_ir_program")) .def("job_list", &framework::interpreter::Plan::JobList) + .def("job_types", &framework::interpreter::Plan::JobTypes) .def("micro_batch_num", &framework::interpreter::Plan::MicroBatchNum) + .def("set_ir_program", &framework::interpreter::Plan::SetIrProgram) + .def("ir_program", &framework::interpreter::Plan::IrProgram) .def("program", &framework::interpreter::Plan::Program); m.def("init_gflags", framework::InitGflags); @@ -2938,8 +2979,9 @@ All parameter, weight, gradient are variables in Paddle. GetAllWorkerInfos(&m); #endif - BindPIR(&m); + BindPir(&m); BindVjp(&m); + BindDecomp(&m); } } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/tensor.cc b/paddle/fluid/pybind/tensor.cc index 5b6efa9e1dba9..7205333bb688c 100644 --- a/paddle/fluid/pybind/tensor.cc +++ b/paddle/fluid/pybind/tensor.cc @@ -1038,7 +1038,13 @@ void BindTensor(pybind11::module &m) { // NOLINT [](DistTensor &self) { return self.value(); }, py::return_value_policy::reference) .def("numel", - [](DistTensor &self) -> int64_t { return self.value().numel(); }); + [](DistTensor &self) -> int64_t { return self.value().numel(); }) + .def("_share_data_with", [](DistTensor &self, const DistTensor &src) { + self.unsafe_set_dims(src.dims()); + self.unsafe_set_dist_attr(src.dist_attr()); + self.unsafe_mutable_value()->ShareDataWith(src.value()); + return self; + }); #endif py::class_(m, "SelectedRows") diff --git a/paddle/phi/CMakeLists.txt b/paddle/phi/CMakeLists.txt index cfbf8fec0adfd..d36ccd67a4512 100644 --- a/paddle/phi/CMakeLists.txt +++ b/paddle/phi/CMakeLists.txt @@ -14,7 +14,6 @@ set(infermeta_srcs CACHE INTERNAL "" FORCE) # paddle experimental common components add_subdirectory(common) - # phi (low level) api headers: include # phi (high level) api add_subdirectory(api) diff --git a/paddle/phi/api/ext/op_meta_info.h b/paddle/phi/api/ext/op_meta_info.h index c774cafcfd26a..484ea06944653 100644 --- a/paddle/phi/api/ext/op_meta_info.h +++ b/paddle/phi/api/ext/op_meta_info.h @@ -120,15 +120,15 @@ class PADDLE_API CustomOpKernelContext { std::vector InputsBetween(size_t start, size_t end) const; Tensor& MutableInputAt(size_t idx); std::vector* AllMutableInput(); - paddle::optional OptionalInputAt(size_t idx); + paddle::optional OptionalInputAt(size_t idx) const; paddle::optional> OptionalInputsBetween(size_t start, - size_t end); + size_t end) const; const std::vector& Attrs() const; - const std::vector>& InputRange(); - const std::vector>& OutputRange(); + const std::vector>& InputRange() const; + const std::vector>& OutputRange() const; Tensor* MutableOutputAt(size_t idx); std::vector MutableOutputBetween(size_t start, size_t end); - std::vector OutputsBetween(size_t start, size_t end); + std::vector OutputsBetween(size_t start, size_t end) const; std::vector* AllMutableOutput(); template @@ -151,8 +151,8 @@ class PADDLE_API CustomOpKernelContext { const std::unordered_map& inplace_map); void AssignInplaceOutputs(); std::vector* AllMutablePlainOutput(); - std::unordered_map GetInplaceIndexMap(); - std::unordered_map GetInplaceReverseIndexMap(); + std::unordered_map GetInplaceIndexMap() const; + std::unordered_map GetInplaceReverseIndexMap() const; private: // TODO(chenweihang): replaced be SmallVector diff --git a/paddle/phi/api/include/tensor_utils.h b/paddle/phi/api/include/tensor_utils.h index 56ed9ae12feb4..0f5f1f1f8744e 100644 --- a/paddle/phi/api/include/tensor_utils.h +++ b/paddle/phi/api/include/tensor_utils.h @@ -17,6 +17,10 @@ limitations under the License. */ #include #include "paddle/phi/api/include/tensor.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" +#endif namespace paddle { @@ -50,4 +54,22 @@ PADDLE_API Tensor from_blob(void* data, const phi::Place& place = phi::Place(), const Deleter& deleter = nullptr); +#ifdef PADDLE_WITH_DISTRIBUTE +/** + * @brief Reshard a DistTensor by given DistAttr. + * + * @note Input of `Reshard` should be a `paddle::Tensor` whose impl is + * shared_ptr of DistTensor. According to the given DistAttr, input will be + * reshard to wanted distributed state. And it will return shared_ptr of a new + * DistTensor as outptut. + * + * @param input The input tensor to be resharded. + * @param dist_attr The dist_attr to be resharded. + * @return Shared_ptr of a new DistTensor + */ +// TODO(GhostScreaming): All APIs should call this unified function later. +PADDLE_API std::shared_ptr reshard( + const paddle::Tensor& input, + const phi::distributed::TensorDistAttr& dist_attr); +#endif } // namespace paddle diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 840f761482684..1d0ae74a92a49 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -28,7 +28,10 @@ limitations under the License. */ #include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/nullary.h" #include "paddle/phi/infermeta/unary.h" - +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#endif namespace paddle { namespace experimental { @@ -57,7 +60,8 @@ Tensor add_n_impl(const std::vector& x) { bool is_sr_kernel = true; for (auto& input : x) { - if (phi::DenseTensor::classof(input.impl().get())) { + if (phi::DenseTensor::classof(input.impl().get()) || + phi::distributed::DistTensor::classof(input.impl().get())) { is_sr_kernel = false; break; } @@ -98,6 +102,82 @@ Tensor add_n_impl(const std::vector& x) { (*kernel_fn)(*dev_ctx, input_x, kernel_out); } else { +#ifdef PADDLE_WITH_DISTRIBUTE + bool run_auto_parallel = AllInputsAreDistTensor(x); + bool rank_is_in_current_mesh = true; + if (run_auto_parallel) { + auto mesh = + std::static_pointer_cast(x[0].impl()) + ->dist_attr() + .process_mesh(); + rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); + + std::vector input_x(x.size()); + for (size_t i = 0; i < input_x.size(); ++i) { + input_x[i] = x[i].impl().get(); + } + + auto meta_dist_input_x = MakeDistMetaTensor(input_x); + auto spmd_info = phi::distributed::VariadicReplicatedInferSpmdDynamic( + meta_dist_input_x); + + auto dist_out = SetKernelDistOutput(&api_output); + auto dense_out = dist_out->unsafe_mutable_value(); + if (!rank_is_in_current_mesh) { + *dense_out = phi::DenseTensor( + std::make_shared( + nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + } + + phi::MetaTensor meta_dist_out(dist_out); + auto x_meta_vec = MakeMetaTensor(input_x); + std::vector x_metas(x_meta_vec.size()); + for (size_t i = 0; i < x_meta_vec.size(); ++i) { + x_metas[i] = &x_meta_vec[i]; + } + phi::AddNInferMeta(x_metas, &meta_dist_out); + if (rank_is_in_current_mesh) { + auto dist_input_x = + ReshardApiInputToKernelInput(dev_ctx, x, spmd_info.first[0]); + dist_input_x = PrepareDataForDistTensor( + dist_input_x, + GetKernelInputArgDef(kernel.InputAt(0), kernel_backend), + {}, + kernel_result.is_stride_kernel); + std::vector input_x(dist_input_x.size()); + for (size_t i = 0; i < dist_input_x.size(); ++i) { + input_x[i] = dist_input_x[i]->unsafe_mutable_value(); + } + + auto x_meta_vec = MakeMetaTensor(input_x); + std::vector x_metas(x_meta_vec.size()); + for (size_t i = 0; i < x_meta_vec.size(); ++i) { + x_metas[i] = &x_meta_vec[i]; + } + phi::MetaTensor meta_dense_out(dense_out); + phi::AddNInferMeta(x_metas, &meta_dense_out); + + using kernel_signature = + void (*)(const phi::DeviceContext&, + const std::vector&, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(*dev_ctx, input_x, dense_out); + } + PADDLE_ENFORCE_EQ(paddle::holds_alternative< + std::vector>( + spmd_info.first[0]), + true, + phi::errors::PreconditionNotMet( + "Arg must be a vector of TensorDistAttr")); + + auto current_process_mesh = + paddle::get<1>(spmd_info.first[0]).at(0).process_mesh(); + SetReplicatedDistAttrForOutput(dist_out, current_process_mesh); + return api_output; + } +#endif std::vector input_x(x.size()); std::vector> temp_dense_tensots; temp_dense_tensots.reserve(x.size()); diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 71257dc588dac..c31a3b1fec235 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -536,10 +536,19 @@ phi::distributed::DistMetaTensor MakeDistMetaTensor( return phi::distributed::DistMetaTensor(tensor); } +std::vector MakeDistMetaTensor( + const std::vector& tensors) { + std::vector meta_tensors; + meta_tensors.reserve(tensors.size()); + for (const auto* t : tensors) { + meta_tensors.emplace_back(*t); + } + return meta_tensors; +} + phi::distributed::DistTensor* SetKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) { if (out) { - // TODO(chenweihang): now all dist case are nullptr if (out->impl() == nullptr) { auto dist_t = std::make_shared(phi::DDim(), dist_attr); @@ -550,11 +559,52 @@ phi::distributed::DistTensor* SetKernelDistOutput( return nullptr; } +phi::distributed::DistTensor* SetKernelDistOutput( + Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr")); + return SetKernelDistOutput(out, paddle::get<0>(dist_attr)); +} + +std::shared_ptr CreateKernelDistOutput( + Tensor* out, + bool set_dist_output_as_tensor_impl, + const phi::distributed::ArgDistAttr& dist_attr) { + if (out) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr")); + auto dist_output = std::make_shared( + phi::DDim(), paddle::get<0>(dist_attr)); + if (set_dist_output_as_tensor_impl) { + VLOG(3) << "CreateKernelDistOutput function set generated output " + "dist_tensor as Tensor's impl"; + if (out->is_dist_tensor()) { + VLOG(3) + << "out is DistTensor, set its DistAttr to generated DistOutput."; + dist_output->unsafe_set_dist_attr( + std::static_pointer_cast(out->impl()) + ->dist_attr()); + } + out->set_impl(dist_output); + } + return dist_output; + } + return nullptr; +} + std::shared_ptr CreateKernelDistOutput( - Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) { + Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) { if (out) { - return std::make_shared(phi::DDim(), - dist_attr); + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr")); + return std::make_shared( + phi::DDim(), paddle::get<0>(dist_attr)); } return nullptr; } @@ -617,6 +667,7 @@ void SetReplicatedDistAttrForOutput( phi::distributed::DistTensor* out, const phi::distributed::ProcessMesh& process_mesh) { if (out) { + // For inplace output, we also need to set replicated dist attr auto dist_attr = phi::distributed::TensorDistAttr(phi::vectorize(out->dims())); dist_attr.set_process_mesh(process_mesh); diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index a57d951ce738f..1a29277b5154f 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/sparse_coo_tensor.h" @@ -139,16 +140,26 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx, phi::distributed::DistMetaTensor MakeDistMetaTensor( const phi::TensorBase& tensor); +std::vector MakeDistMetaTensor( + const std::vector& tensors); + phi::distributed::DistTensor* SetKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr = phi::distributed::TensorDistAttr()); +phi::distributed::DistTensor* SetKernelDistOutput( + Tensor* out, const phi::distributed::ArgDistAttr& dist_attr); + std::shared_ptr CreateKernelDistOutput( Tensor* out, - const phi::distributed::TensorDistAttr& dist_attr = + bool set_dist_output_as_tensor_impl, + const phi::distributed::ArgDistAttr& dist_attr = phi::distributed::TensorDistAttr()); +std::shared_ptr CreateKernelDistOutput( + Tensor* out, const phi::distributed::ArgDistAttr& dist_attr); + std::vector SetKernelDistOutput( std::vector out); diff --git a/paddle/phi/api/lib/context_pool.cc b/paddle/phi/api/lib/context_pool.cc index 8066147025117..ee1e21a58e2f1 100644 --- a/paddle/phi/api/lib/context_pool.cc +++ b/paddle/phi/api/lib/context_pool.cc @@ -61,7 +61,7 @@ const phi::DeviceContext* DeviceContextPool::Get(const Place& place) { } phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) { - return const_cast(Get(place)); + return const_cast(Get(place)); // NOLINT } } // namespace experimental @@ -72,7 +72,7 @@ namespace paddle { PADDLE_API phi::Allocator* GetAllocator(const phi::Place& place) { const phi::DeviceContext* dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - return const_cast(&dev_ctx->GetAllocator()); + return const_cast(&dev_ctx->GetAllocator()); // NOLINT } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 8ba76b64f5f7a..97508f7920716 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -23,8 +23,8 @@ limitations under the License. */ #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/backends/context_pool.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" @@ -623,42 +623,93 @@ std::string ReshardDebugInfo( std::shared_ptr ReshardApiInputToKernelInput( phi::DeviceContext* dev_ctx, const Tensor& tensor, - const phi::distributed::TensorDistAttr& dist_attr) { + const phi::distributed::ArgDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); + auto tensor_in = tensor.impl(); + const auto& tensor_dist_attr = paddle::get<0>(dist_attr); if (tensor_in) { phi::distributed::DistTensor* dist_tensor = static_cast(tensor_in.get()); - if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) { - VLOG(6) << "ApiIn to KernelIn - " - << ReshardDebugInfo(*dist_tensor, dist_attr); - auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, - dist_attr); - return func->Eval(dev_ctx, *dist_tensor, dist_attr); + if (ReshardIsNeeded(dist_tensor->dist_attr(), tensor_dist_attr)) { + VLOG(6) << "ApiIn to Replicated KernelIn - " + << ReshardDebugInfo(*dist_tensor, tensor_dist_attr); + auto* func = phi::distributed::ChooseProperReshardFunction( + *dist_tensor, tensor_dist_attr); + return func->Eval(dev_ctx, *dist_tensor, tensor_dist_attr); } return std::static_pointer_cast(tensor_in); } return nullptr; } -std::shared_ptr -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::TensorDistAttr& dist_attr) { - auto tensor_in = tensor.impl(); - if (tensor_in) { - phi::distributed::DistTensor* dist_tensor = - static_cast(tensor_in.get()); - if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) { - VLOG(6) << "ApiIn to Replicated KernelIn - " - << ReshardDebugInfo(*dist_tensor, dist_attr); - auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, - dist_attr); - return func->Eval(dev_ctx, *dist_tensor, dist_attr); +std::vector> +ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const phi::distributed::ArgDistAttr& dist_attrs) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative>( + dist_attrs), + true, + phi::errors::PreconditionNotMet( + "Arg must be a vector of TensorDistAttr")); + const auto& tensor_dist_attrs = paddle::get<1>(dist_attrs); + + PADDLE_ENFORCE_EQ(tensors.size(), + tensor_dist_attrs.size(), + phi::errors::InvalidArgument( + "Tensor's size should be equal to dist_attrs' size.")); + + std::vector> out; + for (size_t i = 0; i < tensors.size(); i++) { + auto tensor_in = tensors[i].impl(); + auto dist_attr = tensor_dist_attrs[i]; + if (tensor_in) { + phi::distributed::DistTensor* dist_tensor = + static_cast(tensor_in.get()); + if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) { + VLOG(6) << "Vector ApiIn to Replicated KernelIn - " + << ReshardDebugInfo(*dist_tensor, dist_attr); + auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, + dist_attr); + out.push_back(func->Eval(dev_ctx, *dist_tensor, dist_attr)); + } + out.push_back( + std::static_pointer_cast(tensor_in)); + } else { + out.push_back(nullptr); } - return std::static_pointer_cast(tensor_in); } - return nullptr; + return out; +} + +paddle::optional> +ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, + const paddle::optional& tensor, + const phi::distributed::ArgDistAttr& dist_attr) { + if (tensor) { + VLOG(6) << "Optional ApiIn to Replicated KernelIn."; + return paddle::make_optional>( + ReshardApiInputToKernelInput(dev_ctx, *tensor, dist_attr)); + } + return paddle::none; +} + +paddle::optional>> +ReshardApiInputToKernelInput( + phi::DeviceContext* dev_ctx, + const paddle::optional>& tensors, + const phi::distributed::ArgDistAttr& dist_attrs) { + if (tensors) { + VLOG(6) << "Optional ApiIn to Replicated KernelIn."; + return paddle::make_optional< + std::vector>>( + ReshardApiInputToKernelInput(dev_ctx, *tensors, dist_attrs)); + } + return paddle::none; } void ReshardOutputPartialAxisToReplicated( @@ -666,6 +717,12 @@ void ReshardOutputPartialAxisToReplicated( if (out_tensor->dist_attr().is_partial()) { auto dist_attr = out_tensor->dist_attr(); dist_attr.clean_partial_status(); + if (!IsCurRankInMesh(out_tensor->dist_attr().process_mesh())) { + VLOG(6) << "DistTensor is not in mesh, just clear its partial status and " + "skip reshard it to replicated."; + out_tensor->unsafe_set_dist_attr(dist_attr); + return; + } VLOG(6) << "FwdAPI Output P2R - " << ReshardDebugInfo(*out_tensor, dist_attr); auto* func = @@ -706,19 +763,7 @@ void ReshardKernelOutputToApiOutput( } std::shared_ptr PrepareDataForDistTensor( - const Tensor& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel) { - return PrepareDataForDistTensor( - std::static_pointer_cast(input.impl()), - target_args_def, - transform_flag, - is_stride_kernel); -} - -std::shared_ptr PrepareDataForDistTensor( - const std::shared_ptr& input, + std::shared_ptr input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel) { @@ -753,13 +798,13 @@ std::shared_ptr PrepareDataForDistTensor( } std::vector> -PrepareDataForDistTensor(const std::vector& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel) { +PrepareDataForDistTensor( + std::vector> input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel) { std::vector> out; - for (auto& x : input) { - const auto& tensor_in = x.impl(); + for (auto tensor_in : input) { if (tensor_in) { phi::distributed::DistTensor* dist_tensor = static_cast(tensor_in.get()); @@ -795,26 +840,37 @@ PrepareDataForDistTensor(const std::vector& input, return out; } -paddle::optional PrepareDataForDistTensor( - const paddle::optional& input, +paddle::optional> +PrepareDataForDistTensor( + paddle::optional> input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel) { if (input) { - return {*PrepareDataForDistTensor( - *input, target_args_def, transform_flag, is_stride_kernel)}; + VLOG(6) << "PrepareDataForDistTensor for optional return transformed dist " + "tensor"; + return paddle::make_optional>( + PrepareDataForDistTensor( + *input, target_args_def, transform_flag, is_stride_kernel)); } return paddle::none; } paddle::optional>> -PrepareDataForDistTensor(const paddle::optional>& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel) { +PrepareDataForDistTensor( + paddle::optional>> + input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel) { if (input) { - return PrepareDataForDistTensor( - *input, target_args_def, transform_flag, is_stride_kernel); + VLOG(6) << "PrepareDataForDistTensor for optional vector return " + "transformed dist " + "tensor"; + return paddle::make_optional< + std::vector>>( + PrepareDataForDistTensor( + *input, target_args_def, transform_flag, is_stride_kernel)); } return paddle::none; } diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 25c0e4137aa7f..2eba71c7295c8 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" #include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/sparse_coo_tensor.h" @@ -178,13 +179,23 @@ inline bool NeedTransformPlace(const phi::Place& src_place, std::shared_ptr ReshardApiInputToKernelInput( phi::DeviceContext* dev_ctx, const Tensor& tensor, - const phi::distributed::TensorDistAttr& dist_attr); + const phi::distributed::ArgDistAttr& dist_attr); -std::shared_ptr -ReshardApiInputToReplicatedKernelInput( +std::vector> +ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, + const std::vector& tensor, + const phi::distributed::ArgDistAttr& dist_attr); + +paddle::optional> +ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, + const paddle::optional& tensor, + const phi::distributed::ArgDistAttr& dist_attr); + +paddle::optional>> +ReshardApiInputToKernelInput( phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::TensorDistAttr& dist_attr); + const paddle::optional>& tensors, + const phi::distributed::ArgDistAttr& dist_attr); void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor); @@ -195,34 +206,32 @@ void ReshardKernelOutputToApiOutput( Tensor* dst_tensor); std::shared_ptr PrepareDataForDistTensor( - const Tensor& input, + std::shared_ptr input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); -std::shared_ptr PrepareDataForDistTensor( - const std::shared_ptr& input, +std::vector> +PrepareDataForDistTensor( + std::vector> input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); -std::vector> -PrepareDataForDistTensor(const std::vector& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel); - -paddle::optional PrepareDataForDistTensor( - const paddle::optional& input, +paddle::optional> +PrepareDataForDistTensor( + paddle::optional> input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); paddle::optional>> -PrepareDataForDistTensor(const paddle::optional>& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel); +PrepareDataForDistTensor( + paddle::optional>> + input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel); } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index da8b9125a71dd..14334aa7c42a6 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/enforce.h" namespace paddle { @@ -63,10 +64,12 @@ PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst) { "happens when handling inplace optional inputs & outputs."; return; } - PADDLE_ENFORCE_EQ(src.is_dense_tensor() && dst->is_dense_tensor(), - true, - phi::errors::Unavailable( - "Now only supported DenseTensor in Custom Operator.")); + PADDLE_ENFORCE_EQ( + ((src.is_dense_tensor() && dst->is_dense_tensor()) || + (src.is_dist_tensor() && dst->is_dist_tensor())), + true, + phi::errors::Unavailable( + "Now only supported DenseTensor and DistTensor in Custom Operator.")); PADDLE_ENFORCE_EQ( src.initialized(), true, @@ -76,9 +79,19 @@ PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst) { true, phi::errors::Unavailable( "The Custom OpKernel origin output is not defined.")); - auto& dense_src = static_cast(*src.impl()); - auto* dense_dst = static_cast(dst->impl().get()); - *dense_dst = dense_src; + if (src.is_dense_tensor()) { + auto& dense_src = static_cast(*src.impl()); + auto* dense_dst = static_cast(dst->impl().get()); + *dense_dst = dense_src; + } else { + auto* dense_src = + static_cast(src.impl().get()) + ->unsafe_mutable_value(); + auto* dense_dst = + static_cast(dst->impl().get()) + ->unsafe_mutable_value(); + *dense_dst = *dense_src; + } } ////////////////////// Kernel Context ////////////////////// @@ -149,7 +162,8 @@ std::vector* CustomOpKernelContext::AllMutableInput() { return &inputs_; } -paddle::optional CustomOpKernelContext::OptionalInputAt(size_t idx) { +paddle::optional CustomOpKernelContext::OptionalInputAt( + size_t idx) const { if (!inputs_.at(idx).is_initialized()) { return paddle::none; } @@ -157,7 +171,7 @@ paddle::optional CustomOpKernelContext::OptionalInputAt(size_t idx) { } paddle::optional> -CustomOpKernelContext::OptionalInputsBetween(size_t start, size_t end) { +CustomOpKernelContext::OptionalInputsBetween(size_t start, size_t end) const { std::vector rlt; for (size_t i = start; i < end; ++i) { if (!inputs_.at(i).is_initialized()) { @@ -181,7 +195,7 @@ std::vector CustomOpKernelContext::MutableOutputBetween(size_t start, } std::vector CustomOpKernelContext::OutputsBetween(size_t start, - size_t end) { + size_t end) const { std::vector rlt; for (size_t i = start; i < end; ++i) { rlt.emplace_back(outputs_.at(i)); @@ -203,12 +217,12 @@ const std::pair& CustomOpKernelContext::OutputRangeAt( } const std::vector>& -CustomOpKernelContext::InputRange() { +CustomOpKernelContext::InputRange() const { return input_range_; } const std::vector>& -CustomOpKernelContext::OutputRange() { +CustomOpKernelContext::OutputRange() const { return output_range_; } @@ -293,12 +307,13 @@ std::vector* CustomOpKernelContext::AllMutablePlainOutput() { return &plain_outputs_; } -std::unordered_map CustomOpKernelContext::GetInplaceIndexMap() { +std::unordered_map CustomOpKernelContext::GetInplaceIndexMap() + const { return inplace_idx_map_; } std::unordered_map -CustomOpKernelContext::GetInplaceReverseIndexMap() { +CustomOpKernelContext::GetInplaceReverseIndexMap() const { return inplace_reverse_idx_map_; } ////////////////////// Op Meta Info ////////////////////// diff --git a/paddle/phi/api/lib/tensor.cc b/paddle/phi/api/lib/tensor.cc index f50347fd6678a..206d5082e62dd 100644 --- a/paddle/phi/api/lib/tensor.cc +++ b/paddle/phi/api/lib/tensor.cc @@ -113,9 +113,13 @@ std::vector Tensor::shape() const { const phi::DDim &Tensor::strides() const { if (is_dense_tensor()) { return static_cast(impl_.get())->strides(); + } else if (is_dist_tensor()) { + return static_cast(impl_.get()) + ->value() + .strides(); } else { PADDLE_THROW(phi::errors::Unimplemented( - "Only support strides operation on DenseTensor now.")); + "Only support strides operation on DenseTensor and DistTensor now.")); } } @@ -433,9 +437,16 @@ void Tensor::bump_inplace_version() { auto &inplace_version_counter = static_cast(impl_.get())->InplaceVersionCounter(); inplace_version_counter.Bump(); + } else if (is_dist_tensor()) { + auto &inplace_version_counter = + static_cast(impl_.get()) + ->unsafe_mutable_value() + ->InplaceVersionCounter(); + inplace_version_counter.Bump(); } else { - PADDLE_THROW(phi::errors::Unimplemented( - "bump_inplace_version is only supported on DenseTensor now.")); + PADDLE_THROW( + phi::errors::Unimplemented("bump_inplace_version is only supported on " + "DenseTensor and DistTensor now.")); } } @@ -444,9 +455,15 @@ uint32_t Tensor::current_inplace_version() { auto &inplace_version_counter = static_cast(impl_.get())->InplaceVersionCounter(); return inplace_version_counter.CurrentVersion(); + } else if (is_dist_tensor()) { + auto &inplace_version_counter = + static_cast(impl_.get()) + ->unsafe_mutable_value() + ->InplaceVersionCounter(); + return inplace_version_counter.CurrentVersion(); } else { - LOG_FIRST_N(WARNING, 1) - << "current_inplace_version is only supported on DenseTensor now."; + LOG_FIRST_N(WARNING, 1) << "current_inplace_version is only supported on " + "DenseTensor DistTensor now."; } return 0; } @@ -457,6 +474,12 @@ void Tensor::reset_inplace_version(bool set_to_zero) { auto &inplace_version_counter = static_cast(impl_.get())->InplaceVersionCounter(); inplace_version_counter.SetInplaceVersionToZero(); + } else if (is_dist_tensor()) { + auto &inplace_version_counter = + static_cast(impl_.get()) + ->unsafe_mutable_value() + ->InplaceVersionCounter(); + return inplace_version_counter.SetInplaceVersionToZero(); } } } diff --git a/paddle/phi/api/lib/tensor_copy.cc b/paddle/phi/api/lib/tensor_copy.cc index ac3f80bed0c03..4e95157008995 100644 --- a/paddle/phi/api/lib/tensor_copy.cc +++ b/paddle/phi/api/lib/tensor_copy.cc @@ -24,7 +24,11 @@ limitations under the License. */ #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/infermeta/unary.h" - +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/api/lib/data_transform.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#endif namespace paddle { namespace experimental { @@ -40,7 +44,45 @@ void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst) { auto& pool = paddle::experimental::DeviceContextPool::Instance(); auto* dev_ctx = pool.GetMutable( target_place.GetType() == place.GetType() ? place : target_place); +#ifdef PADDLE_WITH_DISTRIBUTE + bool run_auto_parallel = AllInputsAreDistTensor(src); + bool rank_is_in_current_mesh = false; + if (run_auto_parallel) { + auto mesh = + std::static_pointer_cast(src.impl()) + ->dist_attr() + .process_mesh(); + rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); + + auto meta_dist_input_x = MakeDistMetaTensor(*src.impl()); + + auto dist_out = SetKernelDistOutput(dst, meta_dist_input_x.dist_attr()); + auto dense_out = dist_out->unsafe_mutable_value(); + if (!rank_is_in_current_mesh) { + *dense_out = + phi::DenseTensor(std::make_shared( + nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + } + + phi::MetaTensor meta_dist_out(dist_out); + phi::UnchangedInferMeta(MakeMetaTensor(*(src.impl())), &meta_dist_out); + + if (rank_is_in_current_mesh) { + auto dist_input_x = + static_cast(src.impl().get()); + + auto input_x = &dist_input_x->value(); + + phi::MetaTensor meta_dense_out(dense_out); + phi::UnchangedInferMeta(MakeMetaTensor(*input_x), &meta_dense_out); + phi::Copy(*dev_ctx, *input_x, place, blocking, dense_out); + } + VLOG(6) << "copy finished. "; + return; + } +#endif auto dense_x = TensorToDenseTensor(src); auto kernel_out = SetKernelOutput(dst); diff --git a/paddle/phi/api/lib/tensor_method.cc b/paddle/phi/api/lib/tensor_method.cc index 74ee1e380dcc4..7058796f769d4 100644 --- a/paddle/phi/api/lib/tensor_method.cc +++ b/paddle/phi/api/lib/tensor_method.cc @@ -27,7 +27,11 @@ limitations under the License. */ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/infermeta/unary.h" // clang-format off - +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" +#include "paddle/phi/api/lib/data_transform.h" +#endif namespace paddle { namespace experimental { // declare cast api @@ -87,9 +91,7 @@ void Tensor::copy_(const Tensor &src, VLOG(8) << "Src is empty, skip copy"; return; } - // Prepare copy kernel key and outputs - auto kernel_key_set = ParseKernelKeyByInputArgs(src); - KernelType kernel_type = ParseKernelTypeByInputArgs(src); + VLOG(3) << "Deep copy Tensor from " << src.name() << " to " << name(); if (initialized()) { PADDLE_ENFORCE_EQ(dtype(), @@ -114,6 +116,12 @@ void Tensor::copy_(const Tensor &src, "Copy cannot be performed!", target_place, place())); + } + + // Prepare copy kernel key and outputs + auto kernel_key_set = ParseKernelKeyByInputArgs(src); + KernelType kernel_type = ParseKernelTypeByInputArgs(src); + if (initialized()) { kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place())); } else { @@ -129,6 +137,58 @@ void Tensor::copy_(const Tensor &src, place.GetType() == target_place.GetType() ? target_place : place); if (kernel_type == KernelType::DENSE_TENSOR_KENREL) { +#ifdef PADDLE_WITH_DISTRIBUTE + bool run_auto_parallel = AllInputsAreDistTensor(src); + bool rank_is_in_current_mesh = false; + if (run_auto_parallel) { + auto mesh = std::static_pointer_cast( + src.impl())->dist_attr().process_mesh(); + rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); + + auto meta_dist_input_x = MakeDistMetaTensor(*src.impl()); + + if (this->initialized()) { + auto this_dist_attr = + std::static_pointer_cast( + this->impl())->dist_attr(); + PADDLE_ENFORCE_EQ((meta_dist_input_x.dist_attr() == this_dist_attr + || this_dist_attr.empty()), + true, + phi::errors::PreconditionNotMet( + "DistAttr is different of dst " + "tensor and args %s, which " + "current tensor holds %s " + "Copy cannot be performed!", + meta_dist_input_x.dist_attr(), + this_dist_attr)); + } + + auto dist_out = SetKernelDistOutput(this, meta_dist_input_x.dist_attr()); + auto dense_out = dist_out->unsafe_mutable_value(); + if (!rank_is_in_current_mesh) { + *dense_out = phi::DenseTensor( + std::make_shared(nullptr, + 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + } + + phi::MetaTensor meta_dist_out(dist_out); + phi::UnchangedInferMeta(MakeMetaTensor(*(src.impl_)), &meta_dist_out); + + if (rank_is_in_current_mesh) { + auto dist_input_x = static_cast( + src.impl().get());; + + auto input_x = &dist_input_x->value(); + + phi::MetaTensor meta_dense_out(dense_out); + phi::UnchangedInferMeta(MakeMetaTensor(*input_x), &meta_dense_out); + + phi::Copy(*dev_ctx, *input_x, target_place, blocking, dense_out); + } + return; + } +#endif SetKernelOutput(this); phi::MetaTensor meta_out(impl_.get()); phi::UnchangedInferMeta( diff --git a/paddle/phi/api/lib/tensor_utils.cc b/paddle/phi/api/lib/tensor_utils.cc index b8d25e4f22b10..4d5711ecb4078 100644 --- a/paddle/phi/api/lib/tensor_utils.cc +++ b/paddle/phi/api/lib/tensor_utils.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/api/include/tensor_utils.h" +#include "glog/logging.h" #include "paddle/phi/api/lib/api_registry.h" #include "paddle/phi/core/dense_tensor.h" @@ -105,4 +106,38 @@ PADDLE_API Tensor from_blob(void* data, return Tensor(std::make_shared(alloc, meta)); } +#ifdef PADDLE_WITH_DISTRIBUTE +PD_REGISTER_API(reshard) + +PADDLE_API std::shared_ptr reshard( + const paddle::Tensor& input, + const phi::distributed::TensorDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ(input.is_dist_tensor(), + true, + phi::errors::InvalidArgument( + "The input tensor of ReshardFunction should be " + "``phi::distributed::DistTensor``. " + "However it's %s", + typeid(input.impl().get()).name())); + auto dev_ctx = phi::distributed::GetDistTensorDeviceContext( + static_cast(input.impl().get())); + auto input_tensor_impl = input.impl(); + std::shared_ptr dist_out_ptr = nullptr; + if (input_tensor_impl) { + phi::distributed::DistTensor* dist_tensor = + static_cast(input_tensor_impl.get()); + if (dist_tensor->dist_attr() != dist_attr) { + VLOG(6) << "reshard func, reshard tensor from " + << dist_tensor->dist_attr() << " to " << dist_attr; + auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, + dist_attr); + dist_out_ptr = func->Eval(dev_ctx, *dist_tensor, dist_attr); + } else { + dist_out_ptr = std::static_pointer_cast( + input_tensor_impl); + } + } + return dist_out_ptr; +} +#endif } // namespace paddle diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 7be497318443a..157d34e28aaca 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -432,6 +432,7 @@ infer_meta : func : UnchangedInferMeta param : [x] + spmd_rule : ElementwiseUnaryGradInferSpmd kernel : func : cos_grad backward : cos_double_grad @@ -494,24 +495,26 @@ data_type : out_grad - backward_op : cummax_grad - forward : cummax(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices) - args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype) + forward : cummax(Tensor x, int axis=-1, DataType dtype = DataType::INT64) -> Tensor(out), Tensor(indices) + args : (Tensor x, Tensor indices, Tensor out_grad, int axis, DataType dtype) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta param: [x] kernel : func : cummax_grad + data_type : out_grad - backward_op : cummin_grad - forward : cummin(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices) - args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype) + forward : cummin(Tensor x, int axis=-1, DataType dtype = DataType::INT64) -> Tensor(out), Tensor(indices) + args : (Tensor x, Tensor indices, Tensor out_grad, int axis, DataType dtype) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta param: [x] kernel : func : cummin_grad + data_type : out_grad - backward_op : cumprod_grad forward : cumprod (Tensor x, int dim) -> Tensor(out) @@ -708,6 +711,7 @@ infer_meta : func : UnchangedInferMeta param : [out] + spmd_rule : ElementwiseUnaryGradInferSpmd kernel : func : exp_grad inplace : (out_grad -> x_grad) @@ -938,6 +942,8 @@ func : gather_nd_grad composite : gather_nd_grad(x, index, out_grad, x_grad) no_need_buffer : x + data_transform : + skip_transform : index - backward_op : gaussian_inplace_grad forward : gaussian_inplace(Tensor x, float mean=0, float std=1.0, int seed=0) -> Tensor(out) @@ -1119,6 +1125,8 @@ kernel : func : index_put_grad data_type : out_grad + data_transform : + skip_transform : indices - backward_op : index_sample_grad forward : index_sample (Tensor x, Tensor index) -> Tensor(out) @@ -1817,6 +1825,7 @@ infer_meta : func : UnchangedInferMeta param : [out] + spmd_rule : ElementwiseUnaryGradInferSpmd kernel : func : relu_grad backward: relu_double_grad @@ -1906,6 +1915,7 @@ infer_meta : func : UnchangedInferMeta param : [out] + spmd_rule : ElementwiseUnaryGradInferSpmd kernel : func : rsqrt_grad backward : rsqrt_double_grad @@ -2061,6 +2071,7 @@ infer_meta : func : UnchangedInferMeta param : [x] + spmd_rule : ElementwiseUnaryGradInferSpmd kernel : func : silu_grad backward : silu_double_grad @@ -2087,6 +2098,7 @@ infer_meta : func : UnchangedInferMeta param : [x] + spmd_rule : ElementwiseUnaryGradInferSpmd kernel : func : sin_grad backward : sin_double_grad @@ -2234,6 +2246,7 @@ infer_meta : func : UnchangedInferMeta param : [x] + spmd_rule : ElementwiseUnaryGradInferSpmd kernel : func : square_grad backward : square_double_grad diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index b54307861b367..45186294ce979 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -62,14 +62,23 @@ optional : bias, x_max - op : conv2d_xpu - args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, int act_type, float act_param, DataType out_dtype) + args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, Tensor scale_max, Tensor out_max_in, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, int act_type, float act_param, DataType out_dtype) output : Tensor(out), Tensor(out_max) infer_meta : func : Conv2dXPUInferMeta kernel : func : conv2d_xpu data_type : x - optional : bias, branch, branch_max ,x_max + optional : bias, branch, branch_max ,x_max, scale_max, out_max_in + +- op : dequantize_xpu + args : (Tensor x, DataType out_dtype, float scale = 1.0f) + output : Tensor(y) + infer_meta : + func : DeQuantizeXPUInferMeta + kernel : + func : dequantize_xpu + data_type: x - op : embedding_with_eltwise_add_xpu args : (Tensor[] ids, Tensor[] tables, Tensor mask, int64_t padding_idx) @@ -101,14 +110,14 @@ data_type : x - op : fc_xpu - args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha, DataType out_dtype) + args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, Tensor scale_max, Tensor out_max_in, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha, DataType out_dtype) output : Tensor(out), Tensor(out_max) infer_meta : func : FcXPUInferMeta kernel : func : fc_xpu data_type : x - optional : bias, x_max + optional : bias, x_max, scale_max, out_max_in - op : fused_bias_act args : (Tensor x, Tensor bias, Tensor dequant_scales, Tensor shift, Tensor smooth, str act_method = "gelu", str compute_dtype = "default", float quant_scale = -1, int quant_round_type = 1, float quant_max_bound = 127.0, float quant_min_bound = -127.0) @@ -207,6 +216,38 @@ func : fused_scale_bias_relu_conv_bnstats data_type : x +- op : fusion_gru + args : (Tensor x, Tensor h0, Tensor weight_x, Tensor weight_h, Tensor bias, str activation = "tanh", str gate_activation = "sigmoid", bool is_reverse = false, bool use_seq = true, bool origin_mode = false, bool use_mkldnn = false, str mkldnn_data_type = "float32", float scale_data = 1.0f, float shift_data = 0.0f, float[] scale_weights = {1.0f}, bool force_fp32_output = false) + output : Tensor(reordered_h0), Tensor(xx), Tensor(batched_input), Tensor(batched_out), Tensor(hidden) + infer_meta : + func : FusionGRUInferMeta + kernel : + func : fusion_gru + data_type : x + optional : h0, bias + intermediate : reordered_h0, xx, batched_input, batched_out + +- op : fusion_seqconv_eltadd_relu + args : (Tensor x, Tensor filter, Tensor bias, int context_length, int context_start = 0, int context_stride = 1) + output : Tensor(out), Tensor(col_mat) + infer_meta : + func : FusionSeqConvEltAddReluInferMeta + kernel : + func : fusion_seqconv_eltadd_relu + data_type : x + intermediate : col_mat + +- op : fusion_seqexpand_concat_fc + args : (Tensor[] x, Tensor fc_weight, Tensor fc_bias, str fc_activation="identity") + output : Tensor(out), Tensor(fc_out) + infer_meta : + func : FusionSeqExpandConcatFCInferMeta + kernel : + func : fusion_seqexpand_concat_fc + data_type : x + optional : fc_bias + intermediate : fc_out + - op : fusion_transpose_flatten_concat args : (Tensor[] x, int[] trans_axis, int flatten_axis, int concat_axis) output : Tensor(out) @@ -254,6 +295,15 @@ data_type : input optional : bias_qk +- op : quantize_xpu + args : (Tensor x, DataType out_dtype, float scale = 1.0f) + output : Tensor(y) + infer_meta : + func : QuantizeXPUInferMeta + kernel : + func : quantize_xpu + data_type : x + - op : squeeze_excitation_block args : (Tensor x, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, int[] act_type, float[] act_param, int[] filter_dims) output : Tensor(out) diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index 5e7cff9213171..86d79f6543efb 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -175,7 +175,7 @@ def parse_input_and_attr(self, api_name, args_config, optional_vars=[]): 'Scalar(int)': 'const Scalar&', 'Scalar(int64_t)': 'const Scalar&', 'Scalar(float)': 'const Scalar&', - 'Scalar(dobule)': 'const Scalar&', + 'Scalar(double)': 'const Scalar&', 'Scalar[]': 'const std::vector&', 'int': 'int', 'int32_t': 'int32_t', @@ -764,7 +764,21 @@ def gene_optional_vec_dense_input( input_tensor_code = ( input_tensor_code + f""" -{code_indent} paddle::optional> {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});""" +{code_indent} // inplace vector of tensors should also be transferred to CPU when kernel has fallen back +{code_indent} paddle::optional> {PREFIX_TENSOR_NAME}{input_name}; +{code_indent} paddle::optional> {PREFIX_TENSOR_NAME}{input_name}_vec; +{code_indent} if (kernel_result.has_fallback_cpu) {{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), actual_kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); +{code_indent} if ({PREFIX_TENSOR_NAME}{input_name}_vec){{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name} = paddle::optional>({PREFIX_TENSOR_NAME}{input_name}_vec->size()); +{code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}_vec->size(); ++i) {{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name}->at(i) = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i); +{code_indent} }} +{code_indent} }} +{code_indent} }} +{code_indent} else {{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name}); +{code_indent} }}""" ) else: input_name_tensor_map[input_name].append( @@ -773,7 +787,7 @@ def gene_optional_vec_dense_input( input_tensor_code = ( input_tensor_code + f""" -{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); +{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), actual_kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); {code_indent} paddle::optional> {PREFIX_TENSOR_NAME}{input_name}; {code_indent} if ({PREFIX_TENSOR_NAME}{input_name}_vec){{ {code_indent} {PREFIX_TENSOR_NAME}{input_name} = paddle::optional>({PREFIX_TENSOR_NAME}{input_name}_vec->size()); @@ -802,7 +816,19 @@ def gene_vec_dense_input( input_tensor_code = ( input_tensor_code + f""" -{code_indent} std::vector {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});""" +{code_indent} // inplace vector of tensors should also be transferred to CPU when kernel has fallen back +{code_indent} std::vector {PREFIX_TENSOR_NAME}{input_name}; +{code_indent} std::unique_ptr> {PREFIX_TENSOR_NAME}{input_name}_vec; +{code_indent} if (kernel_result.has_fallback_cpu) {{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), actual_kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); +{code_indent} {PREFIX_TENSOR_NAME}{input_name}.resize({PREFIX_TENSOR_NAME}{input_name}_vec->size()); +{code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i); +{code_indent} }} +{code_indent} }} +{code_indent} else {{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name}); +{code_indent} }}""" ) else: input_name_tensor_map[input_name].append( @@ -811,7 +837,7 @@ def gene_vec_dense_input( input_tensor_code = ( input_tensor_code + f""" -{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); +{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), actual_kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); {code_indent} std::vector {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size()); {code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{ {code_indent} {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i); @@ -1243,7 +1269,9 @@ def gen_kernel_code(self, kernel_name, code_indent, inplace_flag=False): {code_indent} phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type); {code_indent} }} {code_indent} VLOG(6) << "{kernel_name} kernel: " << kernel; -{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend); +{code_indent} // add actual_kernel_backend to select actual kernel backend after a potential falling-back to CPU +{code_indent} Backend actual_kernel_backend = kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend; +{code_indent} auto* dev_ctx = GetDeviceContextByBackend(actual_kernel_backend); {input_tensors} {output_create} {pre_save_stride} diff --git a/paddle/phi/api/yaml/generator/api_gen.py b/paddle/phi/api/yaml/generator/api_gen.py index fcfcd17922759..27f329b80c607 100644 --- a/paddle/phi/api/yaml/generator/api_gen.py +++ b/paddle/phi/api/yaml/generator/api_gen.py @@ -274,7 +274,10 @@ def gene_output( output_create = ( output_create + f""" -{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, {get_out_code});""" +{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, {get_out_code}); +{code_indent} if (kernel_result.has_fallback_cpu) {{ +{code_indent} TransDataBackend(kernel_out_{i}, actual_kernel_backend, kernel_out_{i}); +{code_indent} }}""" ) else: @@ -379,6 +382,7 @@ def source_include(header_file_path): #ifdef PADDLE_WITH_DISTRIBUTE #include "paddle/phi/infermeta/spmd_rules/rules.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #endif PD_DECLARE_bool(conv2d_disable_cudnn); @@ -405,6 +409,9 @@ def declare_extension_api(): return """ namespace paddle { PD_DECLARE_API(from_blob); +#ifdef PADDLE_WITH_DISTRIBUTE +PD_DECLARE_API(reshard); +#endif } // namespace paddle """ diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index 3bd51c35e5d15..f20dd50e61099 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -13,6 +13,7 @@ # limitations under the License. import argparse +import re import yaml from api_base import PREFIX_TENSOR_NAME @@ -43,32 +44,62 @@ # TODO(chenweihang): add view support later MAIN_DIST_BRANCH_TEMPLATE = """ // Auto Parallel condition - if ({}) {{ + if (run_auto_parallel) {{ // 1. InferSpmd (Infer DistAttr of Inputs&Outputs){} // 2. Create API Output & Prepare Dist and Dense Output{} // 3. Infer DistTensor's Global Shape{}\n - // 4. Select Kernel{} - // 5. Reshard Input{}\n - // 6. PrepareData (DataTransform & Prepare Dense Input){} - // 7. Infer Local DenseTensor Meta{} - // 8. DenseTensor Kernel Call{} + if (rank_is_in_current_mesh) {{ + // 4. Select Kernel{} + // 5. Reshard Input{}\n + // 6. PrepareData (DataTransform & Prepare Dense Input){} + // 7. Infer Local DenseTensor Meta{} + // 8. DenseTensor Kernel Call{} + }}\n // 9. Reshard Partial Output to Replicated (Temporary){}\n - // 10. Return + // 10. Set Output Dist Attr For Default Impl{}\n + // 11. Return {} }} """ +# TODO(GhostScreaming): Support no-input operators. +# 1. Non computation rank clip +GET_MESH_TEMPLATE = """ + auto mesh = std::static_pointer_cast({}impl())->dist_attr().process_mesh(); + rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh);""" + # Auto Parallel condition -AUTO_PARALLEL_COND_TEMPLATE = """AllInputsAreDistTensor({})""" +AUTO_PARALLEL_COND_TEMPLATE = """ + bool run_auto_parallel = AllInputsAreDistTensor({input_args}); + bool rank_is_in_current_mesh = true; + if (run_auto_parallel) {{{mesh} + }} + if (rank_is_in_current_mesh) {{{kernel_code} + }} +""" # 1. InferSPMD SINGLE_DIST_META_IN_TEMPLATE = """ - auto meta_dist_input_{} = MakeDistMetaTensor(*{}.impl());""" + auto meta_dist_input_{name} = MakeDistMetaTensor(*{name}.impl());""" +VECTOR_DIST_META_IN_TEMPLATE = """ + std::vector meta_dist_input_{name}; + for(auto& e : {name}) {{ + meta_dist_input_{name}.push_back(MakeDistMetaTensor(*e.impl())); + }}""" +OPTIONAL_SINGLE_DIST_META_IN_TEMPLATE = """ + auto meta_dist_input_{name} = {name} ? MakeDistMetaTensor(*(*{name}).impl()) : phi::distributed::DistMetaTensor();""" +OPTIONAL_VECTOR_DIST_META_IN_TEMPLATE = """ + std::vector meta_dist_input_{name}; + if ({name}) {{ + for(auto& e : *{name}) {{ + meta_dist_input_{name}.push_back(MakeDistMetaTensor(*e.impl())); + }} + }}""" INFER_SPMD_TEMPLATE = """ auto spmd_info = phi::distributed::{}({}); """ GENERAL_INFER_SPMD_TEMPLATE = """ - auto spmd_info = phi::distributed::VariadicReplicatedInferSpmd({}); + auto spmd_info = phi::distributed::VariadicReplicatedInferSpmdDynamic({}); """ UNSUPPORTED_INFER_SPMD_COMMENT_TEMPLATE = """ // API `{}` does not support InferSpmd now @@ -84,24 +115,49 @@ SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD = """ auto dist_out = SetKernelDistOutput(&api_output); auto dense_out = dist_out->unsafe_mutable_value(); + if (!rank_is_in_current_mesh) {{ + *dense_out = phi::DenseTensor( + std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + }} """ MULTI_SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD = """ auto dist_out_{idx} = SetKernelDistOutput({out}); - auto dense_out_{idx} = dist_out_{idx}->unsafe_mutable_value(); + auto dense_out_{idx} = dist_out_{idx} ? dist_out_{idx}->unsafe_mutable_value() : nullptr; + if (!rank_is_in_current_mesh) {{ + *dense_out_{idx} = phi::DenseTensor( + std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + }} """ SINGLE_OUT_CREATION_TEMPLATE = """ auto dist_out = SetKernelDistOutput(&api_output, spmd_info.second[0]); auto dense_out = dist_out->unsafe_mutable_value(); + if (!rank_is_in_current_mesh) {{ + *dense_out = phi::DenseTensor( + std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + }} """ MULTI_SINGLE_OUT_CREATION_TEMPLATE = """ auto dist_out_{idx} = SetKernelDistOutput({out}, spmd_info.second[{idx}]); auto dense_out_{idx} = dist_out_{idx}->unsafe_mutable_value(); + if (!rank_is_in_current_mesh) {{ + *dense_out_{idx} = phi::DenseTensor( + std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + }} """ VECTOR_OUT_CREATION_TEMPLATE = """ auto dist_out = SetKernelDistOutput({}, &api_output); std::vector dense_out(dist_out.size()); for (size_t i = 0; i < dist_out.size(); ++i) {{ - dense_out[i] = const_cast(&dist_out[i]->value()); + dense_out[i] = const_cast(&dist_out[i]->value()); + if (!rank_is_in_current_mesh) {{ + *dense_out[i] = phi::DenseTensor( + std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + }} }} """ MULTI_VECTOR_OUT_CREATION_TEMPLATE = """ @@ -109,6 +165,11 @@ std::vector dense_out_{out_name}(dist_out_{out_name}.size()); for (size_t i = 0; i < dist_out_{out_name}.size(); ++i) {{ dense_out_{out_name}[i] = const_cast(&dist_out_{out_name}[i]->value()); + if (!rank_is_in_current_mesh) {{ + *dense_out_{out_name}[i] = phi::DenseTensor( + std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + }} }} """ MULTI_VECTOR_INPLACE_AND_OPTIONAL_OUT_CREATION_TEMPLATE = """ @@ -168,66 +229,63 @@ INFER_GLOBAL_SHAPE_TEMPLATE = """ phi::{}({}{}); """ -# Dist Branch will not generated in the API that doesn't have input tensor. -SET_SINGLE_OUT_REPLICATED_DIST_ATTR = """ - SetReplicatedDistAttrForOutput({}, spmd_info.first[0].process_mesh());""" -SET_VECTOR_OUT_REPLICATED_DIST_ATTR = """ - auto current_process_mesh = spmd_info.first[0].process_mesh(); - for (size_t i = 0; i < dist_out.size(); ++i) {{ - SetReplicatedDistAttrForOutput(dist_out[i], current_process_mesh); - }} -""" # 4. Select Kernel KERNEL_SELECTION_TEMPLATE = """ - VLOG(6) << "{} API dist branch: kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; - auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "{}", {{kernel_backend, kernel_layout, kernel_data_type}}); - const auto& kernel = kernel_result.kernel; - VLOG(6) << "{} kernel: " << kernel; - auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend); + VLOG(6) << "{} API dist branch: kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "{}", {{kernel_backend, kernel_layout, kernel_data_type}}); + const auto& kernel = kernel_result.kernel; + VLOG(6) << "{} kernel: " << kernel; + dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend); """ # 5. Reshard Input -SINGLE_INPUT_RESHARD_TEMPLATE = """ - auto dist_input_{arg} = ReshardApiInputToKernelInput(dev_ctx, {arg}, spmd_info.first[{idx}]);""" -SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE = """ - auto dist_input_{arg} = ReshardApiInputToReplicatedKernelInput(dev_ctx, {arg}, spmd_info.first[{idx}]);""" +# Both Tensor, std::vector, paddle::optional and +# paddle::optional> use the same template +INPUT_RESHARD_TEMPLATE = """ + auto dist_input_{name} = ReshardApiInputToKernelInput(dev_ctx, {name}, spmd_info.first[{idx}]);""" +GENERAL_INPUT_RESHARD_TEMPLATE = """ + auto dist_input_{name} = ReshardApiInputToReplicatedKernelInput(dev_ctx, {name}, spmd_info.first[{idx}]);""" UNSUPPORTED_RESHARD_INPUT_COMMENT_TEMPLATE = """ - // API `{}` does not need to support ReshardInput at this time + // API `{}` does not need to support ReshardInput at this time """ # 6. PrepareData SINGLE_PREPARE_DATA_TEMPLATE = """ - dist_input_{arg} = PrepareDataForDistTensor(dist_input_{arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); - auto input_{arg} = &dist_input_{arg}->value(); + dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto input_{name} = &dist_input_{name}->value(); """ SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """ - auto dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); - auto input_{arg} = &dist_input_{arg}->value(); + auto dist_input_{name} = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto input_{name} = &dist_input_{name}->value(); """ VECTOR_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{name}_vec = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); - std::vector dense_input_{name}_vec; - for (auto tmp : dist_input_{name}_vec) {{ - dense_input_{name}_vec.emplace_back(&tmp->value()); - }} - std::vector dense_input_{name}_meta_vec = MakeMetaTensor(dense_input_{name}_vec); - std::vector dense_input_{name}_meta_ptr_vec(dense_input_{name}_meta_vec.size()); - for (size_t i = 0; i < dense_input_{name}_meta_ptr_vec.size(); ++i) {{ - dense_input_{name}_meta_ptr_vec[i] = &dense_input_{name}_meta_vec[i]; - }} + auto dist_input_{name}_vec = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + std::vector dense_input_{name}_vec; + for (auto tmp : dist_input_{name}_vec) {{ + dense_input_{name}_vec.emplace_back(&tmp->value()); + }} + std::vector dense_input_{name}_meta_vec = MakeMetaTensor(dense_input_{name}_vec); + std::vector dense_input_{name}_meta_ptr_vec(dense_input_{name}_meta_vec.size()); + for (size_t i = 0; i < dense_input_{name}_meta_ptr_vec.size(); ++i) {{ + dense_input_{name}_meta_ptr_vec[i] = &dense_input_{name}_meta_vec[i]; + }} """ OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{name} = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); - paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional(dist_input_{name}->value()) : paddle::none; + dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional((*dist_input_{name})->value()) : paddle::none; +""" +OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """ + auto dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional(dist_input_{name}->value()) : paddle::none; """ OPTIONAL_VECTOR_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{name}_vec = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); - std::vector dense_input_{name}_vec; - if ({name}) {{ - for (auto tmp : *dist_input_{name}_vec) {{ - dense_input_{name}_vec.emplace_back(&tmp->value()); + auto dist_input_{name}_vec = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + std::vector dense_input_{name}_vec; + if ({name}) {{ + for (auto tmp : *dist_input_{name}_vec) {{ + dense_input_{name}_vec.emplace_back(&tmp->value()); }} }} paddle::optional> input_{name}(dense_input_{name}_vec); @@ -239,6 +297,7 @@ paddle::optional> dense_input_{name}_meta_ptr_vec = {name} ? paddle::make_optional>(dense_input_{name}_meta_ptr_vec_tmp) : paddle::none; """ + INFER_META_SINGLE_INPUT_TEMPLATE = """ auto dist_input_{} = {}.impl(); auto input_{} = &(static_cast(dist_input_{}.get())->value()); @@ -257,16 +316,16 @@ OPTIONAL_SINGLE_META_IN_TEMPLATE = """MakeMetaTensor(input_{}), """ OPTIONAL_VECTOR_META_IN_TEMPLATE = """dense_input_{}_meta_ptr_vec, """ SINGLE_META_OUT_DECL_TEMPLATE = """ - phi::MetaTensor meta_{}({});""" + phi::MetaTensor meta_{}({});""" VECTOR_META_OUT_DECL_TEMPLATE = """ - std::vector {name}_meta_vec = MakeMetaTensor({name}); - std::vector {name}_meta_ptr_vec({name}_meta_vec.size()); - for (size_t i = 0; i < {name}_meta_vec.size(); ++i) {{ - {name}_meta_ptr_vec[i] = &{name}_meta_vec[i]; - }} + std::vector {name}_meta_vec = MakeMetaTensor({name}); + std::vector {name}_meta_ptr_vec({name}_meta_vec.size()); + for (size_t i = 0; i < {name}_meta_vec.size(); ++i) {{ + {name}_meta_ptr_vec[i] = &{name}_meta_vec[i]; + }} """ INFER_META_TEMPLATE = """ - phi::{}({}{}); + phi::{}({}{}); """ # 8. DenseTensor Kernel Call @@ -278,10 +337,11 @@ TUPLE_OUTPUT_NAME_TEMPLATE = """ """ KERNEL_CALL_TEMPLATE = """ - using kernel_signature = {}; - auto* kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)({}, {}); + using kernel_signature = {}; + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)({}, {}); """ + # TODO(GhostScreaming): Some operators generate shape info in runtime, # bincount. As a result, dist_output's global shape is set uncorrectly, # because it's generated in InferMeta function. A temporally solution is @@ -303,12 +363,28 @@ # 9. Reshard Partial Output to Replicated RESHARD_P2R_SINGLE_OUTPUT_TEMPLATE = """ + dev_ctx = phi::distributed::GetDistTensorDeviceContext(dist_out); ReshardOutputPartialAxisToReplicated(dev_ctx, dist_out);""" RESHARD_P2R_MULTI_SINGLE_OUTPUT_TEMPLATE = """ - ReshardOutputPartialAxisToReplicated(dev_ctx, dist_out_{});""" + dev_ctx = phi::distributed::GetDistTensorDeviceContext(dist_out_{idx}); + ReshardOutputPartialAxisToReplicated(dev_ctx, dist_out_{idx});""" UNSUPPORTED_RESHARD_OUTPUT_COMMENT_TEMPLATE = """ - // API `{}` does not need to support ReshardOutput now + // API `{}` does not need to support ReshardOutput now.""" + +# 10. Set Output DistAttr for Default impl +# Dist Branch will not generated in the API that doesn't have input tensor. +CURRENT_PROCESS_MESH_TEMPLATE = """ + auto current_process_mesh = paddle::holds_alternative(spmd_info.first[0]) ? + paddle::get<0>(spmd_info.first[0]).process_mesh() : paddle::get<1>(spmd_info.first[0]).at(0).process_mesh();""" +SET_SINGLE_OUT_REPLICATED_DIST_ATTR_TEMPLATE = """ + SetReplicatedDistAttrForOutput({}, current_process_mesh);""" +SET_VECTOR_OUT_REPLICATED_DIST_ATTR_TEMPLATE = """ + for (size_t i = 0; i < {name}.size(); ++i) {{ + SetReplicatedDistAttrForOutput({name}[i], current_process_mesh); + }} """ +NONEED_TO_SET_DIST_ATTR_COMMENT_TEMPLATE = """ + // API `{}` does not need to set DistAttr for output.""" # BaseAPI members: # inputs: @@ -399,13 +475,224 @@ def vector_output_size_assertion_check(self): self.outputs['out_size_expr'] is not None ), f"{self.api}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." - def generate_if_condition_code(self) -> str: + def generate_non_computation_rank_clip_code(self) -> str: + if len(self.inputs['names']) > 0: + mesh = "" + # All inputs have same mesh + if ( + self.inputs['input_info'][self.inputs['names'][0]] + == "const Tensor&" + ): + mesh = GET_MESH_TEMPLATE.format( + "{}.".format(self.inputs['names'][0]) + ) + elif ( + self.inputs['input_info'][self.inputs['names'][0]] + == "const paddle::optional&" + ): + mesh = GET_MESH_TEMPLATE.format( + "{}->".format(self.inputs['names'][0]) + ) + elif ( + self.inputs['input_info'][self.inputs['names'][0]] + == "const std::vector&" + ): + mesh = GET_MESH_TEMPLATE.format( + "{}[0].".format(self.inputs['names'][0]) + ) + elif ( + self.inputs['input_info'][self.inputs['names'][0]] + == "const paddle::optional>&" + ): + mesh = GET_MESH_TEMPLATE.format( + "{}->at(0).".format(self.inputs['names'][0]) + ) + return mesh + else: + return "" + + # Backward API Override this method + def gene_kernel_backend_select(self): + backend_select_code = "" + if self.kernel['backend'] is not None: + if '>' in self.kernel['backend']: + vars_list = self.kernel['backend'].split('>') + assert ( + len(vars_list) == 2 + ), f"{self.api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}." + assert (vars_list[0].strip() in self.attrs['names']) and ( + self.attrs['attr_info'][vars_list[0].strip()][0] + == 'const Place&' + ), f"{self.api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type." + backend_select_code = f""" + kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()}); +""" + + else: + backend_args = [ + ele.strip() for ele in self.kernel['backend'].split(',') + ] + backend_select_code = f""" + kernel_backend = ParseBackend({", ".join(backend_args)}); +""" + + return backend_select_code + + # Overload api_base.py gene_kernel_select function. + def gene_kernel_select(self) -> str: + api = self.api + input_names = self.inputs['names'] + attrs = self.attrs + kernel = self.kernel + + kernel_key_item_init = """ + Backend kernel_backend = Backend::UNDEFINED; + DataLayout kernel_layout = DataLayout::UNDEFINED; + DataType kernel_data_type = DataType::UNDEFINED; +""" + + # Check the tensor options + attr_backend_count = 0 + attr_layout_count = 0 + attr_data_type_count = 0 + for attr_name in attrs['names']: + if attrs['attr_info'][attr_name][0] == 'const Place&': + assert ( + kernel['backend'] is not None + ), f"{api} api: When there is a parameter with 'Place' type in attributes, you must set backend of kernel manually." + attr_backend_count = attr_backend_count + 1 + if attrs['attr_info'][attr_name][0] == 'DataLayout': + assert ( + kernel['layout'] is not None + ), f"{api} api: When there is a parameter with 'DataLayout' type in attributes, you must set layout of kernel manually." + attr_layout_count = attr_layout_count + 1 + if attrs['attr_info'][attr_name][0] == 'DataType': + assert ( + kernel['data_type'] is not None + ), f"{api} api: When there is a parameter with 'DataType' type in attributes, you must set data_type of kernel manually." + attr_data_type_count = attr_data_type_count + 1 + + # preprocess kernel configures + kernel_select_code = self.gene_kernel_backend_select() + + if kernel['layout'] is not None: + if '>' in kernel['layout']: + vars_list = kernel['layout'].split('>') + assert ( + len(vars_list) == 2 + ), f"{api} api: The number of params to set layout with '>' only allows 2, but received {len(vars_list)}." + assert ( + vars_list[0].strip() in attrs['names'] + and attrs['attr_info'][vars_list[0].strip()][0] + == 'DataLayout' + ), f"{api} api: When use '>' to set kernel layout, the first param should be a attribute with DataLayout type." + kernel_select_code = ( + kernel_select_code + + f""" + kernel_layout = ParseLayoutWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()}); +""" + ) + + else: + vars_list = kernel['layout'].split(',') + assert ( + len(vars_list) == 1 + ), f"{api} api: The number of params to set layout must be 1, but received {len(vars_list)}." + kernel_select_code = ( + kernel_select_code + + f""" + kernel_layout = ParseLayout({vars_list[0].strip()}); +""" + ) + + if kernel['data_type'] is not None: + + def process_data_type_args(args_item): + args_item = args_item.strip() + complex_match_result = re.match( + r"complex\((?P\w+)\)", args_item + ) + if complex_match_result: + return f"phi::dtype::ToComplex(ParseDataType({complex_match_result.group('param_name')}))" + else: + return f"ParseDataType({args_item})" + + if '>' in kernel['data_type']: + vars_list = kernel['data_type'].split('>') + assert ( + len(vars_list) == 2 + ), f"{api} api: The number of params to set data_type with '>' only allows 2, but received {len(vars_list)}." + assert ( + vars_list[0].strip() in attrs['names'] + and attrs['attr_info'][vars_list[0].strip()][0] + == 'DataType' + ), f"{api} api: When use '>' to set kernel data_type, the first param should be a attribute with DataType type." + kernel_select_code = ( + kernel_select_code + + f""" + kernel_data_type = ParseDataTypeWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()}); +""" + ) + + else: + vars_list = kernel['data_type'].split(',') + assert ( + len(vars_list) == 1 + ), f"{api} api: The number of params to set data_type only allows 1, but received {len(vars_list)}." + kernel_select_code = ( + kernel_select_code + + f""" + kernel_data_type = {process_data_type_args(vars_list[0])}; +""" + ) + + if len(input_names) == 0: + assert ( + attr_backend_count > 0 and attr_data_type_count > 0 + ), f"{api} api: When there is no input tensor, the args must have 'Place' and 'DataType'." + + kernel_select_args = "" + for input_name in input_names: + kernel_select_args = kernel_select_args + input_name + ", " + + if len(kernel_select_args) > 2: + kernel_select_args = kernel_select_args[:-2] + + # kernel_select_code = kernel_key_item_init + kernel_select_code + + if len(input_names) > 0: + kernel_select_code = ( + kernel_select_code + + f""" + if (kernel_backend == Backend::UNDEFINED + || kernel_layout == DataLayout::UNDEFINED + || kernel_data_type == DataType::UNDEFINED ) {{ + auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args}); + auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); + if (kernel_backend == Backend::UNDEFINED) {{ + kernel_backend = kernel_key.backend(); + }} + if (kernel_layout == DataLayout::UNDEFINED) {{ + kernel_layout = kernel_key.layout(); + }} + if (kernel_data_type == DataType::UNDEFINED) {{ + kernel_data_type = kernel_key.dtype(); + }} + }}""" + ) + input_args = "" for input_name in self.inputs['names']: input_args = input_args + input_name + ", " if len(input_args) > 2: input_args = input_args[:-2] - return AUTO_PARALLEL_COND_TEMPLATE.format(input_args) + mesh = self.generate_non_computation_rank_clip_code() + + if_condition_code = AUTO_PARALLEL_COND_TEMPLATE.format( + input_args=input_args, mesh=mesh, kernel_code=kernel_select_code + ) + + return kernel_key_item_init + if_condition_code def generate_specialized_infer_spmd_code(self) -> str: input_names = self.inputs['names'] @@ -423,9 +710,18 @@ def generate_specialized_infer_spmd_code(self) -> str: if param in input_names: if self.inputs['input_info'][param] == "const Tensor&": input_decl_code += SINGLE_DIST_META_IN_TEMPLATE.format( - param, param + name=param ) input_args_code += "meta_dist_input_" + param + ", " + elif ( + self.inputs['input_info'][param] + == "const std::vector&" + ): + input_decl_code += VECTOR_DIST_META_IN_TEMPLATE.format( + name=param + ) + input_args_code += "meta_dist_input_" + param + ", " + else: raise ValueError( f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported." @@ -464,22 +760,33 @@ def generate_general_infer_spmd_code(self) -> str: if param in input_names: if self.inputs['input_info'][param] == "const Tensor&": input_decl_code += SINGLE_DIST_META_IN_TEMPLATE.format( - param, param + name=param ) input_args_code += "meta_dist_input_" + param + ", " elif ( self.inputs['input_info'][param] - == "const std::vector&" - or self.inputs['input_info'][param] == "const paddle::optional&" - or self.inputs['input_info'][param] + ): + input_decl_code += ( + OPTIONAL_SINGLE_DIST_META_IN_TEMPLATE.format(name=param) + ) + input_args_code += "meta_dist_input_" + param + ", " + elif ( + self.inputs['input_info'][param] + == "const std::vector&" + ): + input_decl_code += VECTOR_DIST_META_IN_TEMPLATE.format( + name=param + ) + input_args_code += "meta_dist_input_" + param + ", " + elif ( + self.inputs['input_info'][param] == "const paddle::optional>&" ): - # TODO(chenweihang): support other input type later, - # now only support single tensor input api - input_decl_code = "" - input_args_code = "" - break + input_decl_code += ( + OPTIONAL_VECTOR_DIST_META_IN_TEMPLATE.format(name=param) + ) + input_args_code += "meta_dist_input_" + param + ", " else: raise ValueError( f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported." @@ -510,6 +817,7 @@ def generate_output_creation_code(self) -> str: output_num = len(self.outputs['types']) return_type = self.get_return_type_with_intermediate(self.inplace_flag) output_creation_code = "" + output_creation_code += "\n phi::DeviceContext* dev_ctx = nullptr;" if output_num == 1: # api output generate if self.need_to_generate_code_for_inplace_impl(0): @@ -679,17 +987,12 @@ def generate_infer_global_shape_code(self) -> str: # 3. get meta tensor output args output_decl_code = "" output_args_code = "" - set_out_dist_attr_code = "" for i, out_name in enumerate(self.dist_output_args): if self.outputs['types'][i] == 'std::vector': output_decl_code += VECTOR_GLOBAL_META_OUT_DECL_TEMPLATE.format( name=out_name ) output_args_code += f"{out_name}_meta_ptr_vec, " - if self.generate_general_infer_spmd is True: - set_out_dist_attr_code += ( - SET_VECTOR_OUT_REPLICATED_DIST_ATTR - ) else: output_decl_code += SINGLE_GLOBAL_META_OUT_DECL_TEMPLATE.format( out_name, out_name @@ -700,10 +1003,6 @@ def generate_infer_global_shape_code(self) -> str: output_args_code += ( f"{out_name} ? &meta_{out_name} : nullptr, " ) - if self.generate_general_infer_spmd is True: - set_out_dist_attr_code += ( - SET_SINGLE_OUT_REPLICATED_DIST_ATTR.format(out_name) - ) output_args_code = output_args_code[:-2] return ( @@ -712,7 +1011,6 @@ def generate_infer_global_shape_code(self) -> str: + INFER_GLOBAL_SHAPE_TEMPLATE.format( infer_meta_func_code, input_args_code, output_args_code ) - + set_out_dist_attr_code ) def generate_kernel_selection_code(self) -> str: @@ -733,19 +1031,15 @@ def generate_reshard_input_code(self) -> str: for i, param in enumerate(kernel_params): if param in input_names: - if self.inputs['input_info'][param] == "const Tensor&": - if self.generate_general_infer_spmd is True: - input_reshard_code += ( - SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE.format( - arg=param, idx=i - ) - ) - else: - input_reshard_code += ( - SINGLE_INPUT_RESHARD_TEMPLATE.format( - arg=param, idx=i - ) - ) + if self.inputs['input_info'][param] in [ + "const Tensor&", + "const std::vector&", + "const paddle::optional&", + "const paddle::optional>&", + ]: + input_reshard_code += INPUT_RESHARD_TEMPLATE.format( + name=param, idx=i + ) else: raise ValueError( f"{self.api} : Param of reshard input error : {self.inputs['input_info'][param]} type is not supported." @@ -774,15 +1068,15 @@ def generate_single_dense_input( if self.generate_infer_spmd is True: input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE.format( - arg=input_name, + name=input_name, idx=kernel_param.index(input_name), - flag=trans_flag, + trans_flag=trans_flag, ) else: input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD.format( arg=input_name, idx=kernel_param.index(input_name), - flag=trans_flag, + trans_flag=trans_flag, ) return input_tensor_code @@ -798,10 +1092,9 @@ def generate_vector_dense_input( kernel_param = self.kernel['param'] if kernel_param is None: kernel_param = input_names + attr_names - input_tensor_code += VECTOR_PREPARE_DATA_TEMPLATE.format( name=input_name, - index=kernel_param.index(input_name), + idx=kernel_param.index(input_name), trans_flag=trans_flag, ) @@ -819,11 +1112,20 @@ def generate_optional_single_dense_input( if kernel_param is None: kernel_param = input_names + attr_names - input_tensor_code += OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE.format( - name=input_name, - index=kernel_param.index(input_name), - trans_flag=trans_flag, - ) + if self.generate_infer_spmd is True: + input_tensor_code += OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE.format( + name=input_name, + idx=kernel_param.index(input_name), + trans_flag=trans_flag, + ) + else: + input_tensor_code += ( + OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD.format( + name=input_name, + idx=kernel_param.index(input_name), + trans_flag=trans_flag, + ) + ) return input_tensor_code @@ -838,10 +1140,9 @@ def generate_optional_vector_dense_input( kernel_param = self.kernel['param'] if kernel_param is None: kernel_param = input_names + attr_names - input_tensor_code += OPTIONAL_VECTOR_PREPARE_DATA_TEMPLATE.format( name=input_name, - index=kernel_param.index(input_name), + idx=kernel_param.index(input_name), trans_flag=trans_flag, ) @@ -1075,7 +1376,9 @@ def generate_reshard_partial_out_to_replicated_code(self) -> str: for i, out_type in enumerate(self.outputs['types']): if out_type == 'Tensor': reshard_p2r_code += ( - RESHARD_P2R_MULTI_SINGLE_OUTPUT_TEMPLATE.format(i) + RESHARD_P2R_MULTI_SINGLE_OUTPUT_TEMPLATE.format( + idx=i + ) ) else: self.vector_output_size_assertion_check() @@ -1090,6 +1393,29 @@ def generate_reshard_partial_out_to_replicated_code(self) -> str: return reshard_p2r_code + def generate_output_dist_attr_setting(self) -> str: + set_out_dist_attr_code = "" + if self.generate_general_infer_spmd is True: + set_out_dist_attr_code += CURRENT_PROCESS_MESH_TEMPLATE + for i, out_name in enumerate(self.dist_output_args): + if self.outputs['types'][i] == 'std::vector': + set_out_dist_attr_code += ( + SET_VECTOR_OUT_REPLICATED_DIST_ATTR_TEMPLATE.format( + name=out_name + ) + ) + else: + set_out_dist_attr_code += ( + SET_SINGLE_OUT_REPLICATED_DIST_ATTR_TEMPLATE.format( + out_name + ) + ) + else: + set_out_dist_attr_code = ( + NONEED_TO_SET_DIST_ATTR_COMMENT_TEMPLATE.format(self.api) + ) + return set_out_dist_attr_code + def generate_return_code(self) -> str: return self.gene_return_code() @@ -1098,7 +1424,6 @@ def generate_auto_paralel_branch(self) -> str: if len(self.inputs['names']) == 0: return "" return MAIN_DIST_BRANCH_TEMPLATE.format( - self.generate_if_condition_code(), self.generate_infer_spmd_code(), self.generate_output_creation_code(), self.generate_infer_global_shape_code(), @@ -1108,6 +1433,7 @@ def generate_auto_paralel_branch(self) -> str: self.generate_infer_meta_code(), self.generate_kernel_call_code(), self.generate_reshard_partial_out_to_replicated_code(), + self.generate_output_dist_attr_setting(), self.generate_return_code(), ) diff --git a/paddle/phi/api/yaml/generator/dist_bw_api_gen.py b/paddle/phi/api/yaml/generator/dist_bw_api_gen.py index b29e186f06d38..04b3af4ec48cf 100644 --- a/paddle/phi/api/yaml/generator/dist_bw_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_bw_api_gen.py @@ -24,17 +24,20 @@ MAIN_DIST_BRANCH_TEMPLATE = """ // Auto Parallel condition - if ({}) {{ + if (run_auto_parallel) {{ // 1. InferSpmd (Infer DistAttr of Inputs&Outputs){} // 2. Create Temporary Output & Prepare Dist and Dense Output{} // 3. Infer DistTensor's Global Shape{}\n - // 4. Select Kernel{} - // 5. Reshard Input{}\n - // 6. PrepareData (DataTransform & Prepare Dense Input){} - // 7. Infer Local DenseTensor Meta{} - // 8. DenseTensor Kernel Call{} - // 9. Reshard Output{}\n - // 10. Return + // 4. Set Output Dist Attr For Default Impl{}\n + if (rank_is_in_current_mesh){{ + // 5. Select Kernel{} + // 6. Reshard Input{}\n + // 7. PrepareData (DataTransform & Prepare Dense Input){} + // 8. Infer Local DenseTensor Meta{} + // 9. DenseTensor Kernel Call{} + }} + // 10. Reshard Partial Output to Replicated (Temporary){}\n + // 11. Return {} }} """ @@ -46,28 +49,36 @@ """ SINGLE_OUT_CREATION_TEMPLATE_WITH_SPMD = """ std::shared_ptr shared_dist_out = - CreateKernelDistOutput({}, spmd_info.second[0]); + CreateKernelDistOutput({}, !rank_is_in_current_mesh, spmd_info.second[0]); phi::distributed::DistTensor* dist_out = shared_dist_out.get(); phi::DenseTensor* dense_out = dist_out->unsafe_mutable_value(); + if (dense_out && !rank_is_in_current_mesh && !dist_out->defined()) {{ + *dense_out = phi::DenseTensor( + std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + }} """ SINGLE_OUT_CREATION_TEMPLATE = """ std::shared_ptr shared_dist_out = - CreateKernelDistOutput({}); + CreateKernelDistOutput({}, !rank_is_in_current_mesh); phi::distributed::DistTensor* dist_out = shared_dist_out.get(); phi::DenseTensor* dense_out = dist_out->unsafe_mutable_value(); -""" -VECTOR_OUT_CREATION_TEMPLATE = """ - auto dist_out = SetKernelDistOutput({name}); - std::vector dense_out(dist_out.size()); - for (size_t i=0; i(&dist_out[i]->value()); + if (dense_out && !rank_is_in_current_mesh && !dist_out->defined()) {{ + *dense_out = phi::DenseTensor( + std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); }} """ VECTOR_OUT_CREATION_TEMPLATE = """ auto dist_out = SetKernelDistOutput({name}); std::vector dense_out(dist_out.size()); - for (size_t i = 0; i < dist_out.size(); i++) {{ - dense_out[i] = const_cast(&dist_out[i]->value()); + for (size_t i=0; iunsafe_mutable_value(); + if (dense_out[i] && !rank_is_in_current_mesh && !dist_out[i]->defined()) {{ + *dense_out[i] = phi::DenseTensor( + std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + }} }} """ INPLACE_OUT_CREATION_TEMPLATE = """ @@ -75,33 +86,53 @@ """ MULTI_SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD = """ auto dist_out_{idx} = SetKernelDistOutput({name}); - auto dense_out_{idx} = dist_out_{idx}->unsafe_mutable_value(); + auto dense_out_{idx} = dist_out_{idx} ? dist_out_{idx}->unsafe_mutable_value() : nullptr; + if (dense_out_{idx} && !rank_is_in_current_mesh && dist_out_{idx}->defined()) {{ + *dense_out_{idx} = phi::DenseTensor( + std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + }} """ MULTI_SINGLE_OUT_CREATION_TEMPLATE_WITH_SPMD = """ std::shared_ptr shared_dist_out_{idx} = - CreateKernelDistOutput({name}, spmd_info.second[{idx}]); + CreateKernelDistOutput({name}, !rank_is_in_current_mesh, spmd_info.second[{idx}]); phi::distributed::DistTensor* dist_out_{idx} = shared_dist_out_{idx}.get(); phi::DenseTensor* dense_out_{idx} = dist_out_{idx} ? dist_out_{idx}->unsafe_mutable_value() : nullptr; + if (dense_out_{idx} && !rank_is_in_current_mesh && dist_out_{idx}->defined()) {{ + *dense_out_{idx} = phi::DenseTensor( + std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + }} """ MULTI_SINGLE_OUT_CREATION_TEMPLATE = """ std::shared_ptr shared_dist_out_{idx} = - CreateKernelDistOutput({name}); + CreateKernelDistOutput({name}, !rank_is_in_current_mesh); phi::distributed::DistTensor* dist_out_{idx} = shared_dist_out_{idx}.get(); phi::DenseTensor* dense_out_{idx} = dist_out_{idx} ? dist_out_{idx}->unsafe_mutable_value() : nullptr; + if (dense_out_{idx} && !rank_is_in_current_mesh && !dist_out_{idx}->defined()) {{ + *dense_out_{idx} = phi::DenseTensor( + std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + }} """ MULTI_VECTOR_OUT_CREATION_TEMPLATE = """ auto dist_out_{i} = SetKernelDistOutput({name}); std::vector dense_out_{i}(dist_out_{i}.size()); for (size_t i = 0; i < dist_out_{i}.size(); i++) {{ - dense_out_{i}[i] = const_cast(&dist_out_{i}[i]->value()); + dense_out_{i}[i] = const_cast(&dist_out_{i}[i]->value()); + if (dense_out_{i}[i] && !rank_is_in_current_mesh && !dist_out_{i}[i]->defined()) {{ + *dense_out_{i}[i]= phi::DenseTensor( + std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + }} }} """ # 9. Reshard Output RESHARD_SINGLE_OUTPUT_TEMPLATE = """ - ReshardKernelOutputToApiOutput(dev_ctx, shared_dist_out, {});""" + ReshardKernelOutputToApiOutput(dev_ctx, shared_dist_out, {});""" RESHARD_MULTI_SINGLE_OUTPUT_TEMPLATE = """ - ReshardKernelOutputToApiOutput(dev_ctx, shared_dist_out_{}, {});""" + ReshardKernelOutputToApiOutput(dev_ctx, shared_dist_out_{}, {});""" class DistBackwardAPI(DistForwardAPI, BackwardAPI): @@ -114,6 +145,7 @@ def generate_output_creation_code(self) -> str: # backward api only need to generate kernel outputs output_num = len(self.outputs['types']) output_creation_code = "" + output_creation_code += "\n phi::DeviceContext* dev_ctx = nullptr;" if output_num == 1: self.dist_output_args.append('dist_out') self.dense_output_args.append('dense_out') @@ -260,10 +292,10 @@ def generate_auto_paralel_branch(self) -> str: if len(self.inputs['names']) == 0: return "" return MAIN_DIST_BRANCH_TEMPLATE.format( - self.generate_if_condition_code(), self.generate_infer_spmd_code(), self.generate_output_creation_code(), self.generate_infer_global_shape_code(), + self.generate_output_dist_attr_setting(), self.generate_kernel_selection_code(), self.generate_reshard_input_code(), self.generate_prepare_data_code(), @@ -308,6 +340,7 @@ def source_include(header_file_path, fw_header_file_path): #ifdef PADDLE_WITH_DISTRIBUTE #include "paddle/phi/infermeta/spmd_rules/rules.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #endif PD_DECLARE_bool(conv2d_disable_cudnn); diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 73e508434697c..e5e3f9fb86c53 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -86,7 +86,7 @@ kernel : func : batch_norm_double_grad data_type : x - optional : out_mean, out_variance, grad_x_grad, grad_scale_grad, grad_bias_grad + optional : scale, out_mean, out_variance, grad_x_grad, grad_scale_grad, grad_bias_grad inplace : (grad_out -> grad_out_grad) - backward_op : batch_norm_grad @@ -99,7 +99,7 @@ kernel : func : batch_norm_grad data_type : out_grad - optional : mean_out, variance_out, reserve_space + optional : scale, bias, mean_out, variance_out, reserve_space composite: batch_norm_grad(x, scale, bias, mean_out, variance_out, saved_mean, saved_variance, reserve_space, out_grad, momentum, epsilon, data_layout, is_test, use_global_stats, trainable_statistics) backward : batch_norm_double_grad @@ -193,6 +193,7 @@ infer_meta : func : GeneralBinaryGradInferMeta param : [x, y] + spmd_rule : ElementwiseBinaryGradInferSpmd kernel : func : divide_grad composite : divide_grad(x, y, out, out_grad, axis, x_grad, y_grad) @@ -226,6 +227,7 @@ infer_meta : func : GeneralBinaryGradInferMeta param: [x, y] + spmd_rule : ElementwiseBinaryGradInferSpmd composite : elementwise_pow_grad(x, y, out_grad, x_grad, y_grad) kernel : func : elementwise_pow_grad @@ -246,8 +248,8 @@ invoke : zeros_like(out_grad) - backward_op : frobenius_norm_grad - forward : frobenius_norm(Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) -> Tensor(out) - args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis, bool keep_dim, bool reduce_all) + forward : frobenius_norm(Tensor x, IntArray axis, bool keep_dim, bool reduce_all) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis, bool keep_dim, bool reduce_all) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta @@ -363,6 +365,7 @@ infer_meta : func : GeneralBinaryGradInferMeta param: [x, y] + spmd_rule: ElementwiseBinaryGradInferSpmd kernel : func : maximum_grad composite : maximum_grad(x, y, out_grad, axis, x_grad, y_grad) @@ -380,6 +383,7 @@ infer_meta : func : UnchangedInferMeta param: [x] + spmd_rule : ReductionGradInferSpmd kernel : func : mean_grad backward : mean_double_grad @@ -407,8 +411,8 @@ composite : minimum_grad(x, y, out_grad, axis, x_grad, y_grad) - backward_op : mish_grad - forward : mish (Tensor x, float threshold) -> Tensor(out) - args : (Tensor x, Tensor out_grad, float threshold) + forward : mish (Tensor x, float lambda) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float lambda) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta @@ -438,6 +442,7 @@ infer_meta : func : GeneralBinaryGradInferMeta param : [x, y] + spmd_rule : ElementwiseBinaryGradInferSpmd kernel : func : multiply_grad composite: multiply_grad(x, y, out_grad, axis, x_grad, y_grad) @@ -702,6 +707,7 @@ infer_meta : func : UnchangedInferMeta param : [x] + spmd_rule : ReductionGradInferSpmd kernel : func : sum_grad composite : sum_grad(x, out_grad, axis, keepdim, reduce_all, x_grad) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 01acb338c987b..060ab96f28694 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -67,10 +67,10 @@ args : (Tensor start, Tensor end, Tensor step, DataType dtype, Place place={}) output : Tensor(out) infer_meta : - func : ArangeInferMeta + func : ArangeTensorInferMeta param : [start, end, step] kernel : - func : arange + func : arange_tensor param : [start, end, step] data_type : dtype backend : place @@ -122,7 +122,7 @@ data_type : x view : (mean -> mean_out), (variance -> variance_out) backward : batch_norm_grad - optional : reserve_space + optional : scale, bias, reserve_space - op : c_allgather args : (Tensor x, int ring_id, int nranks, bool use_calc_stream) @@ -214,7 +214,7 @@ inplace : (x -> out) - op : c_sync_comm_stream - args : (Tensor x) + args : (Tensor x, int ring_id) output : Tensor(out) infer_meta : func : UnchangedInferMeta @@ -317,6 +317,7 @@ output : Tensor(out) infer_meta : func : ElementwiseInferMeta + spmd_rule : ElementwiseBinaryInferSpmd kernel : func : divide inplace: (x -> out) @@ -348,6 +349,7 @@ output : Tensor(out) infer_meta : func : ElementwiseInferMeta + spmd_rule: ElementwiseBinaryInferSpmd kernel : func : elementwise_pow backward : elementwise_pow_grad @@ -415,6 +417,7 @@ output : Tensor(out) infer_meta : func : CompareInferMeta + spmd_rule: ElementwiseBinaryInferSpmd kernel : func : equal inplace: (x -> out) @@ -452,10 +455,10 @@ inplace: (x -> out) - op : frobenius_norm - args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) + args : (Tensor x, IntArray axis, bool keep_dim, bool reduce_all) output : Tensor(out) infer_meta : - func : ReduceInferMetaBase + func : ReduceIntArrayAxisInferMetaBase kernel : func : frobenius_norm backward : frobenius_norm_grad @@ -717,6 +720,7 @@ output : Tensor(out) infer_meta : func : ElementwiseInferMeta + spmd_rule : ElementwiseBinaryInferSpmd kernel : func : maximum backward : maximum_grad @@ -726,6 +730,7 @@ output : Tensor(out) infer_meta : func : ReduceIntArrayAxisInferMeta + spmd_rule : ReductionMeanInferSpmdDynamic kernel : func : mean backward : mean_grad @@ -781,6 +786,7 @@ output : Tensor infer_meta : func : ElementwiseInferMeta + spmd_rule : ElementwiseBinaryInferSpmd kernel : func : multiply {dense, dense -> dense}, multiply_sr {selected_rows, dense -> selected_rows} @@ -801,6 +807,7 @@ output : Tensor(out) infer_meta : func : CompareInferMeta + spmd_rule : ElementwiseBinaryInferSpmd kernel : func : not_equal inplace: (x -> out) @@ -903,6 +910,7 @@ func : RepeatInterleaveInferMeta kernel : func : repeat_interleave + data_type : x backward: repeat_interleave_grad - op : repeat_interleave_with_tensor_index @@ -1015,6 +1023,7 @@ output : Tensor(out) infer_meta : func : SumInferMeta + spmd_rule : ReductionSumInferSpmdDynamic kernel : func : sum data_type : x diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index d7d52a0041d8a..c5900d451d191 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -987,8 +987,8 @@ extra : attrs : [bool use_mkldnn = false, bool use_cudnn = false] -- op : exponential_ - backward : exponential__grad +- op : exponential_ (exponential) + backward : exponential__grad (exponential_grad) inputs : x : X outputs : @@ -1263,8 +1263,6 @@ scale : Scale outputs : out : Out - attrs : - epsilon : epsilon - op : fused_fc_elementwise_layernorm inputs : @@ -1278,11 +1276,6 @@ out : Out mean : Mean variance : Variance - attrs : - x_num_col_dims : x_num_col_dims - activation_type : activation_type - epsilon : epsilon - begin_norm_axis : begin_norm_axis - op : fused_feedforward backward: fused_feedforward_grad @@ -1316,19 +1309,67 @@ dropout1_out: Dropout1Out dropout2_out: Dropout2Out +- op : fused_gemm_epilogue + inputs: + {x : X, y : Y, bias : Bias} + outputs : + {out : Out, reserve_space: ReserveSpace} + +- op : fused_gemm_epilogue_grad + inputs: + {x : X, y : Y, reserve_space: ReserveSpace, out_grad : DOut} + outputs : + {x_grad : DX, y_grad : DY, bias_grad : DBias} + - op : fused_transpose extra : attrs : [str data_format = "AnyLayout"] -- op : fusion_transpose_flatten_concat +- op : fusion_gru inputs : x : X + h0 : H0 + weight_x : WeightX + weight_h : WeightH + bias : Bias + outputs : + reordered_h0 : ReorderedH0 + xx : XX + batched_input : BatchedInput + batched_out : BatchedOut + hidden : Hidden + attrs : + scale_data : Scale_data + shift_data : Shift_data + scale_weights : Scale_weights + +- op : fusion_seqconv_eltadd_relu + inputs : + x : X + filter : Filter + bias : Bias outputs : out : Out + col_mat : ColMat attrs : - trans_axis : trans_axis - flatten_axis : flatten_axis - concat_axis : concat_axis + context_length : contextLength + context_start : contextStart + context_stride : contextStride + +- op : fusion_seqexpand_concat_fc + inputs : + x : X + fc_weight : FCWeight + fc_bias : FCBias + outputs : + out : Out + fc_out : FCOut + +- op : fusion_transpose_flatten_concat + inputs : + x : X + outputs : + out : Out - op : gather backward : gather_grad @@ -1839,6 +1880,10 @@ - op : matmul_with_flatten (mul) backward : matmul_with_flatten_grad (mul_grad) + inputs : + {x : X, y : Y} + outputs : + out : Out extra : attrs : [bool use_mkldnn = false, float scale_x = 1.0f, 'float[] scale_y = {1.0f}', float scale_out = 1.0f, bool force_fp32_output = false] @@ -2428,6 +2473,24 @@ attrs : repeats : Repeats +- op : repeat_interleave + backward : repeat_interleave_grad + inputs : + x : X + outputs : + out : Out + attrs : + {repeats : Repeats, axis : dim} + +- op : repeat_interleave_with_tensor_index + backward : repeat_interleave_with_tensor_index_grad + inputs : + {x : X, repeats: RepeatTensor} + outputs: + out : Out + attrs: + axis : dim + - op : reshape (reshape2) backward : reshape_grad (reshape2_grad) inputs: @@ -3263,6 +3326,12 @@ attrs: pivot : pivots +- op: memcpy + inputs: + x: X + outputs: + out: Out + - op: memcpy_d2h inputs : x : X @@ -3336,6 +3405,16 @@ outputs : out : Out +- op: sparse_momentum + inputs : + {param: Param, grad: Grad, velocity: Velocity, index: Index, axis: Axis, learning_rate: LearningRate,master_param: MasterParam} + outputs : + {param_out: ParamOut, velocity_out: VelocityOut, master_param_out: MasterParamOut} + scalar: + axis: + datatype : int + tensor_name : Axis + - op: squared_l2_norm backward: squared_l2_norm_grad inputs : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index b42a6c99a2ce7..5bf57114402ee 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -327,6 +327,7 @@ output : Tensor(out) infer_meta : func : ElementwiseInferMeta + spmd_rule : ElementwiseBinaryInferSpmd kernel : func : bitwise_and backend : x @@ -337,6 +338,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : bitwise_not backend : x @@ -503,6 +505,7 @@ infer_meta : func : ConcatInferMeta param : [x, axis] + spmd_rule : ConcatInferSpmdDynamic kernel : func : concat data_type : x @@ -550,6 +553,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : cos inplace: (x -> out) @@ -598,21 +602,23 @@ backward : cross_entropy_with_softmax_grad - op : cummax - args : (Tensor x, int axis=-1, int dtype=3) + args : (Tensor x, int axis=-1, DataType dtype = DataType::INT64) output : Tensor(out), Tensor(indices) infer_meta : func : CumWithIndicesInferMeta kernel : func : cummax + data_type : x backward : cummax_grad - op : cummin - args : (Tensor x, int axis=-1, int dtype=3) + args : (Tensor x, int axis=-1, DataType dtype = DataType::INT64) output : Tensor(out), Tensor(indices) infer_meta : func : CumWithIndicesInferMeta kernel : func : cummin + data_type : x backward : cummin_grad - op : cumprod @@ -819,6 +825,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : exp inplace : (x -> out) @@ -1012,10 +1019,10 @@ backward : frame_grad - op : full_int_array - args : (IntArray value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) + args : (int64_t[] value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) output: Tensor(out) infer_meta : - func : CreateIntArrayInferMeta + func : CreateVecShapeInferMeta param : [value, dtype] kernel : func : full_int_array @@ -1035,13 +1042,15 @@ - op : gather_nd args : (Tensor x, Tensor index) - output : Tensor + output : Tensor(out) infer_meta : func : GatherNdInferMeta kernel : func : gather_nd data_type : x backward : gather_nd_grad + data_transform : + skip_transform : index - op : gather_tree args : (Tensor ids, Tensor parents) @@ -1241,6 +1250,8 @@ data_type : x inplace : (x -> out) backward : index_put_grad + data_transform : + skip_transform : indices - op : index_sample args : (Tensor x, Tensor index) @@ -1306,7 +1317,7 @@ func : is_empty - op : isclose - args : (Tensor x, Tensor y, Scalar rtol="1e-5", Scalar atol="1e-8", bool equal_nan=false) + args : (Tensor x, Tensor y, Scalar(double) rtol=1e-5, Scalar(double) atol=1e-8, bool equal_nan=false) output : Tensor(out) infer_meta : func : ValueCompareInferMeta @@ -2073,6 +2084,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : relu inplace : (x -> out) @@ -2179,6 +2191,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : rsqrt inplace : (x -> out) @@ -2359,6 +2372,7 @@ output : Tensor infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : silu backward : silu_grad @@ -2368,6 +2382,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : sin inplace: (x -> out) @@ -2458,6 +2473,7 @@ output : Tensor infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : square {dense -> dense}, square_sr {selected_rows -> selected_rows} @@ -2667,7 +2683,7 @@ backward: uniform_inplace_grad - op : unique_consecutive - args : (Tensor x, bool return_inverse = false, bool return_counts = false, int[] axis = {}, int dtype = 5) + args : (Tensor x, bool return_inverse = false, bool return_counts = false, int[] axis = {}, DataType dtype = DataType::FLOAT32) output : Tensor(out), Tensor(index), Tensor(counts) infer_meta : func : UniqueConsecutiveInferMeta diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 9f8def740385b..f48cd394bac35 100755 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -72,9 +72,9 @@ args : (Tensor start, Tensor end, Tensor step) output : Tensor(out) infer_meta : - func : ArangeInferMeta + func : ArangeTensorInferMeta kernel : - func : arange + func : arange_tensor data_transform : skip_transform : start, end, step @@ -256,7 +256,7 @@ args : (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1) output : Tensor(out) infer_meta : - func : ReduceInferMetaBase + func : ReduceIntArrayAxisInferMetaBase kernel : func : frobenius_norm param : [x, axis, keepdim, reduce_all] @@ -500,6 +500,15 @@ data_type : x backward : prod_grad +- op : quant_linear + args: (Tensor x, Tensor w, Tensor bias, int in_num_col_dims = 1, str activation_type = "", bool padding_weights = false, float scale_in = 1.0f, float[] scale_weights = {1.0f}, int quant_round_type = 1, float quant_max_bound = 127.0f, float quant_min_bound = -127.0f) + output: Tensor(out) + optional: bias + infer_meta: + func: QuantLinearInferMeta + kernel: + func: quant_linear + - op : randint args : (int low, int high, IntArray shape = {}, DataType dtype = DataType::INT64, int seed = 0) output : Tensor(out) diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index 748c80c0859c5..1e57fb736b7c2 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -183,7 +183,7 @@ void Device::BlasAXPBY(const stream::Stream& stream, phi::CppTypeToDataType::Type(), numel, alpha, - reinterpret_cast(const_cast(x)), + reinterpret_cast(const_cast(x)), // NOLINT beta, reinterpret_cast(y)); } diff --git a/paddle/phi/backends/dynload/cuda_driver.cc b/paddle/phi/backends/dynload/cuda_driver.cc index 2bd0a7bfea5c1..d9fd89a0c65a6 100644 --- a/paddle/phi/backends/dynload/cuda_driver.cc +++ b/paddle/phi/backends/dynload/cuda_driver.cc @@ -24,6 +24,7 @@ void* cuda_dso_handle = nullptr; #if CUDA_VERSION >= 10020 CUDA_ROUTINE_EACH_VVM(DEFINE_WRAP); +CUDA_ROUTINE_EACH_CUDA_GRAPH(DEFINE_WRAP); #endif CUDA_ROUTINE_EACH(DEFINE_WRAP); diff --git a/paddle/phi/backends/dynload/cuda_driver.h b/paddle/phi/backends/dynload/cuda_driver.h index f743a33a1866f..ba771afe09023 100644 --- a/paddle/phi/backends/dynload/cuda_driver.h +++ b/paddle/phi/backends/dynload/cuda_driver.h @@ -72,7 +72,13 @@ extern bool HasCUDADriver(); __macro(cuMemRelease); \ __macro(cuMemAddressFree) +#define CUDA_ROUTINE_EACH_CUDA_GRAPH(__macro) \ + __macro(cuGraphNodeGetType); \ + __macro(cuGraphKernelNodeGetParams); \ + __macro(cuGraphExecKernelNodeSetParams) + CUDA_ROUTINE_EACH_VVM(DECLARE_DYNAMIC_LOAD_CUDA_WRAP); +CUDA_ROUTINE_EACH_CUDA_GRAPH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP); #endif CUDA_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP); diff --git a/paddle/phi/backends/dynload/nccl.h b/paddle/phi/backends/dynload/nccl.h index 6c73c562caa69..91b6f5dcd58dc 100644 --- a/paddle/phi/backends/dynload/nccl.h +++ b/paddle/phi/backends/dynload/nccl.h @@ -44,6 +44,7 @@ extern void* nccl_dso_handle; __macro(ncclCommInitAll); \ __macro(ncclGetUniqueId); \ __macro(ncclCommInitRank); \ + __macro(ncclCommAbort); \ __macro(ncclCommDestroy); \ __macro(ncclCommCount); \ __macro(ncclCommCuDevice); \ @@ -55,6 +56,7 @@ extern void* nccl_dso_handle; __macro(ncclGroupEnd); \ __macro(ncclReduce); \ __macro(ncclReduceScatter); \ + __macro(ncclCommGetAsyncError); \ __macro(ncclGetErrorString); NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP) diff --git a/paddle/phi/backends/dynload/rccl.h b/paddle/phi/backends/dynload/rccl.h index 9232d387d2d19..e1018a3f253fa 100644 --- a/paddle/phi/backends/dynload/rccl.h +++ b/paddle/phi/backends/dynload/rccl.h @@ -44,6 +44,7 @@ extern void* rccl_dso_handle; __macro(ncclCommInitAll); \ __macro(ncclGetUniqueId); \ __macro(ncclCommInitRank); \ + __macro(ncclCommAbort); \ __macro(ncclCommDestroy); \ __macro(ncclCommCount); \ __macro(ncclCommCuDevice); \ @@ -55,6 +56,7 @@ extern void* rccl_dso_handle; __macro(ncclGroupEnd); \ __macro(ncclReduce); \ __macro(ncclReduceScatter); \ + __macro(ncclCommGetAsyncError); \ __macro(ncclGetErrorString); RCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_RCCL_WRAP) diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.cc b/paddle/phi/backends/gpu/cuda/cuda_graph.cc index 9268b85f29679..479a88f8ae1ff 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.cc +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.cc @@ -19,6 +19,13 @@ #include #include +#if CUDA_VERSION < 11000 +cudaError_t cudaGetFuncBySymbol(cudaFunction_t *functionPtr, + const void *symbolPtr) { + return cudaSuccess; +} +#endif + namespace phi { namespace backends { namespace gpu { @@ -204,46 +211,8 @@ void CUDAGraph::EndSegmentCapture() { return; } - auto sorted_nodes = ToposortCUDAGraph(graph); - capturing_graph_->pre_hooks_.emplace_back(); - std::unordered_set visited; - VLOG(10) << "SetSeedFunc number : " - << capturing_graph_->set_seed_funcs_.size(); - for (const auto &set_seed_func : capturing_graph_->set_seed_funcs_) { - bool found = false; - for (auto node : sorted_nodes) { - if (visited.count(node) > 0) continue; - cudaGraphNodeType type; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphNodeGetType(node, &type)); - if (type == cudaGraphNodeTypeKernel) { - cudaKernelNodeParams params; - auto err = cudaGraphKernelNodeGetParams(node, ¶ms); - if (err == cudaErrorInvalidDeviceFunction) { - continue; - } else { - PADDLE_ENFORCE_GPU_SUCCESS(err); - } - CUDAKernelParams kernel_params(¶ms); - if (set_seed_func(&kernel_params, true)) { - capturing_graph_->pre_hooks_.back().push_back( - [set_seed_func, node, params](cudaGraphExec_t exec_graph) { - CUDAKernelParams kernel_params(¶ms); - set_seed_func(&kernel_params, false); - PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphExecKernelNodeSetParams( - exec_graph, node, ¶ms)); - }); - visited.insert(node); - found = true; - break; - } - } - } - PADDLE_ENFORCE_EQ(found, - true, - phi::errors::InvalidArgument( - "Cannot find the corresponding random CUDA kernel.")); - } - capturing_graph_->set_seed_funcs_.clear(); + capturing_graph_->pre_hooks_.emplace_back( + CUDAGraphNodeLauncher::Instance().GetParameterSettersForExecGraph(graph)); cudaGraphExec_t exec_graph; PADDLE_ENFORCE_GPU_SUCCESS( @@ -308,6 +277,82 @@ void CUDAGraph::PrintToDotFiles(const std::string &dirname, #endif } +#if CUDA_VERSION >= 11000 +void CUDAGraphNodeLauncher::KernelNodeLaunch( + cudaFunction_t cudaFunc, + parameterSetter_t parameterSetter, + cudaKernelCallback_t cudakernelCallback) { + if (phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) { + unsigned int id = GenerateIndentifier(); + + parameterSetters[cudaFunc][id] = parameterSetter; + cudakernelCallback(id); + + } else { + cudakernelCallback(0); + } +} + +std::vector +CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(cudaGraph_t graph) { + size_t num_nodes; + PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphGetNodes(graph, nullptr, &num_nodes)); + std::vector nodes(num_nodes); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaGraphGetNodes(graph, nodes.data(), &num_nodes)); + + std::vector> hooks; + for (auto node : nodes) { + CUgraphNode cuNode = node; + CUgraphNodeType pType; + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cuGraphNodeGetType(cuNode, &pType)); + if (pType == CU_GRAPH_NODE_TYPE_KERNEL) { + CUDA_KERNEL_NODE_PARAMS cuParams; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cuGraphKernelNodeGetParams(cuNode, &cuParams)); + CUDAKernelParams kernel_params(cuParams.kernelParams); + auto kernel = + parameterSetters.find(static_cast(cuParams.func)); + + // There exists a parameter setter + if (kernel != parameterSetters.end()) { + auto launchSequence = kernel->second; + unsigned int id = kernel_params.As(0); + auto parameterSetter = launchSequence.find(id); + if (parameterSetter != launchSequence.end()) { + auto setter = parameterSetter->second; + hooks.emplace_back([setter, cuNode, cuParams]( + cudaGraphExec_t exec_graph) { + CUDAKernelParams kernel_params(cuParams.kernelParams); + setter(kernel_params); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cuGraphExecKernelNodeSetParams( + static_cast(exec_graph), cuNode, &cuParams)); + }); + } else { + PADDLE_THROW( + phi::errors::InvalidArgument("Error: does not find launch id")); + } + } + } + } + + return hooks; +} +#else +void CUDAGraphNodeLauncher::KernelNodeLaunch( + cudaFunction_t cudaFunc, + parameterSetter_t parameterSetter, + cudaKernelCallback_t cudakernelCallback) { + cudakernelCallback(0); +} + +std::vector +CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(cudaGraph_t graph) { + PADDLE_THROW(phi::errors::Unimplemented( + "CUDAGraphNodeLauncher is only supported when CUDA version >= 11.0")); +} +#endif + } // namespace gpu } // namespace backends } // namespace phi diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.h b/paddle/phi/backends/gpu/cuda/cuda_graph.h index 2f61e031f1128..7b5644128c7cd 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.h +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.h @@ -21,14 +21,13 @@ #include #include #include +#include #include -#include "cuda.h" // NOLINT -#include "cuda_runtime.h" // NOLINT - #include "glog/logging.h" #include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/backends/device_code.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/place.h" @@ -37,6 +36,13 @@ #include "paddle/phi/core/macros.h" #include "paddle/utils/optional.h" +#if CUDA_VERSION < 11000 +// For CUDA versions less than 11.0, use a dummy type for cudaFunction_t. +using cudaFunction_t = void *; +cudaError_t cudaGetFuncBySymbol(cudaFunction_t *functionPtr, + const void *symbolPtr); +#endif + namespace phi { namespace backends { namespace gpu { @@ -88,18 +94,91 @@ class CUDAGraphContextManager { class CUDAKernelParams { public: - explicit CUDAKernelParams(const cudaKernelNodeParams *params) - : params_(params) {} - - const void *func() const { return params_->func; } + explicit CUDAKernelParams(void **params) : kernelParams(params) {} template T &As(size_t idx) const { - return *reinterpret_cast(params_->kernelParams[idx]); + return *reinterpret_cast(kernelParams[idx]); } + void **getParams() const { return kernelParams; } + private: - const cudaKernelNodeParams *params_; + void **kernelParams; +}; + +using cudaGraphExecuterSetter_t = std::function; + +// ** class CUDAGraphNodeLauncher +// +// This class offers a interface for launching CUDA kernels in CUDA Graph, we +// utilize the `cudaGraphExecKernelNodeSetParams` function for parameter setup. +// Launching kernels via this class ensures proper management. +// +// NOTE: It's essential that the first parameter for any kernel launched +// through this class is an `unsigned int` identifier. This identifier plays a +// crucial role in linking the CUDA kernel to its corresponding CUDA graph +// node. We tag each kernel launch with a unique identifier to maintain +// structured linkage with its CUDA graph node. +// +// NOTE: This class use a singleton design pattern ensures there's only a +// single global instance accessible via the `Instance()` method. +class CUDAGraphNodeLauncher { + public: + // [Parameter Setter Callback] + // Sets the kernel's parameters BEFORE activating the CUDA graph. It enables + // dynamic determination and setup of kernel arguments. + // + // parameterSetter_t parameterSetter = [saved_state](CUDAKernelParams + // ¶m){ + // // Code to compute and the parameter values from the saved_state + // // ... + // param.As(idx) = calculated_value; + // }; + using parameterSetter_t = std::function; + + // [CUDA Kernel Callback] + // Acts as the launcher for the kernel. It accepts an `unsigned int` + // identifier and uses it for the kernel launch. + // + // cudaKernelCallback_t cudaKernelCallback = [=](unsigned int id) { + // kernel<<<>>>(id, ...); // Launching the kernel with id + // }; + using cudaKernelCallback_t = std::function; + + // [Retrieving CUDA Function] + // The `cudaGetFuncBySymbol` method can be used to fetch the `cudaFunction_t` + // reference of the kernel from the kernel pointer. + // + // cudaFunction_t cudaFunc; + // PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, &kernel)); + // + // [Kernel Launch] + // With the callbacks defined and the CUDA function obtained, the kernel can + // be launched using the `KernelNodeLaunch` method. + void KernelNodeLaunch(cudaFunction_t cudaFunc, + parameterSetter_t parameterSetter, + cudaKernelCallback_t cudakernelCallback); + + std::vector GetParameterSettersForExecGraph( + cudaGraph_t graph); + + parameterSetter_t GetParameterSetter(const CUDAKernelParams ¶ms); + + static CUDAGraphNodeLauncher &Instance() { + static CUDAGraphNodeLauncher *launcher = new CUDAGraphNodeLauncher; + return *launcher; + } + + private: + CUDAGraphNodeLauncher() : id(0) {} + DISABLE_COPY_AND_ASSIGN(CUDAGraphNodeLauncher); + + unsigned int GenerateIndentifier() { return id++; } + + unsigned int id; + std::unordered_map> + parameterSetters; }; #if CUDA_VERSION >= 10010 @@ -244,7 +323,9 @@ class CUDAGraph { std::mutex mtx_; std::vector set_seed_funcs_; - std::vector>> pre_hooks_; + // we collect all callbacks as a sequence of 'prehooks', i.e. these functions + // are called prior to the execution of the cudagraph. + std::vector> pre_hooks_; std::mutex func_mtx_; bool is_first_run_{true}; @@ -288,54 +369,6 @@ class CUDAGraphCaptureModeGuard { }; #endif -template -static bool IsBitwiseEqual(const T &x, const T &y) { - return std::memcmp(&x, &y, sizeof(T)) == 0; -} - -template -struct IsSameKernelHelper; - -template -struct IsSameKernelHelper { - private: - using FuncArgsTuple = decltype(std::make_tuple(std::declval()...)); - - template - struct Impl { - static bool Compare(const CUDAKernelParams ¶ms, const TupleT &args) { - using CompareT = typename std::tuple_element::type; - if (!IsBitwiseEqual(params.As(IDX), - std::get(args))) { - return false; - } - - constexpr auto NewIsEnd = (IDX + 1 == std::tuple_size::value); - return Impl::Compare(params, args); - } - }; - - template - struct Impl { - static bool Compare(const CUDAKernelParams ¶ms, const TupleT &args) { - return true; - } - }; - - public: - template - static bool Compare(const CUDAKernelParams ¶ms, Args... args) { - constexpr auto kNumArgs = sizeof...(FuncArgs); - static_assert(kNumArgs == sizeof...(Args), "Argument number not match"); - - auto args_tuple = std::make_tuple(args...); - using TupleT = typename std::decay::type; - return Impl::Compare(params, args_tuple); - } -}; - } // namespace gpu } // namespace backends } // namespace phi diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h b/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h index a7f863729288a..de5c1503f35db 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h +++ b/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h @@ -27,61 +27,6 @@ namespace phi { namespace backends { namespace gpu { -#ifdef PADDLE_WITH_CUDA -#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \ - __kernel_func, \ - __grid, \ - __block, \ - __sm_size, \ - __stream, \ - __seed_inc, \ - __seed_expr, \ - __offset_expr, \ - ...) \ - do { \ - if (::phi::backends::gpu::CUDAGraph::IsThisThreadCapturing() && \ - (__cond)) { \ - using __Helper = \ - ::phi::backends::gpu::IsSameKernelHelper; \ - auto *dev_ctx = ::phi::DeviceContextPool::Instance().GetByPlace( \ - ::phi::backends::gpu::CUDAGraph::CapturingPlace()); \ - auto __set_seed_func = \ - [=](::phi::backends::gpu::CUDAKernelParams *__params, \ - bool __check_only) -> bool { \ - if (__check_only) { \ - return __params->func() == &__kernel_func && \ - __Helper::Compare(*__params, __VA_ARGS__); \ - } \ - auto &KERNEL_PARAMS = *__params; \ - uint64_t __seed, __offset; \ - ::phi::funcs::GetSeedDataAndIncrement( \ - *dev_ctx, nullptr, false, 0, __seed_inc, &__seed, &__offset); \ - __seed_expr = static_cast(__seed); \ - __offset_expr = static_cast(__offset); \ - return true; \ - }; \ - ::phi::backends::gpu::CUDAGraph::RecordRandomKernelInfo( \ - __set_seed_func); \ - } \ - __kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \ - } while (0) -#else -#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \ - __kernel_func, \ - __grid, \ - __block, \ - __sm_size, \ - __stream, \ - __seed_inc, \ - __seed_expr, \ - __offset_expr, \ - ...) \ - do { \ - __kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \ - } while (0) -#endif - inline bool IsCUDAGraphCapturing() { #ifdef PADDLE_WITH_CUDA return CUDAGraph::IsCapturing(); diff --git a/paddle/phi/backends/gpu/cuda/cudnn_helper.h b/paddle/phi/backends/gpu/cuda/cudnn_helper.h index 651a4247a12df..74db3fc75bcd1 100644 --- a/paddle/phi/backends/gpu/cuda/cudnn_helper.h +++ b/paddle/phi/backends/gpu/cuda/cudnn_helper.h @@ -33,8 +33,12 @@ namespace phi { namespace backends { namespace gpu { +#define CUDNN_VERSION_COMPUTE(major, minor, patch) \ + ((major) <= 8 ? (major)*1000 + (minor)*100 + (patch) \ + : (major)*10000 + (minor)*100 + (patch)) + #define CUDNN_VERSION_MIN(major, minor, patch) \ - (CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch))) + (CUDNN_VERSION >= CUDNN_VERSION_COMPUTE(major, minor, patch)) enum class DataLayout { // Not use kNHWC, diff --git a/paddle/phi/backends/gpu/gpu_context.cc b/paddle/phi/backends/gpu/gpu_context.cc index 7905320728bda..f87e3b3d80539 100644 --- a/paddle/phi/backends/gpu/gpu_context.cc +++ b/paddle/phi/backends/gpu/gpu_context.cc @@ -919,17 +919,17 @@ ncclComm_t GPUContext::nccl_comm() const { return impl_->GetNcclComm(); } void GPUContext::set_nccl_comm(ncclComm_t comm) { impl_->SetNcclComm(comm); } void GPUContext::Init() { - impl_->allocator_ = const_cast(&this->GetAllocator()); + impl_->allocator_ = const_cast(&this->GetAllocator()); // NOLINT impl_->Init(); } void GPUContext::SetStream(gpuStream_t stream) { - impl_->allocator_ = const_cast(&this->GetAllocator()); + impl_->allocator_ = const_cast(&this->GetAllocator()); // NOLINT impl_->SetStream(stream); } void GPUContext::SetCUDAStream(CUDAStream* stream, bool clear) { - impl_->allocator_ = const_cast(&this->GetAllocator()); + impl_->allocator_ = const_cast(&this->GetAllocator()); // NOLINT impl_->SetCUDAStream(stream, clear); } @@ -1006,7 +1006,7 @@ void GPUContext::PartialInitWithoutAllocator(int stream_priority) { } void GPUContext::PartialInitWithAllocator() { - impl_->allocator_ = const_cast(&this->GetAllocator()); + impl_->allocator_ = const_cast(&this->GetAllocator()); // NOLINT impl_->PartialInitWithAllocator(); } diff --git a/paddle/phi/backends/gpu/gpu_primitives.h b/paddle/phi/backends/gpu/gpu_primitives.h index ca47ac53228f8..bcf5220f65545 100644 --- a/paddle/phi/backends/gpu/gpu_primitives.h +++ b/paddle/phi/backends/gpu/gpu_primitives.h @@ -45,6 +45,91 @@ constexpr int PADDLE_CUDA_NUM_THREADS = 512; USE_CUDA_ATOMIC(Add, float); USE_CUDA_ATOMIC(Add, int); USE_CUDA_ATOMIC(Add, unsigned int); + +CUDA_ATOMIC_WRAPPER(Add, bool) { + size_t offset = reinterpret_cast(address) & 3; + uint32_t *address_as_ui = + reinterpret_cast(reinterpret_cast(address) - offset); + uint32_t old = *address_as_ui; + uint32_t shift = offset * 8; + uint32_t old_byte; + uint32_t newval; + uint32_t assumed; + + do { + assumed = old; + old_byte = (old >> shift) & 0xff; + newval = static_cast(val + static_cast(old_byte)); + newval = (old & ~(0x000000ff << shift)) | (newval << shift); + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); + + return static_cast(old & 0xff); +} + +CUDA_ATOMIC_WRAPPER(Add, uint8_t) { + size_t offset = reinterpret_cast(address) & 3; + uint32_t *address_as_ui = + reinterpret_cast(reinterpret_cast(address) - offset); + uint32_t old = *address_as_ui; + uint32_t shift = offset * 8; + uint32_t old_byte; + uint32_t newval; + uint32_t assumed; + + do { + assumed = old; + old_byte = (old >> shift) & 0xff; + newval = static_cast(val + static_cast(old_byte)); + newval = (old & ~(0x000000ff << shift)) | (newval << shift); + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); + + return static_cast(old & 0xff); +} + +CUDA_ATOMIC_WRAPPER(Add, int8_t) { + size_t offset = reinterpret_cast(address) & 3; + uint32_t *address_as_ui = + reinterpret_cast(reinterpret_cast(address) - offset); + uint32_t old = *address_as_ui; + uint32_t shift = offset * 8; + uint32_t old_byte; + uint32_t newval; + uint32_t assumed; + + do { + assumed = old; + old_byte = (old >> shift) & 0xff; + newval = static_cast(val + static_cast(old_byte)); + newval = (old & ~(0x000000ff << shift)) | (newval << shift); + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); + + return static_cast(old & 0xff); +} + +CUDA_ATOMIC_WRAPPER(Add, int16_t) { + size_t offset = reinterpret_cast(address) & 2; + uint32_t *address_as_ui = + reinterpret_cast(reinterpret_cast(address) - offset); + bool is_32_align = offset; + uint32_t old = *address_as_ui; + uint32_t old_bytes; + uint32_t newval; + uint32_t assumed; + + do { + assumed = old; + old_bytes = is_32_align ? old >> 16 : old & 0xffff; + newval = static_cast(val + static_cast(old_bytes)); + newval = is_32_align ? (old & 0xffff) | (newval << 16) + : (old & 0xffff0000) | newval; + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); + + return static_cast(old & 0xffff); +} // CUDA API uses unsigned long long int, we cannot use uint64_t here. // It because unsigned long long int is not necessarily uint64_t USE_CUDA_ATOMIC(Add, unsigned long long int); // NOLINT diff --git a/paddle/phi/backends/gpu/gpu_resources.cc b/paddle/phi/backends/gpu/gpu_resources.cc index a447df94cb4dc..a29b5e110922a 100644 --- a/paddle/phi/backends/gpu/gpu_resources.cc +++ b/paddle/phi/backends/gpu/gpu_resources.cc @@ -146,19 +146,40 @@ void InitGpuProperties(Place place, } #else size_t cudnn_dso_ver = dynload::cudnnGetVersion(); + auto get_cudnn_major = [](auto version) { + if (version < 9000) { + return version / 1000; + } + // CUDNN changes the CUDNN_VERSION rules after 9.0 + return version / 10000; + }; + auto get_cudnn_minor = [](auto version) { + if (version < 9000) { + return (version % 1000) / 100; + } + // CUDNN changes the CUDNN_VERSION rules after 9.0 + return (version % 10000) / 100; + }; + LOG_FIRST_N(WARNING, 1) << "device: " << static_cast(place.device) - << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "." - << (cudnn_dso_ver % 1000) / 100 << "."; + << ", cuDNN Version: " + << get_cudnn_major(cudnn_dso_ver) << "." + << get_cudnn_minor(cudnn_dso_ver) << "."; // Check CUDA/CUDNN version compatiblity auto local_cuda_version = (*driver_version / 1000) * 10 + (*driver_version % 100) / 10; auto compile_cuda_version = (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10; + + // Compute cuDNN major + auto local_cudnn_major = get_cudnn_major(cudnn_dso_ver); + size_t compile_cudnn_major = CUDNN_MAJOR; + #if defined(__linux__) PADDLE_ENFORCE_EQ( (local_cuda_version / 10 < compile_cuda_version / 10) && - (cudnn_dso_ver / 1000 < CUDNN_VERSION / 1000), + (local_cudnn_major < compile_cudnn_major), false, phi::errors::InvalidArgument( "The installed Paddle is compiled with CUDA%d/cuDNN%d," @@ -167,9 +188,9 @@ void InitGpuProperties(Place place, "Please recompile or reinstall Paddle with compatible CUDA/cuDNN " "version.", compile_cuda_version / 10, - CUDNN_VERSION / 1000, + compile_cudnn_major, local_cuda_version / 10, - cudnn_dso_ver / 1000)); + local_cudnn_major)); #endif if (local_cuda_version < compile_cuda_version) { LOG_FIRST_N(WARNING, 1) @@ -269,15 +290,17 @@ void InitDnnHandle(dnnHandle_t* handle, gpuStream_t stream, Place place) { PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(handle)); PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetStream(*handle, stream)); #else - auto local_cudnn_version = phi::dynload::cudnnGetVersion() / 100; - auto compile_cudnn_version = CUDNN_VERSION / 100; - if (local_cudnn_version < static_cast(compile_cudnn_version)) { + auto version = phi::dynload::cudnnGetVersion(); + auto local_cudnn_major = + (version < 9000) ? version / 1000 : version / 10000; + auto local_cudnn_minor = + (version < 9000) ? (version % 1000) / 100 : (version % 10000) / 100; + if (version < static_cast(CUDNN_VERSION)) { LOG_FIRST_N(WARNING, 1) << "WARNING: device: " << place.device - << ". The installed Paddle is compiled with CUDNN " - << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10 - << ", but CUDNN version in your machine is " - << local_cudnn_version / 10 << "." << local_cudnn_version % 10 + << ". The installed Paddle is compiled with CUDNN " << CUDNN_MAJOR + << "." << CUDNN_MINOR << ", but CUDNN version in your machine is " + << local_cudnn_major << "." << local_cudnn_minor << ", which may cause serious incompatible bug. " << "Please recompile or reinstall Paddle with compatible CUDNN " "version."; diff --git a/paddle/phi/backends/gpu/rocm/miopen_helper.h b/paddle/phi/backends/gpu/rocm/miopen_helper.h index b8ce6e22e939b..f7815e2ed851e 100644 --- a/paddle/phi/backends/gpu/rocm/miopen_helper.h +++ b/paddle/phi/backends/gpu/rocm/miopen_helper.h @@ -61,8 +61,12 @@ inline const char* miopenGetErrorString(miopenStatus_t status) { } // no use, but will have compiling error if not defined +#define CUDNN_VERSION_COMPUTE(major, minor, patch) \ + ((major) <= 8 ? (major)*1000 + (minor)*100 + (patch) \ + : (major)*10000 + (minor)*100 + (patch)) + #define CUDNN_VERSION_MIN(major, minor, patch) \ - (CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch))) + (CUDNN_VERSION >= CUDNN_VERSION_COMPUTE(major, minor, patch)) enum class DataLayout { // Not use kNHWC, diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 6252bbc54c933..0069cb453e236 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -1327,6 +1327,8 @@ class BatchNormOneDNNHandler Place cpu_place, const DenseTensor* x, const float epsilon, + const bool use_scale, + const bool use_bias, const bool fuse_with_relu, const bool global_stats, const bool test_mode) @@ -1335,8 +1337,9 @@ class BatchNormOneDNNHandler dnnl::batch_normalization_backward>(engine, cpu_place) { // Flags are added by bitwise OR operation - auto flags = dnnl::normalization_flags::use_scale | - dnnl::normalization_flags::use_shift; + auto flags = dnnl::normalization_flags::none; + if (use_scale) flags |= dnnl::normalization_flags::use_scale; + if (use_bias) flags |= dnnl::normalization_flags::use_shift; if (global_stats) flags |= dnnl::normalization_flags::use_global_stats; if (fuse_with_relu && test_mode) flags |= dnnl::normalization_flags::fuse_norm_relu; @@ -1354,39 +1357,31 @@ class BatchNormOneDNNHandler Place cpu_place, const float epsilon, const DenseTensor* in_x, - const DenseTensor* scale, + const bool use_scale, + const bool use_bias, const DenseTensor* out_grad) : OneDNNHandlerNoCachingT(engine, cpu_place) { - auto scale_tz = vectorize(scale->dims()); - PADDLE_ENFORCE_EQ( - scale_tz.size(), - 1, - errors::InvalidArgument( - "Dims of scale tensor must be 1, but received scale's size is %d", - scale_tz.size())); + auto flags = dnnl::normalization_flags::none; + if (use_scale) flags |= dnnl::normalization_flags::use_scale; + if (use_bias) flags |= dnnl::normalization_flags::use_shift; - this->AcquireForwardPrimitiveDescriptor( - dnnl::prop_kind::forward_training, - in_x->mem_desc(), - in_x->mem_desc(), - epsilon, - dnnl::normalization_flags::use_scale | - dnnl::normalization_flags::use_shift); - this->AcquireBackwardPrimitiveDescriptor( - dnnl::prop_kind::backward, - out_grad->mem_desc(), - out_grad->mem_desc(), - in_x->mem_desc(), - epsilon, - dnnl::normalization_flags::use_scale | - dnnl::normalization_flags::use_shift); + this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, + in_x->mem_desc(), + in_x->mem_desc(), + epsilon, + flags); + this->AcquireBackwardPrimitiveDescriptor(dnnl::prop_kind::backward, + out_grad->mem_desc(), + out_grad->mem_desc(), + in_x->mem_desc(), + epsilon, + flags); } - std::tuple, std::shared_ptr> - AcquireScaleShiftMemory(const DenseTensor* scale, const DenseTensor* shift) { + std::shared_ptr AcquireScaleMemory(const DenseTensor* scale) { auto scale_tz = vectorize(scale->dims()); PADDLE_ENFORCE_EQ( scale_tz.size(), @@ -1397,20 +1392,37 @@ class BatchNormOneDNNHandler auto scale_memory = this->AcquireMemoryFromPrimitive( this->fwd_pd_->weights_desc(), to_void_cast(scale->data())); + + return scale_memory; + } + + std::shared_ptr AcquireShiftMemory(const DenseTensor* shift) { + auto shift_tz = vectorize(shift->dims()); + PADDLE_ENFORCE_EQ( + shift_tz.size(), + 1, + errors::InvalidArgument( + "Dims of bias tensor must be 1, but received bias's size is %d", + shift_tz.size())); + auto shift_memory = this->AcquireMemoryFromPrimitive( this->fwd_pd_->weights_desc(), to_void_cast(shift->data())); - return std::make_tuple(scale_memory, shift_memory); + return shift_memory; } - std::tuple, std::shared_ptr> - AcquireDiffScaleShiftMemory(T* diff_scale_data, T* diff_shift_data) { + std::shared_ptr AcquireDiffScaleMemory(T* diff_scale_data) { auto diff_scale_memory = this->AcquireMemoryFromPrimitive( this->bwd_pd_->diff_weights_desc(), diff_scale_data); + + return diff_scale_memory; + } + + std::shared_ptr AcquireDiffShiftMemory(T* diff_shift_data) { auto diff_shift_memory = this->AcquireMemoryFromPrimitive( this->bwd_pd_->diff_weights_desc(), diff_shift_data); - return std::make_tuple(diff_scale_memory, diff_shift_memory); + return diff_shift_memory; } std::shared_ptr AcquireMeanMemory(const DenseTensor* mean) { @@ -1519,8 +1531,11 @@ class PoolingOneDNNHandler } if (adaptive) { - ComputeAdaptivePoolParameters( - src_tz, &copied_kernel_size, &copied_strides); + ComputeAdaptivePoolParameters(src_tz, + onednn_paddings[0], + onednn_paddings[1], + &copied_kernel_size, + &copied_strides); } bool is_test = dev_ctx.HasDnnAttr("is_test") @@ -1612,8 +1627,11 @@ class PoolingOneDNNHandler } if (adaptive) { - ComputeAdaptivePoolParameters( - diff_src_tz, &copied_kernel_size, &copied_strides); + ComputeAdaptivePoolParameters(src_tz, + onednn_paddings[0], + onednn_paddings[1], + &copied_kernel_size, + &copied_strides); } memory::dims dilation = {0, 0}; @@ -1672,23 +1690,45 @@ class PoolingOneDNNHandler return mem_p; } - static void ComputeAdaptivePoolParameters(const std::vector& src_tz, - std::vector* kernel_size, - std::vector* strides) { + static void ComputeAdaptivePoolParameters( + const std::vector& src_tz, + const std::vector& padding_l, + const std::vector& padding_r, + std::vector* kernel_size, + std::vector* strides) { // https://github.com/oneapi-src/oneDNN/tree/bkocot/adaptive-pooling/rfcs/20200818-adaptive-pooling auto IH = static_cast(src_tz[src_tz.size() - 2]); auto IW = static_cast(src_tz[src_tz.size() - 1]); auto OH = static_cast(kernel_size->at(0)); auto OW = static_cast(kernel_size->at(1)); - strides->at(0) = - static_cast(floor((IH * 2.0) / OH) - floor(IH / OH)); - strides->at(1) = - static_cast(floor((IW * 2.0) / OW) - floor(IW / OW)); - kernel_size->at(0) = - static_cast(ceil((IH * 2.0) / OH) - floor(IH / OH)); - kernel_size->at(1) = - static_cast(ceil((IW * 2.0) / OW) - floor(IW / OW)); + /* + The previous calculation formula is given by OneDNN rfc, but in some odd + cases(mod(I/O)>=O/2) there will be problems with the calculation results. + Now change the formula to the general calculation formula of + AdaptivePool when in mod(I/O)>=O/2 case: + stride=floor(input_size/output_size) + kernel_size=input_size-(output_size-1)*stride + */ + int mod_H = IH - floor(IH / OH) * OH; + int mod_W = IW - floor(IW / OW) * OW; + if (2 * mod_H < OH && 2 * mod_W < OW) { + strides->at(0) = + static_cast(floor((IH * 2.0) / OH) - floor(IH / OH)); + strides->at(1) = + static_cast(floor((IW * 2.0) / OW) - floor(IW / OW)); + kernel_size->at(0) = + static_cast(ceil((IH * 2.0) / OH) - floor(IH / OH)); + kernel_size->at(1) = + static_cast(ceil((IW * 2.0) / OW) - floor(IW / OW)); + } else { + strides->at(0) = static_cast(floor(IH / OH)); + strides->at(1) = static_cast(floor(IW / OW)); + kernel_size->at(0) = static_cast( + IH + padding_l[0] + padding_r[0] - floor((OH - 1) * strides->at(0))); + kernel_size->at(1) = static_cast( + IW + padding_l[1] + padding_r[1] - floor((OW - 1) * strides->at(1))); + } } private: diff --git a/paddle/phi/backends/xpu/xpu1_op_list.cc b/paddle/phi/backends/xpu/xpu1_op_list.cc index f99805c095992..52c9661e9f55a 100644 --- a/paddle/phi/backends/xpu/xpu1_op_list.cc +++ b/paddle/phi/backends/xpu/xpu1_op_list.cc @@ -21,7 +21,10 @@ namespace xpu { XPUOpMap& get_kl1_ops() { // KL1支持的op,通过op_name, data_type static XPUOpMap s_xpu1_kernels{ - {"abs", XPUKernelSet({phi::DataType::FLOAT32})}, + {"abs", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, {"accuracy", XPUKernelSet({phi::DataType::FLOAT32})}, {"adam", XPUKernelSet({phi::DataType::FLOAT32})}, {"adamw", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -34,6 +37,11 @@ XPUOpMap& get_kl1_ops() { phi::DataType::INT32, phi::DataType::INT64, phi::DataType::BOOL})}, + {"assign_value", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::BOOL})}, {"batch_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"batch_norm", XPUKernelSet({phi::DataType::FLOAT32})}, {"bilinear_interp", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -48,13 +56,20 @@ XPUOpMap& get_kl1_ops() { {"cast", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT64, + phi::DataType::BOOL, phi::DataType::INT32})}, {"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})}, {"coalesce_tensor", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64, phi::DataType::INT32})}, - {"concat", XPUKernelSet({phi::DataType::FLOAT32})}, + {"concat", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::INT64, + phi::DataType::INT32})}, {"concat_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"conv2d", XPUKernelSet({phi::DataType::FLOAT32})}, {"conv2d_grad", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -67,20 +82,39 @@ XPUOpMap& get_kl1_ops() { {"dropout_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"c_allreduce_sum", XPUKernelSet({phi::DataType::FLOAT32})}, {"c_reduce_sum", XPUKernelSet({phi::DataType::FLOAT32})}, - {"elementwise_add", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_add", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT64, + phi::DataType::INT32})}, {"elementwise_add_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_div_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_div", XPUKernelSet({phi::DataType::FLOAT32})}, - {"elementwise_floordiv", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_floordiv", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_max_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"elementwise_max", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_max", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_min_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"elementwise_min", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_min", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_mul_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_mul", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_pow", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_sub_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"elementwise_sub", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_sub", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"embedding_with_eltwise_add_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"equal", XPUKernelSet({phi::DataType::INT64})}, @@ -115,14 +149,26 @@ XPUOpMap& get_kl1_ops() { phi::DataType::INT32, phi::DataType::INT64, })}, + {"greater_than", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32})}, {"hard_switch_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"hard_switch", XPUKernelSet({phi::DataType::FLOAT32})}, + {"index_select", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, {"iou_similarity", XPUKernelSet({phi::DataType::FLOAT32})}, {"lamb", XPUKernelSet({phi::DataType::FLOAT32})}, {"layer_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"layer_norm", XPUKernelSet({phi::DataType::FLOAT32})}, {"leaky_relu_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"leaky_relu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"less_than", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32})}, {"load", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT8, @@ -206,7 +252,10 @@ XPUOpMap& get_kl1_ops() { {"rnn", XPUKernelSet({phi::DataType::FLOAT32})}, {"roi_align_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"roi_align", XPUKernelSet({phi::DataType::FLOAT32})}, - {"scale", XPUKernelSet({phi::DataType::FLOAT32})}, + {"scale", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT64, + phi::DataType::INT32})}, {"sgd", XPUKernelSet({phi::DataType::FLOAT32})}, {"shape", XPUKernelSet({phi::DataType::FLOAT64, @@ -218,7 +267,11 @@ XPUOpMap& get_kl1_ops() { {"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, {"sign", XPUKernelSet({phi::DataType::FLOAT32})}, {"slice_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"slice", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, + {"slice", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT16})}, {"softmax_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"softmax_with_cross_entropy", XPUKernelSet({phi::DataType::FLOAT32})}, {"softmax_with_cross_entropy_grad", @@ -306,6 +359,10 @@ XPUOpMap& get_kl1_ops() { phi::DataType::UINT8, phi::DataType::FLOAT32})}, {"where_index", XPUKernelSet({phi::DataType::BOOL})}, + {"where", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT32})}, // AddMore }; diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 74a8cf0bc1150..356ed02444d87 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -26,7 +26,12 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"add_layernorm_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"abs", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"abs", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT8, + phi::DataType::INT64})}, {"abs_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"accuracy", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -139,6 +144,7 @@ XPUOpMap& get_kl2_ops() { {"cast", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT64, phi::DataType::BOOL, phi::DataType::INT8, @@ -176,7 +182,9 @@ XPUOpMap& get_kl2_ops() { {"conv1d_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"conv2d_xpu", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT8})}, {"conv3d_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"conv3d", @@ -210,6 +218,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32})}, {"depthwise_conv2d_transpose", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"dequantize_xpu", + XPUKernelSet({phi::DataType::INT16, phi::DataType::INT8})}, {"diag_v2", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, @@ -243,15 +253,24 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT64, phi::DataType::INT32})}, {"elementwise_floordiv", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_max_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_max", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_min_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_min", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_mul_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_mul", @@ -317,7 +336,9 @@ XPUOpMap& get_kl2_ops() { {"fast_layernorm_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fc_xpu", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT8})}, {"fill", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, @@ -425,7 +446,9 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT64, phi::DataType::BOOL})}, {"gaussian_random", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"gelu_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"gelu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -617,6 +640,8 @@ XPUOpMap& get_kl2_ops() { {"prelu_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"prod_raw", XPUKernelSet({phi::DataType::FLOAT32})}, + {"quantize_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"range", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT64, @@ -640,7 +665,10 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"reduce_min_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})}, - {"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reduce_prod", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, {"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_sum", XPUKernelSet({phi::DataType::FLOAT16, @@ -653,6 +681,8 @@ XPUOpMap& get_kl2_ops() { {"relu_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"relu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"repeat_interleave", + XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, {"reshape2_grad", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::FLOAT16, @@ -700,6 +730,11 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, phi::DataType::INT64})}, + {"scatter_nd_add_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"sampling_id", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})}, {"set_value", @@ -748,10 +783,12 @@ XPUOpMap& get_kl2_ops() { {"slice_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32})}, {"slice", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, {"softmax", @@ -777,14 +814,19 @@ XPUOpMap& get_kl2_ops() { {"split", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, {"split_with_num", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, - {"sqrt", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"sqrt", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"square_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -797,6 +839,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze2", XPUKernelSet({phi::DataType::FLOAT64, @@ -806,6 +850,7 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze", XPUKernelSet({phi::DataType::FLOAT64, @@ -814,6 +859,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze_grad", XPUKernelSet({phi::DataType::FLOAT64, @@ -822,6 +869,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"stack", XPUKernelSet({phi::DataType::FLOAT32, @@ -833,17 +882,20 @@ XPUOpMap& get_kl2_ops() { {"strided_slice", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT16, phi::DataType::INT32, phi::DataType::INT64})}, {"strided_slice_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT16, phi::DataType::INT32})}, {"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"swish", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"swish_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"swish_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"take_along_axis", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"tanh_grad", @@ -854,6 +906,7 @@ XPUOpMap& get_kl2_ops() { {"transfer_dtype", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT64, phi::DataType::BOOL, phi::DataType::UINT8, @@ -894,24 +947,28 @@ XPUOpMap& get_kl2_ops() { {"transpose2_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose2", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, @@ -922,7 +979,10 @@ XPUOpMap& get_kl2_ops() { {"update_loss_scaling", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"unbind", XPUKernelSet({phi::DataType::FLOAT32})}, - {"uniform_random", XPUKernelSet({phi::DataType::FLOAT32})}, + {"uniform_random", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unique", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, @@ -935,7 +995,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze2", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -944,7 +1005,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze_grad", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -952,7 +1014,9 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -960,8 +1024,9 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::FLOAT32})}, + phi::DataType::BFLOAT16})}, {"unstack", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, @@ -977,7 +1042,8 @@ XPUOpMap& get_kl2_ops() { {"where_index", XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::INT64})}, {"where_grad", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, @@ -987,7 +1053,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"sin", XPUKernelSet({phi::DataType::FLOAT32})}, {"sin_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"cos", XPUKernelSet({phi::DataType::FLOAT32})}, diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 29a8549395894..0ba008a680d7b 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -138,6 +138,7 @@ XPUOpMap& get_kl3_ops() { {"cast", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT64, phi::DataType::BOOL, phi::DataType::INT8, @@ -415,7 +416,9 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT64, phi::DataType::BOOL})}, {"gaussian_random", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"gelu_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"gelu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -626,7 +629,10 @@ XPUOpMap& get_kl3_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"reduce_min_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})}, - {"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reduce_prod", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, {"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_sum", XPUKernelSet({phi::DataType::FLOAT16, @@ -684,6 +690,11 @@ XPUOpMap& get_kl3_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, phi::DataType::INT64})}, + {"scatter_nd_add_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"sampling_id", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})}, {"set_value", @@ -731,10 +742,12 @@ XPUOpMap& get_kl3_ops() { {"slice_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32})}, {"slice", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, {"softmax", @@ -759,14 +772,21 @@ XPUOpMap& get_kl3_ops() { {"split", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, {"split_with_num", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, - {"sqrt", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"sqrt", + XPUKernelSet({ + phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, + })}, {"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"square_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -779,6 +799,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze2", XPUKernelSet({phi::DataType::FLOAT64, @@ -788,6 +810,7 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze", XPUKernelSet({phi::DataType::FLOAT64, @@ -796,6 +819,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze_grad", XPUKernelSet({phi::DataType::FLOAT64, @@ -804,6 +829,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"stack", XPUKernelSet({phi::DataType::FLOAT32, @@ -815,17 +842,25 @@ XPUOpMap& get_kl3_ops() { {"strided_slice", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT16, phi::DataType::INT32, phi::DataType::INT64})}, {"strided_slice_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT16, phi::DataType::INT32})}, {"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"swish", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"swish_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"swish", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, + {"swish_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"take_along_axis", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"tanh_grad", @@ -836,6 +871,7 @@ XPUOpMap& get_kl3_ops() { {"transfer_dtype", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT64, phi::DataType::BOOL, phi::DataType::UINT8, @@ -876,24 +912,28 @@ XPUOpMap& get_kl3_ops() { {"transpose2_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose2", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, @@ -904,7 +944,10 @@ XPUOpMap& get_kl3_ops() { {"update_loss_scaling", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"unbind", XPUKernelSet({phi::DataType::FLOAT32})}, - {"uniform_random", XPUKernelSet({phi::DataType::FLOAT32})}, + {"uniform_random", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unique", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, @@ -917,7 +960,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze2", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -926,7 +970,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze_grad", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -934,7 +979,9 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -942,8 +989,9 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::FLOAT32})}, + phi::DataType::BFLOAT16})}, {"unstack", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, @@ -959,7 +1007,8 @@ XPUOpMap& get_kl3_ops() { {"where_index", XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::INT64})}, {"where_grad", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, @@ -969,7 +1018,8 @@ XPUOpMap& get_kl3_ops() { XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"sin", XPUKernelSet({phi::DataType::FLOAT32})}, {"sin_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"cos", XPUKernelSet({phi::DataType::FLOAT32})}, diff --git a/paddle/phi/core/ddim.h b/paddle/phi/core/ddim.h index 57ad4d09ef463..be11b4c9596cd 100644 --- a/paddle/phi/core/ddim.h +++ b/paddle/phi/core/ddim.h @@ -227,7 +227,7 @@ std::vector vectorize(const DDim& ddim) { return result; } -int64_t product(const DDim& ddim); +TEST_API int64_t product(const DDim& ddim); bool contain_unknown_dim(const DDim& ddim); diff --git a/paddle/phi/core/device_context.cc b/paddle/phi/core/device_context.cc index 7f636cd2a4831..3804802e84260 100644 --- a/paddle/phi/core/device_context.cc +++ b/paddle/phi/core/device_context.cc @@ -176,8 +176,10 @@ struct DeviceContext::Impl { allocator = cuda_graph_allocator_; } #endif - return tensor->AllocateFrom( - const_cast(allocator), dtype, requested_size, fake_alloc); + return tensor->AllocateFrom(const_cast(allocator), + dtype, + requested_size, + fake_alloc); // NOLINT } template @@ -218,8 +220,10 @@ struct DeviceContext::Impl { (fake_alloc || tensor->numel() == 0) && requested_size == 0 ? host_zero_allocator_ : host_allocator_; - return tensor->AllocateFrom( - const_cast(allocator), dtype, requested_size, fake_alloc); + return tensor->AllocateFrom(const_cast(allocator), + dtype, + requested_size, + fake_alloc); // NOLINT } template diff --git a/paddle/phi/core/distributed/CMakeLists.txt b/paddle/phi/core/distributed/CMakeLists.txt index 12c59059c7c32..8e58ab4bf840e 100644 --- a/paddle/phi/core/distributed/CMakeLists.txt +++ b/paddle/phi/core/distributed/CMakeLists.txt @@ -5,7 +5,9 @@ add_subdirectory(auto_parallel) set(DISTRIBUTED_COMMON_SRCS comm_context_manager.cc) if(WITH_NCCL OR WITH_RCCL) - list(APPEND DISTRIBUTED_COMMON_SRCS nccl_comm_context.cc) + list(APPEND DISTRIBUTED_COMMON_SRCS comm_task_manager.cc) + list(APPEND DISTRIBUTED_COMMON_SRCS nccl_comm_context.cc nccl_comm_task.cc + nccl_tools.cc) endif() if(WITH_GLOO) diff --git a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt index 92e69e0dc7657..15e73c3d50215 100644 --- a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt @@ -7,15 +7,8 @@ collect_srcs( process_mesh.cc dist_attr.cc dist_mapper.cc - reshard_utils.cc dist_tensor.cc dist_meta_tensor.cc - inferspmd_utils.cc - reshard_function.cc - r_to_s_reshard_function.cc - s_to_r_reshard_function.cc - r_to_p_reshard_function.cc - p_to_r_reshard_function.cc - s_to_s_reshard_function.cc - nd_mesh_reshard_function.cc - same_status_reshard_function.cc) + inferspmd_utils.cc) + +add_subdirectory(reshard) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc index 46e58cc9b373e..052a6d457ca8b 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc @@ -349,7 +349,8 @@ std::string TensorDistAttr::partial_status_string() const { } bool TensorDistAttr::empty() const { - return process_mesh_.empty() || dims_mapping_.empty(); + // dims_mapping is empty when the tensor is 0-dim, but it is also be valid. + return process_mesh_.empty(); } std::vector> TensorDistAttr::to_placement() @@ -398,7 +399,7 @@ bool TensorDistAttr::is_replicated(int64_t mesh_axis) const { bool TensorDistAttr::is_shard(int64_t mesh_axis, int64_t tensor_axis) const { auto placement = to_placement(); if (mesh_axis == -1) { - return std::all_of(placement.begin(), + return std::any_of(placement.begin(), placement.end(), [tensor_axis](std::shared_ptr status) { return status->is_shard(tensor_axis); diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.h b/paddle/phi/core/distributed/auto_parallel/dist_attr.h index f051592b7bf7e..6689750d24ad9 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.h @@ -32,6 +32,8 @@ limitations under the License. */ namespace phi { namespace distributed { +constexpr int kReplicateDim = -1; + class PlacementStatus { public: virtual ~PlacementStatus() = default; diff --git a/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc index dc5d6c20e62b3..1e3164de81865 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc @@ -22,11 +22,11 @@ namespace distributed { phi::DDim DistMetaTensor::dims() const { // member values in tensor_ have higher priority than those in DistMetaTensor if (tensor_ != nullptr) { - PADDLE_ENFORCE_EQ(this->is_dist(), - true, - phi::errors::InvalidArgument( - "The current MetaTensor doesn't contains " - "DistTensor when call `dist_attr` method.")); + PADDLE_ENFORCE_EQ( + this->is_dist(), + true, + phi::errors::InvalidArgument("The current MetaTensor doesn't contains " + "DistTensor when call `dims` method.")); return MetaTensor::dims(); } else { return dims_; diff --git a/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h index efbf38d28f9f0..30757c5a1cdaa 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h @@ -22,6 +22,8 @@ namespace distributed { class DistMetaTensor : public MetaTensor { public: + DistMetaTensor() : MetaTensor() {} + // supporting implicit construction is easier to use DistMetaTensor(TensorBase* tensor) // NOLINT : MetaTensor(tensor) {} diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index 8e3e6405f4d29..a2e3ac5123a44 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -16,8 +16,8 @@ #include "glog/logging.h" #include "paddle/phi/backends/context_pool.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #include "paddle/phi/core/distributed/store/store_utils.h" namespace phi { diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h index 4781b5d872001..2d444decf640a 100644 --- a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h @@ -125,6 +125,28 @@ struct InferSpmdFnImpl { } }; + // direct vector + template + struct InferSpmdFnCallHelper&, Tail...> { + template + static SpmdInfo Call(const InferSpmdContext& ctx, PreviousArgs&... pargs) { + static_assert(attr_idx == 0, + "InferSpmd's Input should appear before Attributes."); + // TODO(liuzhenhai): parse input list as vector directly + const std::pair range = ctx.InputRangeAt(in_idx); + std::vector tmp_arg = + ctx.InputsBetween(range.first, range.second); + std::vector arg; + std::transform(tmp_arg.begin(), + tmp_arg.end(), + std::back_inserter(arg), + [](const DistMetaTensor* arg_ptr) { return *arg_ptr; }); + return InferSpmdFnCallHelper::template Call( + ctx, pargs..., arg); + } + }; + #define PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(attr_type) \ template \ struct InferSpmdFnCallHelper { \ diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/reshard/CMakeLists.txt new file mode 100644 index 0000000000000..e5902375cfec6 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard/CMakeLists.txt @@ -0,0 +1,12 @@ +collect_srcs( + core_srcs + SRCS + reshard_utils.cc + reshard_function.cc + r_to_s_reshard_function.cc + s_to_r_reshard_function.cc + r_to_p_reshard_function.cc + p_to_r_reshard_function.cc + s_to_s_reshard_function.cc + nd_mesh_reshard_function.cc + same_status_reshard_function.cc) diff --git a/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc similarity index 91% rename from paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc rename to paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc index 9d5d8f43f7670..28d71ff93b49c 100644 --- a/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.h" #include "glog/logging.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" -#include "paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.h" -#include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h" -#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" -#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.h" namespace phi { namespace distributed { @@ -73,15 +73,14 @@ int64_t FindFirstDiffShardAxis(const TensorDistAttr& in_dist_attr, bool SameNdMeshReshardFunction::IsSuitable( const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; - - flag &= (in.dist_attr().process_mesh() == out_dist_attr.process_mesh()); - flag &= (out_dist_attr.process_mesh().ndim() > 1); + RESHARD_SHORTCUT_IF_FALSE(in.dist_attr().process_mesh() == + out_dist_attr.process_mesh()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.process_mesh().ndim() > 1); // check the input and output dims_mapping is not equal - flag &= in.dist_attr() != out_dist_attr; + RESHARD_SHORTCUT_IF_FALSE(in.dist_attr() != out_dist_attr); - return flag; + return true; } void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx, @@ -121,7 +120,8 @@ void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx, // 1.3 Calculate the input one dim dist attr TensorDistAttr in_one_dim_dist_attr(vectorize(in.dims())); in_one_dim_dist_attr.set_process_mesh(sub_mesh); - in_one_dim_dist_attr.set_partial_status(std::vector{0}); + in_one_dim_dist_attr.set_partial_status(std::vector{0}, + kv.second); // 1.4 Calculate the output one dim dist attr TensorDistAttr out_one_dim_dist_attr(vectorize(in.dims())); diff --git a/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.h similarity index 93% rename from paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.h rename to paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.h index e47cb46138f7b..169c51899717e 100644 --- a/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" namespace phi { namespace distributed { diff --git a/paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc similarity index 58% rename from paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.cc rename to paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc index f9aaa6f8adf7f..ce4d571306cba 100644 --- a/paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc @@ -12,33 +12,33 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.h" #include "glog/logging.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #include "paddle/phi/kernels/all_reduce_kernel.h" +#include "paddle/phi/kernels/elementwise_divide_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" namespace phi { namespace distributed { bool PToRReshardFunction::IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; - - flag &= in.dist_attr().is_partial(); - flag &= out_dist_attr.is_replicated(); + RESHARD_SHORTCUT_IF_FALSE(in.dist_attr().is_partial()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_replicated()); const auto& in_process_mesh = in.dist_attr().process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - flag &= (in_process_mesh.ndim() == 1); - flag &= (out_process_mesh.ndim() == 1); - flag &= (in_process_mesh == out_process_mesh); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh); - return flag; + return true; } void PToRReshardFunction::Eval(DeviceContext* dev_ctx, @@ -50,9 +50,18 @@ void PToRReshardFunction::Eval(DeviceContext* dev_ctx, const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); const auto& in_partial_status = in_dist_attr.partial_status(); + auto in_reduce_type = in_partial_status.at(0); + bool reduce_mean = false; auto dtype = in.dtype(); - int64_t reduce_type = static_cast(in_partial_status.at(0)); + if (in_reduce_type == ReduceType::kRedAvg) { + in_reduce_type = ReduceType::kRedSum; + reduce_mean = true; + } + int64_t reduce_type = static_cast(in_reduce_type); + VLOG(3) << "Transfer from partial to replicated status with reduce type " + << reduce_type; + RESHARD_FUNCTOR_WITH_COMM(dev_ctx, AllReduce, dtype, @@ -61,6 +70,24 @@ void PToRReshardFunction::Eval(DeviceContext* dev_ctx, reduce_type, GetMutableTensor(out)); + if (reduce_mean) { + VLOG(3) << "Do reduce mean after all reduce sum"; + DenseTensor tensor_of_num_process; + IntArray shape({1}); + RESHARD_FUNCTOR(dev_ctx, + Full, + in.dtype(), + shape, + static_cast(in_process_ids.size()), + &tensor_of_num_process); + RESHARD_FUNCTOR(dev_ctx, + Divide, + dtype, + out->value(), + tensor_of_num_process, + GetMutableTensor(out)); + } + SetDistProps(out, in.dims(), out_dist_attr); } diff --git a/paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.h similarity index 93% rename from paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.h rename to paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.h index 3895b9246f3c6..c1b0c3cd01060 100644 --- a/paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" namespace phi { namespace distributed { diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc similarity index 81% rename from paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc rename to paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc index 77569c1ecfbac..93bc88f7888a0 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.h" #include "glog/logging.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #include "paddle/phi/kernels/assign_kernel.h" #include "paddle/phi/kernels/full_kernel.h" @@ -27,20 +27,19 @@ namespace distributed { bool RToPReshardFunction::IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; const auto& in_dist_attr = in.dist_attr(); - flag &= in_dist_attr.is_replicated(); - flag &= out_dist_attr.is_partial(); + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_replicated()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_partial()); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - flag &= (in_process_mesh.ndim() == 1); - flag &= (out_process_mesh.ndim() == 1); - flag &= (in_process_mesh == out_process_mesh); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh); - return flag; + return true; } void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx, diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.h similarity index 93% rename from paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h rename to paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.h index af3bdb41d78c1..3014cdc550e6c 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" namespace phi { namespace distributed { diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc similarity index 53% rename from paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc rename to paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc index bc6cb393a15b8..f4651f0619999 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.h" #include "glog/logging.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h" #include "paddle/phi/kernels/split_kernel.h" namespace phi { @@ -25,20 +26,19 @@ namespace distributed { bool RToSReshardFunction::IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; const auto& in_dist_attr = in.dist_attr(); - flag &= in_dist_attr.is_replicated(); - flag &= out_dist_attr.is_shard(); + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_replicated()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_shard()); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - flag &= (in_process_mesh.ndim() == 1); - flag &= (out_process_mesh.ndim() == 1); - flag &= (in_process_mesh == out_process_mesh); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh); - return flag; + return true; } void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, @@ -86,7 +86,57 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, SetDistProps(out, in.dims(), out_dist_attr); } +bool RToSReshardFunctionCrossMesh::IsSuitable( + const DistTensor& in, const TensorDistAttr& out_dist_attr) { + const auto& in_dist_attr = in.dist_attr(); + + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_replicated()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_shard()); + + const auto& in_process_mesh = in_dist_attr.process_mesh(); + const auto& out_process_mesh = out_dist_attr.process_mesh(); + + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.shape() == + out_process_mesh.shape()); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh != out_process_mesh); + + return true; +} + +void RToSReshardFunctionCrossMesh::Eval(phi::DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) { + VLOG(3) << "Call RToSReshardFunctionCrossMesh Eval"; + const auto& in_dist_attr = in.dist_attr(); + + DistTensor tmp_result; + TensorDistAttr in_dist_attr_shard = in_dist_attr; + in_dist_attr_shard.set_dims_mapping(out_dist_attr.dims_mapping()); + RToSReshardFunction r_to_s_func; + PADDLE_ENFORCE( + r_to_s_func.IsSuitable(in, in_dist_attr_shard), + phi::errors::InvalidArgument( + "Invoke the r to s reshard function is not valid from %s to %s.", + tmp_result.dist_attr(), + out_dist_attr)); + r_to_s_func.Eval(dev_ctx, in, in_dist_attr_shard, &tmp_result); + + // Step 2: Same status from the input mesh to output mesh + SameStatusReshardFunction same_status_func; + PADDLE_ENFORCE( + same_status_func.IsSuitable(tmp_result, out_dist_attr), + phi::errors::InvalidArgument("Invoke the same status reshard function " + "is not valid from %s to %s.", + tmp_result.dist_attr(), + out_dist_attr)); + same_status_func.Eval(dev_ctx, tmp_result, out_dist_attr, out); +} + REGISTER_RESHARD_FUNC(RToSReshardFunction); +REGISTER_RESHARD_FUNC(RToSReshardFunctionCrossMesh); } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.h similarity index 71% rename from paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h rename to paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.h index 3a86ff0cfa074..4ca086525b0d2 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.h @@ -14,16 +14,24 @@ #pragma once -#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" namespace phi { namespace distributed { class RToSReshardFunction final : public ReshardFunction { public: - RToSReshardFunction() = default; - ~RToSReshardFunction() = default; + bool IsSuitable(const DistTensor& in, + const TensorDistAttr& out_dist_attr) override; + void Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) override; +}; + +class RToSReshardFunctionCrossMesh final : public ReshardFunction { + public: bool IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) override; diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.cc similarity index 97% rename from paddle/phi/core/distributed/auto_parallel/reshard_function.cc rename to paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.cc index 01824dd93bca1..04d47e4151d8a 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h similarity index 100% rename from paddle/phi/core/distributed/auto_parallel/reshard_function.h rename to paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc similarity index 77% rename from paddle/phi/core/distributed/auto_parallel/reshard_utils.cc rename to paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc index 57487ea7195bf..e7a1ec15da307 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc @@ -12,13 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #include "glog/logging.h" +#include "paddle/phi/backends/context_pool.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/store/store_utils.h" +#include "paddle/phi/core/enforce.h" namespace phi { namespace distributed { @@ -142,5 +145,40 @@ bool IsCurRankInMesh(const ProcessMesh& process_mesh) { process_ids.end()); } +// Only Input is DistTensor and current device id isn't in DistTensor's mesh +// will return true. +bool NeedComputationClipForPP( + const std::shared_ptr& tensor_impl) { + PADDLE_ENFORCE_EQ( + phi::distributed::DistTensor::classof(tensor_impl.get()), + true, + phi::errors::InvalidArgument( + "The input tensor of NeedComputationClipForPP should be " + "``phi::distributed::DistTensor``. " + "However it's %s", + typeid(tensor_impl.get()).name())); + return !IsCurRankInMesh( + std::static_pointer_cast(tensor_impl) + ->dist_attr() + .process_mesh()); +} + +Place GetDefaultPlace() { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (phi::backends::gpu::GetGPUDeviceCount() >= 0) { + return paddle::DefaultGPUPlace(); + } +#endif + return paddle::CPUPlace(); +} + +phi::DeviceContext* GetDistTensorDeviceContext( + phi::distributed::DistTensor* input) { + // TODO(GhostScreaming): pipeline parallel may create an undefined middle grad + // tensor. In such case, we need to get default place. + auto place = input && input->defined() ? input->place() : GetDefaultPlace(); + return phi::DeviceContextPool::Instance().Get(place); +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h similarity index 91% rename from paddle/phi/core/distributed/auto_parallel/reshard_utils.h rename to paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h index 652840976194f..15ecef53d0343 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h @@ -22,6 +22,9 @@ #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/visit_type.h" namespace phi { @@ -32,6 +35,14 @@ class ProcessMesh; bool IsCurRankInMesh(const ProcessMesh& process_mesh); +bool NeedComputationClipForPP( + const std::shared_ptr& tensor_impl); + +Place GetDefaultPlace(); + +phi::DeviceContext* GetDistTensorDeviceContext( + phi::distributed::DistTensor* input); + int64_t GetLocalRankInParticipate(const std::vector& process_ids, int64_t global_rank = -1); @@ -65,14 +76,14 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx, do { \ if (phi::CPUContext::classof(dev_ctx)) { \ VLOG(4) << "Call `" << #fn_name << "` in Resharding on GPU."; \ - PD_VISIT_FLOATING_AND_INTEGRAL_TYPES( \ + PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES( \ dtype, #fn_name, ([&] { \ fn_name(static_cast(*dev_ctx), \ __VA_ARGS__); \ })); \ } else if (phi::GPUContext::classof(dev_ctx)) { \ VLOG(4) << "Call `" << #fn_name << "` in Resharding on CPU."; \ - PD_VISIT_FLOATING_AND_INTEGRAL_TYPES( \ + PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES( \ dtype, #fn_name, ([&] { \ fn_name(static_cast(*dev_ctx), \ __VA_ARGS__); \ @@ -143,5 +154,12 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx, } while (0) #endif +#define RESHARD_SHORTCUT_IF_FALSE(expr) \ + do { \ + if (!(expr)) { \ + return false; \ + } \ + } while (0) + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc similarity index 61% rename from paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc rename to paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc index db8a26088ae45..55c22fb034555 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.h" #include "glog/logging.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #include "paddle/phi/kernels/all_gather_kernel.h" #include "paddle/phi/kernels/concat_kernel.h" #include "paddle/phi/kernels/split_kernel.h" @@ -28,28 +28,28 @@ namespace distributed { bool SToRReshardFunction::IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; const auto& in_dist_attr = in.dist_attr(); const auto& in_dims_mapping = in_dist_attr.dims_mapping(); - flag &= in_dist_attr.is_shard(); - flag &= out_dist_attr.is_replicated(); + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_shard()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_replicated()); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - flag &= (in_process_mesh.ndim() == 1); - flag &= (out_process_mesh.ndim() == 1); - flag &= (in_process_mesh == out_process_mesh); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh); // Ensure the tensor is balanced split, or we need send/recv rather than // all_gather int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; int64_t num_of_process = in_process_mesh.size(); - flag &= (in.local_dims()[static_cast(split_axis)] * num_of_process == - in.dims()[static_cast(split_axis)]); + RESHARD_SHORTCUT_IF_FALSE(in.local_dims()[static_cast(split_axis)] * + num_of_process == + in.dims()[static_cast(split_axis)]); - return flag; + return true; } void SToRReshardFunction::Eval(DeviceContext* dev_ctx, @@ -115,7 +115,52 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, } } +bool SToRReshardFunctionCrossMesh::IsSuitable( + const DistTensor& in, const TensorDistAttr& out_dist_attr) { + const auto& in_dist_attr = in.dist_attr(); + const auto& in_dims_mapping = in_dist_attr.dims_mapping(); + + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_shard()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_replicated()); + + const auto& in_process_mesh = in_dist_attr.process_mesh(); + const auto& out_process_mesh = out_dist_attr.process_mesh(); + + int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; + int64_t num_of_process = in_process_mesh.size(); + RESHARD_SHORTCUT_IF_FALSE(in.local_dims()[static_cast(split_axis)] * + num_of_process == + in.dims()[static_cast(split_axis)]); + + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.shape() == + out_process_mesh.shape()); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh != out_process_mesh); + + return true; +} + +void SToRReshardFunctionCrossMesh::Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) { + VLOG(3) << "Call SToRReshardFunctionCrossMesh Eval"; + const auto& out_process_mesh = out_dist_attr.process_mesh(); + + SameStatusReshardFunction same_status_func; + DistTensor tmp_result; + + TensorDistAttr tmp_dist_attr = in.dist_attr(); + tmp_dist_attr.set_process_mesh(out_process_mesh); + same_status_func.Eval(dev_ctx, in, tmp_dist_attr, &tmp_result); + + SToRReshardFunction s_to_r_func; + s_to_r_func.Eval(dev_ctx, tmp_result, out_dist_attr, out); +} + REGISTER_RESHARD_FUNC(SToRReshardFunction); +REGISTER_RESHARD_FUNC(SToRReshardFunctionCrossMesh); } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.h similarity index 68% rename from paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h rename to paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.h index 869b4ed9178de..ee4b65fade96e 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.h @@ -13,7 +13,8 @@ // limitations under the License. #pragma once -#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h" namespace phi { namespace distributed { @@ -32,5 +33,16 @@ class SToRReshardFunction final : public ReshardFunction { DistTensor* out) override; }; +class SToRReshardFunctionCrossMesh final : public ReshardFunction { + public: + bool IsSuitable(const DistTensor& in, + const TensorDistAttr& out_dist_attr) override; + + void Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) override; +}; + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc similarity index 91% rename from paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc rename to paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc index 3aafe1dc7fbee..cf454926093dd 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.h" #include "glog/logging.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #include "paddle/phi/kernels/all_to_all_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h" @@ -28,20 +28,19 @@ namespace distributed { bool SToSReshardFunction::IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; const auto& in_dist_attr = in.dist_attr(); - flag &= in_dist_attr.is_shard(); - flag &= out_dist_attr.is_shard(); + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_shard()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_shard()); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - flag &= (in_process_mesh.ndim() == 1); - flag &= (out_process_mesh.ndim() == 1); - flag &= (in_process_mesh == out_process_mesh); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh); - return flag; + return true; } void SToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.h similarity index 93% rename from paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.h rename to paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.h index b004ed919c192..383c7b522ad62 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" namespace phi { namespace distributed { diff --git a/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.cc similarity index 77% rename from paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc rename to paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.cc index ea32163d67f62..5740e14ae833a 100644 --- a/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h" #include #include "glog/logging.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #include "paddle/phi/core/distributed/store/store_utils.h" #include "paddle/phi/kernels/p_recv_kernel.h" #include "paddle/phi/kernels/p_send_kernel.h" @@ -46,18 +46,20 @@ std::vector GetUnionProcessIds(std::vector in_process_ids, bool SameStatusReshardFunction::IsSuitable( const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; const auto& in_dist_attr = in.dist_attr(); - flag &= (in_dist_attr.dims_mapping() == out_dist_attr.dims_mapping()); - flag &= (in_dist_attr.partial_dims() == out_dist_attr.partial_dims()); + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.dims_mapping() == + out_dist_attr.dims_mapping()); + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.partial_dims() == + out_dist_attr.partial_dims()); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - flag &= (in_process_mesh != out_process_mesh); - flag &= (in_process_mesh.shape() == out_process_mesh.shape()); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh != out_process_mesh); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.shape() == + out_process_mesh.shape()); - return flag; + return true; } void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx, @@ -80,6 +82,19 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx, // kernel execution. bool dynamic_shape = true; + // TODO(GhostScreaming): After cross-mesh reshard, current device may + // needs to execute next layer. When it construct next layer's backward + // graph, out->place() will be called such as in SetGradOutMeta method. As + // a result, out can't be undefined. Try to allocate a zero-memory value + // for out. Following send/recv will cover this empty DenseTensor + // construction. + VLOG(3) << "Same_status_reshard_function create an empty DenseTensor for " + "cross-mesh DistTensor."; + *(out->unsafe_mutable_value()) = + phi::DenseTensor(std::make_shared( + nullptr, 0, phi::distributed::GetDefaultPlace()), + in.value().meta()); + std::vector> p2p_pair; for (size_t i = 0; i < out_process_ids.size(); ++i) { p2p_pair.emplace_back( diff --git a/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h similarity index 93% rename from paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h rename to paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h index 38c044e083a09..7abaec5e8f6c3 100644 --- a/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" namespace phi { namespace distributed { diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index 338ee4b4bad17..2a5b336f34e25 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -33,6 +33,7 @@ #include "paddle/phi/backends/context_pool.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/distributed/nccl_tools.h" #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE #include "paddle/phi/core/distributed/xccl_comm_context.h" @@ -56,18 +57,19 @@ void CommContextManager::CreateNCCLCommContext( const std::string& unique_comm_key, int rank, int size, - const std::string& hash_key) { + const std::string& hash_key, + const P2POption* p2p_opt) { auto& comm_context_manager = CommContextManager::GetInstance(); if (comm_context_manager.Has(unique_comm_key)) { return; } ncclUniqueId nccl_id; - if (rank == 0) { + if (rank == 0 || (p2p_opt && p2p_opt->is_p2p_op && p2p_opt->p2p_rank == 0)) { PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetUniqueId(&nccl_id)); } std::string unique_key = "NCCLCommContext/" + unique_comm_key + hash_key; - if (rank == 0) { + if (rank == 0 || (p2p_opt && p2p_opt->is_p2p_op && p2p_opt->p2p_rank == 0)) { std::vector nccl_id_wrapper( reinterpret_cast(&nccl_id), reinterpret_cast(&nccl_id) + NCCL_UNIQUE_ID_BYTES); @@ -77,6 +79,14 @@ void CommContextManager::CreateNCCLCommContext( std::memcpy(&nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size()); } + if (p2p_opt) { + rank = p2p_opt->rank; + size = p2p_opt->num_ranks; + } + VLOG(3) << "init NCCLCommContext rank: " << rank << ", size: " << size + << ", unique_comm_key: " << unique_comm_key + << ", unique_key: " << unique_key + << ", nccl_id: " << SerializeNCCLUniqueId(nccl_id); auto nccl_comm_context = std::make_unique(rank, size, nccl_id); if (CommContextManager::device_id != -1) { diff --git a/paddle/phi/core/distributed/comm_context_manager.h b/paddle/phi/core/distributed/comm_context_manager.h index 69e58a96e18e1..2229786db3855 100644 --- a/paddle/phi/core/distributed/comm_context_manager.h +++ b/paddle/phi/core/distributed/comm_context_manager.h @@ -30,6 +30,13 @@ namespace phi { namespace distributed { +struct P2POption { + bool is_p2p_op; + int p2p_rank; + int num_ranks; + int rank; +}; + class Store; class CommContextManager { @@ -62,7 +69,8 @@ class CommContextManager { const std::string& unique_comm_key, int rank, int size, - const std::string& hash_key = ""); + const std::string& hash_key = "", + const P2POption* opt = nullptr); #endif #if defined(PADDLE_WITH_GLOO) diff --git a/paddle/phi/core/distributed/comm_task.h b/paddle/phi/core/distributed/comm_task.h new file mode 100644 index 0000000000000..3673c7a9e21aa --- /dev/null +++ b/paddle/phi/core/distributed/comm_task.h @@ -0,0 +1,158 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/phi/core/distributed/utils.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/macros.h" + +#if defined(PADDLE_WITH_RCCL) +#include "paddle/phi/backends/dynload/rccl.h" +#endif +#if defined(PADDLE_WITH_NCCL) +#include "paddle/phi/backends/dynload/nccl.h" +#endif + +namespace phi { +namespace distributed { + +class Store; +class CommTask { + public: + CommTask(const std::string& backend = "", + const phi::Place& place = phi::Place(), + int rank = -1, + int size = 0, + int gid = 0, + uint64_t seq = 0, + int64_t numel = 0, + ncclComm_t nccl_comm = nullptr, + gpuStream_t nccl_stream = nullptr, + CommType comm_type = CommType::UNKNOWN) + : backend_(backend), + place_(place), + rank_(rank), + size_(size), + gid_(gid), + seq_(seq), + numel_(numel), + nccl_comm_(nccl_comm), + nccl_stream_(nccl_stream), + comm_type_(comm_type) { + const char* global_rank = std::getenv("PADDLE_TRAINER_ID"); + PADDLE_ENFORCE_NOT_NULL( + global_rank, + phi::errors::NotFound( + "The environment variable 'PADDLE_TRAINER_ID' cannot be found.")); + global_rank_ = std::atoi(global_rank); + } + virtual ~CommTask() = default; + + std::string UniqueKey() { + return "op:" + CommTypeToString(comm_type_) + + ",gid:" + std::to_string(gid_) + ",seq:" + std::to_string(seq_); + } + + std::string GetBackend() { return backend_; } + phi::Place GetPlace() { return place_; } + int GetGlobalRank() { return global_rank_; } + int GetRank() { return rank_; } + int GetSize() { return size_; } + int GetGid() { return gid_; } + int64_t GetNumel() { return numel_; } + uint64_t GetSeq() { return seq_; } + CommType GetCommType() { return comm_type_; } + bool GetTraceUpdated() { return start_trace_updated_; } + void SetTraceUpdated() { start_trace_updated_ = true; } + std::chrono::time_point GetStartTime() { + return start_time_; + } + std::shared_ptr GetStore() { return store_; } + void SetStore(std::shared_ptr store) { store_ = store; } + + ncclComm_t nccl_comm() { return nccl_comm_; } + gpuStream_t nccl_stream() { return nccl_stream_; } + + virtual std::string GetTraceMsg() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return ""; + } + virtual void StartRecord() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return; + } + virtual void EndRecord() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return; + } + + virtual std::string GetCommErrors() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return ""; + } + virtual bool IsStarted() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return false; + } + virtual bool IsTimeout() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return false; + } + virtual bool IsCompleted() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return false; + } + virtual void AbortComm() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return; + } + + protected: + std::string backend_; + phi::Place place_; + int global_rank_; + int rank_; + int size_; + int gid_; + uint64_t seq_{0}; + int64_t numel_; + ncclComm_t nccl_comm_; + gpuStream_t nccl_stream_; + CommType comm_type_; + bool start_trace_updated_{false}; + + bool completed_ = false; + bool aborted_{false}; + std::chrono::time_point start_time_; + std::shared_ptr store_; + + private: + DISABLE_COPY_AND_ASSIGN(CommTask); +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/comm_task_manager.cc b/paddle/phi/core/distributed/comm_task_manager.cc new file mode 100644 index 0000000000000..37083119b59f5 --- /dev/null +++ b/paddle/phi/core/distributed/comm_task_manager.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(PADDLE_WITH_GLOO) +#include + +#include "paddle/phi/core/distributed/gloo_comm_context.h" +#include "paddle/phi/core/distributed/gloo_utils.h" +#include "paddle/phi/core/distributed/store/gloo_store.h" +#endif + +#include "paddle/phi/core/distributed/comm_context_manager.h" + +#include +#include + +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/distributed/store/store.h" +#include "paddle/phi/core/enforce.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/phi/core/distributed/comm_task_manager.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/distributed/trace_utils.h" +#endif + +namespace phi { +namespace distributed { + +std::thread CommTaskManager::comm_task_loop_thread_; +const int64_t CommTaskManager::loop_thread_sleep_millis = 10000; + +std::atomic CommTaskManager::terminated_; +std::mutex CommTaskManager::comm_task_list_mutex_; +std::condition_variable CommTaskManager::comm_task_list_cv_; +std::list> CommTaskManager::comm_task_list_; +std::unordered_map> + CommTaskManager::init_comm_task_map_; +std::unordered_map> + CommTaskManager::start_comm_task_map_; + +CommTaskManager::CommTaskManager() { + terminated_.store(false); + comm_task_loop_thread_ = std::thread(&CommTaskManager::CommTaskLoop, this); + LOG(INFO) << "CommTaskManager init success"; +} +CommTaskManager::~CommTaskManager() { + terminated_.store(true); + + if (comm_task_loop_thread_.joinable()) { + comm_task_loop_thread_.join(); + comm_task_list_cv_.notify_one(); + } + LOG(INFO) << "CommTaskManager destruct success."; +} + +void CommTaskManager::CommTaskEnqueue(std::shared_ptr comm_task) { + if (!terminated_.load()) { + std::lock_guard lock(comm_task_list_mutex_); + comm_task_list_.emplace_back(std::move(comm_task)); + } +} + +void CommTaskManager::CommTaskLoop() { + bool done = false; + while (!terminated_.load() || !done) { + std::unique_lock lock(comm_task_list_mutex_); + comm_task_list_cv_.wait_for( + lock, + std::chrono::milliseconds(loop_thread_sleep_millis), + [&]() -> bool { return terminated_.load(); }); + for (auto iter = comm_task_list_.begin(); iter != comm_task_list_.end();) { + auto task = *iter; + if (task->IsTimeout()) { + if (!task->IsStarted()) { + LOG(ERROR) << "Find timeout init but not start task: " + << task->GetTraceMsg() << ",comm:" << task->nccl_comm() + << ",stream:" << task->nccl_stream(); + std::string task_key = task->UniqueKey(); + init_comm_task_map_[task_key] = task; + } else if (!task->IsCompleted()) { + LOG(ERROR) << "Find timeout start but not finish task: " + << task->GetTraceMsg() << ",comm:" << task->nccl_comm() + << ",stream:" << task->nccl_stream(); + std::string task_key = task->UniqueKey(); + start_comm_task_map_[task_key] = task; + } + iter = comm_task_list_.erase(iter); + } else { + ++iter; + } + } + + for (auto iter = init_comm_task_map_.begin(); + iter != init_comm_task_map_.end();) { + auto task = iter->second; + if (task->IsStarted()) { + std::string task_key = task->UniqueKey(); + start_comm_task_map_[task_key] = task; + iter = init_comm_task_map_.erase(iter); + LOG(INFO) << "Start timeout task: " << task->GetTraceMsg(); + } else { + ++iter; + } + } + + for (auto iter = start_comm_task_map_.begin(); + iter != start_comm_task_map_.end();) { + auto task = iter->second; + if (task->IsCompleted()) { + iter = start_comm_task_map_.erase(iter); + LOG(INFO) << "Finish timeout task: " << task->GetTraceMsg(); + } else { + ++iter; + } + } + + if (comm_task_list_.empty() && init_comm_task_map_.empty() && + start_comm_task_map_.empty()) { + done = true; + } + } +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/comm_task_manager.h b/paddle/phi/core/distributed/comm_task_manager.h new file mode 100644 index 0000000000000..58be0026dd072 --- /dev/null +++ b/paddle/phi/core/distributed/comm_task_manager.h @@ -0,0 +1,72 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/phi/core/distributed/comm_context.h" +#include "paddle/phi/core/distributed/comm_task.h" +#include "paddle/phi/core/macros.h" + +namespace phi { +namespace distributed { + +enum ErrorHandlingMode { NoHandling = 0, TearDown = 1 }; + +class Store; + +class CommTaskManager { + public: + CommTaskManager(); + ~CommTaskManager(); + + public: + static CommTaskManager& GetInstance() { + static CommTaskManager instance; + return instance; + } + + void CommTaskEnqueue(std::shared_ptr comm_task); + + private: + void CommTaskLoop(); + + static std::thread comm_task_loop_thread_; + static const int64_t loop_thread_sleep_millis; + + static std::atomic terminated_; + + static std::mutex comm_task_list_mutex_; + static std::condition_variable comm_task_list_cv_; + static std::list> comm_task_list_; + // not start task + static std::unordered_map> + init_comm_task_map_; + // start but not finish task + static std::unordered_map> + start_comm_task_map_; + std::shared_ptr store_; + bool store_error_{false}; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/nccl_comm_context.cc b/paddle/phi/core/distributed/nccl_comm_context.cc index bd49f0cff1708..d1d92c98fb0fd 100644 --- a/paddle/phi/core/distributed/nccl_comm_context.cc +++ b/paddle/phi/core/distributed/nccl_comm_context.cc @@ -19,6 +19,7 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/distributed/check/nccl_dynamic_check.h" #include "paddle/phi/core/distributed/check/static_check.h" +#include "paddle/phi/core/distributed/nccl_tools.h" #include "paddle/phi/core/distributed/utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/utils/data_type.h" @@ -31,9 +32,9 @@ constexpr bool FLAGS_enable_nccl_dynamic_check = false; NCCLCommContext::NCCLCommContext(int rank, int size, ncclUniqueId nccl_id) : CommContext(rank, size) { - PADDLE_ENFORCE_GPU_SUCCESS( + NCCL_CHECK( phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_)); - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetVersion(&nccl_version_)); + NCCL_CHECK(phi::dynload::ncclGetVersion(&nccl_version_)); } int NCCLCommContext::GetNcclVersion() { return nccl_version_; } @@ -76,14 +77,13 @@ void NCCLCommContext::Broadcast(phi::DenseTensor* out_tensor, if (FLAGS_enable_nccl_dynamic_check) { NCCLDynamicCheck::CheckShape(*out_tensor, root, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclBroadcast(in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - ToNCCLDataType(in_tensor.type()), - root, - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclBroadcast(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToNCCLDataType(in_tensor.type()), + root, + nccl_comm_, + stream)); } void NCCLCommContext::AllGather(phi::DenseTensor* out_tensor, @@ -100,13 +100,12 @@ void NCCLCommContext::AllGather(phi::DenseTensor* out_tensor, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclAllGather(in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - ToNCCLDataType(in_tensor.type()), - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclAllGather(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToNCCLDataType(in_tensor.type()), + nccl_comm_, + stream)); } void NCCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, @@ -123,14 +122,13 @@ void NCCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclReduceScatter(in_tensor.data(), - out_tensor->data(), - out_tensor->numel(), - ToNCCLDataType(in_tensor.type()), - reduce_type, - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclReduceScatter(in_tensor.data(), + out_tensor->data(), + out_tensor->numel(), + ToNCCLDataType(in_tensor.type()), + reduce_type, + nccl_comm_, + stream)); } void NCCLCommContext::Send(const phi::DenseTensor& in_tensor, @@ -143,13 +141,12 @@ void NCCLCommContext::Send(const phi::DenseTensor& in_tensor, NCCLDynamicCheck::CheckShape(in_tensor, rank_, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclSend(in_tensor.data(), - count, - ToNCCLDataType(in_tensor.dtype()), - peer, - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclSend(in_tensor.data(), + count, + ToNCCLDataType(in_tensor.dtype()), + peer, + nccl_comm_, + stream)); VLOG(3) << "rank " << GetRank() << " send " << phi::product(in_tensor.dims()) << " to " << peer; } @@ -163,13 +160,12 @@ void NCCLCommContext::Recv(phi::DenseTensor* out_tensor, NCCLDynamicCheck::CheckShape(*out_tensor, peer, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclRecv(out_tensor->data(), - count, - ToNCCLDataType(out_tensor->dtype()), - peer, - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclRecv(out_tensor->data(), + count, + ToNCCLDataType(out_tensor->dtype()), + peer, + nccl_comm_, + stream)); VLOG(3) << "rank " << GetRank() << " recv " << phi::product(out_tensor->dims()) << " from " << peer; } @@ -189,14 +185,13 @@ void NCCLCommContext::AllReduce(phi::DenseTensor* out_tensor, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclAllReduce(in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - ToNCCLDataType(in_tensor.type()), - reduce_type, - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclAllReduce(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToNCCLDataType(in_tensor.type()), + reduce_type, + nccl_comm_, + stream)); } void NCCLCommContext::Reduce(phi::DenseTensor* out_tensor, @@ -215,15 +210,14 @@ void NCCLCommContext::Reduce(phi::DenseTensor* out_tensor, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclReduce(in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - ToNCCLDataType(in_tensor.type()), - reduce_type, - root, - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclReduce(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToNCCLDataType(in_tensor.type()), + reduce_type, + root, + nccl_comm_, + stream)); } void NCCLCommContext::GroupStart() { diff --git a/paddle/phi/core/distributed/nccl_comm_task.cc b/paddle/phi/core/distributed/nccl_comm_task.cc new file mode 100644 index 0000000000000..f82f39c1954a3 --- /dev/null +++ b/paddle/phi/core/distributed/nccl_comm_task.cc @@ -0,0 +1,219 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/nccl_comm_task.h" + +#include "gflags/gflags.h" +#include "glog/logging.h" + +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/distributed/nccl_tools.h" +#include "paddle/phi/core/distributed/trace_utils.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { +namespace distributed { + +NCCLCommTask::NCCLCommTask(const phi::Place& place, + int rank, + int size, + int gid, + uint64_t seq, + int64_t numel, + bool sync_op, + bool use_calc_stream, + ncclComm_t nccl_comm, + gpuStream_t stream, + CommType comm_type, + int64_t timeout) + : CommTask("NCCL", + place, + rank, + size, + gid, + seq, + numel, + nccl_comm, + stream, + comm_type), + sync_op_(sync_op), + use_calc_stream_(use_calc_stream) { + start_trace_updated_ = false; + start_event_created_ = false; + end_event_created_ = false; + start_time_ = std::chrono::steady_clock::now(); + timeout_ = std::chrono::milliseconds(timeout); +} + +void NCCLCommTask::StartRecord() { + backends::gpu::GPUDeviceGuard guard(place_.device); + if (!start_event_created_) { +#ifdef PADDLE_WITH_CUDA + CUDA_CHECK(cudaEventCreateWithFlags(&nccl_start_event_, cuda_event_flags_)); +#else // PADDLE_WITH_HIP + HIP_CHECK(hipEventCreateWithFlags(&nccl_start_event_, hip_event_flags_)); +#endif + start_event_created_ = true; + } +#ifdef PADDLE_WITH_CUDA + CUDA_CHECK(cudaEventRecord(nccl_start_event_, nccl_stream_)); +#else // PADDLE_WITH_HIP + HIP_CHECK(hipEventRecord(nccl_start_event_, nccl_stream_)); +#endif +} +void NCCLCommTask::EndRecord() { + backends::gpu::GPUDeviceGuard guard(place_.device); + if (!end_event_created_) { +#ifdef PADDLE_WITH_CUDA + CUDA_CHECK(cudaEventCreateWithFlags(&nccl_end_event_, cuda_event_flags_)); +#else // PADDLE_WITH_HIP + HIP_CHECK(hipEventCreateWithFlags(&nccl_end_event_, hip_event_flags_)); +#endif + end_event_created_ = true; + } +#ifdef PADDLE_WITH_CUDA + CUDA_CHECK(cudaEventRecord(nccl_end_event_, nccl_stream_)); +#else // PADDLE_WITH_HIP + HIP_CHECK(hipEventRecord(nccl_end_event_, nccl_stream_)); +#endif +} + +bool NCCLCommTask::CudaEventQuery(gpuEvent_t event) { +#ifdef PADDLE_WITH_CUDA + cudaError_t ret = cudaEventQuery(event); + if (ret == cudaSuccess) { + return true; + } else if (ret != cudaErrorNotReady) { + CUDA_CHECK(ret); + } else { + // ignore and clear the error if not ready + CUDA_CHECK(cudaGetLastError()); + } +#else // PADDLE_WITH_HIP + hipError_t ret = hipEventQuery(event); + if (ret == hipSuccess) { + return true; + } else if (ret != hipErrorNotReady) { + HIP_CHECK(ret); + } else { + // ignore and clear the error if not ready + HIP_CHECK(hipGetLastError()); + } +#endif + return false; +} + +std::string GetNCCLErrorDetail(ncclResult_t result) { + std::string detail; + std::string last_error; +#ifdef ENABLE_NCCL_GET_LAST_ERROR + last_error = + ", Last error: " + std::string(phi::dynload::ncclGetLastError(NULL)); +#endif + switch (result) { + case ncclUnhandledCudaError: + detail = "ncclUnhandledCudaError: Call to CUDA function failed."; + break; + case ncclSystemError: + detail = + "ncclSystemError: System call (e.g. socket, malloc) or external " + "library call failed or device error. "; +#ifndef NCCL_REMOTE_ERROR + // Before ncclRemoteError was created, unexpected remote disconnect was + // categorized as ncclSystemError + detail += "It can be also caused by unexpected exit of a remote peer."; +#endif + break; + case ncclInternalError: + detail = "ncclInternalError: Internal check failed."; + break; + case ncclInvalidArgument: + detail = "ncclInvalidArgument: Invalid value for an argument."; + break; + case ncclInvalidUsage: + detail = + "ncclInvalidUsage: This usually reflects invalid usage of NCCL " + "library."; + break; +#ifdef NCCL_REMOTE_ERROR + case ncclRemoteError: + detail = + "ncclRemoteError: A call failed possibly due to a network error or a " + "remote process exiting prematurely."; + break; +#endif + default: + detail = "Unknown NCCL error!"; + } + return detail + last_error; +} + +std::string NCCLCommTask::GetCommErrors() { + std::unique_lock lock(mutex_); + if (!comm_error_.empty()) { + return comm_error_; + } + + ncclResult_t nccl_async_error; + NCCL_CHECK( + phi::dynload::ncclCommGetAsyncError(nccl_comm_, &nccl_async_error)); + if (nccl_async_error != ncclSuccess) { + comm_error_ = + "\n\t Find nccl comm error: " + GetNCCLErrorDetail(nccl_async_error); + } + return comm_error_; +} + +bool NCCLCommTask::IsStarted() { return CudaEventQuery(nccl_start_event_); } + +bool NCCLCommTask::IsCompleted() { return CudaEventQuery(nccl_end_event_); } + +bool NCCLCommTask::IsTimeout() { + auto current_timepoint = std::chrono::steady_clock::now(); + return std::chrono::duration_cast( + current_timepoint - start_time_) >= timeout_; +} + +void NCCLCommTask::AbortComm() { + std::unique_lock lock(mutex_); + if (aborted_) { + return; + } + NCCL_CHECK(phi::dynload::ncclCommAbort(nccl_comm_)); + + aborted_ = true; + nccl_comm_ = nullptr; + return; +} + +std::string NCCLCommTask::GetTraceMsg() { + auto current_timepoint = std::chrono::steady_clock::now(); + auto time_elapsed = std::chrono::duration_cast( + current_timepoint - start_time_); + return "op:" + CommTypeToString(comm_type_) + ",gid:" + std::to_string(gid_) + + ",seq:" + std::to_string(seq_) + + ",started:" + std::to_string(IsStarted()) + + ",completed:" + std::to_string(IsCompleted()) + + ",global_rank:" + std::to_string(global_rank_) + + ",local_rank:" + std::to_string(rank_) + + ",size:" + std::to_string(size_) + ",numel:" + std::to_string(numel_) + + ",sync_op:" + std::to_string(sync_op_) + + ",use_calc_stream:" + std::to_string(use_calc_stream_) + + ",timeout:" + std::to_string(timeout_.count()) + + ",is_timeout:" + std::to_string(IsTimeout()) + + ",time_elapsed:" + std::to_string(time_elapsed.count()); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/nccl_comm_task.h b/paddle/phi/core/distributed/nccl_comm_task.h new file mode 100644 index 0000000000000..9fe71670c2f88 --- /dev/null +++ b/paddle/phi/core/distributed/nccl_comm_task.h @@ -0,0 +1,89 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/backends/gpu/gpu_decls.h" +#include "paddle/phi/core/distributed/comm_context.h" +#include "paddle/phi/core/distributed/comm_task.h" +#include "paddle/phi/core/distributed/utils.h" +#include "paddle/phi/core/macros.h" + +#if defined(PADDLE_WITH_RCCL) +#include "paddle/phi/backends/dynload/rccl.h" +#else +#include "paddle/phi/backends/dynload/nccl.h" +#endif + +namespace phi { +class DenseTensor; +namespace distributed { + +static int64_t DefaultTimeout = 30 * 60 * 1000; + +class NCCLCommTask : public CommTask { + public: + NCCLCommTask(const phi::Place& place = phi::Place(), + int rank = -1, + int size = 0, + int gid = 0, + uint64_t seq = 0, + int64_t numel = 0, + bool sync_op = true, + bool use_calc_stream = false, + ncclComm_t = nullptr, + gpuStream_t = nullptr, + CommType comm_type = CommType::UNKNOWN, + int64_t timeout = DefaultTimeout); + ~NCCLCommTask() = default; + + // check whether the nccl kernel started + bool IsStarted() override; + bool IsTimeout() override; + bool IsCompleted() override; + + std::string GetTraceMsg() override; + std::string GetCommErrors() override; + void AbortComm() override; + + void StartRecord(); + void EndRecord(); + + bool CudaEventQuery(gpuEvent_t event); + + protected: + std::mutex mutex_; + std::chrono::milliseconds timeout_; + +#ifdef PADDLE_WITH_CUDA + unsigned int cuda_event_flags_ = cudaEventDisableTiming; +#else // PADDLE_WITH_HIP + unsigned int hip_event_flags_ = hipEventDisableTiming; +#endif + + bool sync_op_; + bool use_calc_stream_; + + bool start_event_created_; + bool end_event_created_; + gpuEvent_t nccl_start_event_; + gpuEvent_t nccl_end_event_; + + std::string comm_error_; + + private: + DISABLE_COPY_AND_ASSIGN(NCCLCommTask); +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/fluid/distributed/collective/nccl_tools.cc b/paddle/phi/core/distributed/nccl_tools.cc similarity index 51% rename from paddle/fluid/distributed/collective/nccl_tools.cc rename to paddle/phi/core/distributed/nccl_tools.cc index 940c8d47ccb88..e419cfca905fa 100644 --- a/paddle/fluid/distributed/collective/nccl_tools.cc +++ b/paddle/phi/core/distributed/nccl_tools.cc @@ -12,14 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/distributed/collective/nccl_tools.h" +#include "paddle/phi/core/distributed/nccl_tools.h" #include #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" -namespace paddle { +#if NCCL_VERSION_CODE >= 21300 +#define ENABLE_NCCL_GET_LAST_ERROR +#define NCCL_REMOTE_ERROR +#endif + +namespace phi { namespace distributed { ncclRedOp_t ToNCCLRedType(ReduceOp reduction) { @@ -47,5 +52,43 @@ std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID) { return oss.str(); } +std::string NCCLDTypeToString(ncclDataType_t dtype) { +#define PD_NCCL_DTYPE_TO_STR(__nccl_dtype, __str_dtype) \ + if (dtype == __nccl_dtype) return __str_dtype; + PD_NCCL_DTYPE_TO_STR(ncclFloat, "float32"); + PD_NCCL_DTYPE_TO_STR(ncclFloat32, "float32"); + PD_NCCL_DTYPE_TO_STR(ncclHalf, "float16"); + PD_NCCL_DTYPE_TO_STR(ncclFloat16, "float16"); +#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 + PD_NCCL_DTYPE_TO_STR(ncclBfloat16, "bfloat16"); +#endif + PD_NCCL_DTYPE_TO_STR(ncclDouble, "float64"); + PD_NCCL_DTYPE_TO_STR(ncclFloat64, "float64"); + + PD_NCCL_DTYPE_TO_STR(ncclInt8, "int8"); + PD_NCCL_DTYPE_TO_STR(ncclChar, "int8"); + PD_NCCL_DTYPE_TO_STR(ncclUint8, "uint8"); + PD_NCCL_DTYPE_TO_STR(ncclInt32, "int32"); + PD_NCCL_DTYPE_TO_STR(ncclInt, "int32"); + PD_NCCL_DTYPE_TO_STR(ncclUint32, "uint32"); + PD_NCCL_DTYPE_TO_STR(ncclInt64, "int64"); + PD_NCCL_DTYPE_TO_STR(ncclUint64, "uint64"); + +#undef PD_NCCL_DTYPE_TO_STR + PADDLE_THROW(phi::errors::InvalidArgument( + "This datatype %d in nccl is not supported.", static_cast(dtype))); +} + +std::string NCCLRedTypeToString(ncclRedOp_t op) { + if (op == ncclSum) return "SUM"; + if (op == ncclProd) return "PROD"; + if (op == ncclMin) return "MIN"; + if (op == ncclMax) return "MAX"; +#if NCCL_VERSION_CODE >= 21000 + if (op == ncclAvg) return "AVG"; +#endif + return "UDF_" + std::to_string(op); +} + } // namespace distributed -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/core/distributed/nccl_tools.h b/paddle/phi/core/distributed/nccl_tools.h new file mode 100644 index 0000000000000..0ab380a417783 --- /dev/null +++ b/paddle/phi/core/distributed/nccl_tools.h @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/distributed/types.h" + +#ifdef PADDLE_WITH_RCCL +#include +#include "paddle/phi/backends/dynload/rccl.h" +#else +#include +#include "paddle/phi/backends/dynload/nccl.h" +#endif + +namespace phi { +namespace distributed { + +#define NCCL_CHECK(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) { \ + PADDLE_THROW( \ + phi::errors::External("Failed, NCCL error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + phi::dynload::ncclGetErrorString(r))); \ + } \ + } while (0) + +#ifdef PADDLE_WITH_NCCL +#define CUDA_CHECK(expr) \ + do { \ + cudaError_t r = expr; \ + if (r != cudaSuccess) { \ + PADDLE_THROW(phi::errors::External("Failed, cuda error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(r))); \ + } \ + } while (0) +#else // PADDLE_WITH_RCCL +#define HIP_CHECK(expr) \ + do { \ + hipError_t r = expr; \ + if (r != hipSuccess) { \ + PADDLE_THROW(phi::errors::External("Failed, hip error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + hipGetErrorString(r))); \ + } \ + } while (0) +#endif + +ncclRedOp_t ToNCCLRedType(ReduceOp reduction); + +std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID); + +std::string NCCLDTypeToString(ncclDataType_t dtype); + +std::string NCCLRedTypeToString(ncclRedOp_t op); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/store/store.cc b/paddle/phi/core/distributed/store/store.cc index 7e7db8895b99f..5987b694b4e51 100644 --- a/paddle/phi/core/distributed/store/store.cc +++ b/paddle/phi/core/distributed/store/store.cc @@ -28,6 +28,11 @@ std::vector Store::get(const std::string& key) { errors::InvalidArgument("Implement the get method in the subclass.")); } +bool Store::check(const std::string& key) { + PADDLE_THROW( + errors::InvalidArgument("Implement the get method in the subclass.")); +} + void Store::wait(const std::string& key) { PADDLE_THROW( errors::InvalidArgument("Implement the wait method in the subclass.")); diff --git a/paddle/phi/core/distributed/store/store.h b/paddle/phi/core/distributed/store/store.h index fa509586eefdf..4ecd4cb8b5d99 100644 --- a/paddle/phi/core/distributed/store/store.h +++ b/paddle/phi/core/distributed/store/store.h @@ -29,6 +29,7 @@ class Store { virtual int64_t add(const std::string& key, int64_t value); virtual std::vector get(const std::string& key); + virtual bool check(const std::string& key); virtual void wait(const std::string& key); virtual void set(const std::string& key, const std::vector& value); diff --git a/paddle/phi/core/distributed/store/tcp_store.cc b/paddle/phi/core/distributed/store/tcp_store.cc index 6fbe2aa6761e2..46af21fa94356 100644 --- a/paddle/phi/core/distributed/store/tcp_store.cc +++ b/paddle/phi/core/distributed/store/tcp_store.cc @@ -110,6 +110,19 @@ void MasterDaemon::_do_get(SocketType socket) { tcputils::send_vector(socket, value); } +void MasterDaemon::_do_check(SocketType socket) { + std::string key = tcputils::receive_string(socket); + VLOG(4) << "MasterDaemon::_do_check key(" << key << ") " + << GetSockName(socket); + + auto iter = _store.find(key); + if (iter != _store.end()) { + tcputils::send_value(socket, ReplyType::READY); + } else { + tcputils::send_value(socket, ReplyType::NOT_READY); + } +} + #ifndef _WIN32 void MasterDaemon::InitControlFd() { PADDLE_ENFORCE_NE( @@ -190,6 +203,9 @@ void MasterDaemon::ProcessCommands(std::vector* p_fds) { case Command::GET: _do_get(fds[i].fd); break; + case Command::CHECK: + _do_check(fds[i].fd); + break; case Command::SET: _do_set(fds[i].fd); break; @@ -420,6 +436,17 @@ std::vector TCPStore::get(const std::string& key) { return _client->receive_vector(); } +bool TCPStore::check(const std::string& key) { + _client->send_command_for_key(Command::CHECK, _key_prefix + key); + VLOG(3) << "TCPStore check."; + auto response = _client->receive_value(); + if (response == ReplyType::READY) { + return true; + } else { + return false; + } +} + void TCPStore::wait(const std::string& key) { ReplyType reply; // NOLINT VLOG(7) << "TCPStore wait."; diff --git a/paddle/phi/core/distributed/store/tcp_store.h b/paddle/phi/core/distributed/store/tcp_store.h index 0f17bc9b58bd4..4cc3a1933bd5d 100644 --- a/paddle/phi/core/distributed/store/tcp_store.h +++ b/paddle/phi/core/distributed/store/tcp_store.h @@ -37,8 +37,8 @@ namespace phi { namespace distributed { -enum class ReplyType { WAITING, STOP_WAIT }; -enum class Command { ADD, GET, SET, WAIT, STOP }; +enum class ReplyType { WAITING, STOP_WAIT, READY, NOT_READY }; +enum class Command { ADD, GET, CHECK, SET, WAIT, STOP }; namespace detail { @@ -59,6 +59,7 @@ class MasterDaemon { void _do_add(SocketType socket); void _do_wait(SocketType socket); void _do_get(SocketType socket); + void _do_check(SocketType socket); void _do_set(SocketType socket); void _notify_waiting_sockets(const std::string&); SocketType _listen_socket; @@ -130,6 +131,7 @@ class TCPStore : public Store { int64_t add(const std::string& key, int64_t value) override; std::vector get(const std::string& key) override; + bool check(const std::string& key) override; void wait(const std::string& key) override; void set(const std::string& key, const std::vector& value) override; diff --git a/paddle/phi/core/distributed/trace_utils.h b/paddle/phi/core/distributed/trace_utils.h new file mode 100644 index 0000000000000..7a34055a987bc --- /dev/null +++ b/paddle/phi/core/distributed/trace_utils.h @@ -0,0 +1,187 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/distributed/store/store.h" +#include "paddle/utils/string/split.h" + +namespace phi { +namespace distributed { + +enum TraceEventType { + TraceEventStart, + TraceEventEnd, +}; + +using TraceMap = + std::map>>; + +inline std::string GetTraceStartKey(const std::string& backend, + int rank, + int gid) { + return backend + "_" + std::to_string(rank) + "_" + std::to_string(gid) + + "_trace_start"; +} + +inline std::string GetTraceEndKey(const std::string& backend, + int rank, + int gid) { + return backend + "_" + std::to_string(rank) + "_" + std::to_string(gid) + + "_trace_end"; +} + +inline std::string GetExceptionMsgFromExceptionPtr( + const std::exception_ptr& exception_ptr) { + if (exception_ptr == nullptr) { + return "No exception found"; + } + try { + std::rethrow_exception(exception_ptr); + } catch (const std::exception& e) { + return e.what(); + } catch (...) { + return "Unknown exception type"; + } +} + +inline bool UpdateTraceMsg(std::shared_ptr store, + const std::string& key, + uint64_t seq, + const std::string& comm_type) { + std::vector value(comm_type.size() + sizeof(seq) + 1); + memcpy(value.data(), &seq, sizeof(seq)); + memcpy(value.data() + sizeof(seq), comm_type.data(), comm_type.size()); + try { + store->set(key, value); + return true; + } catch (...) { + LOG(ERROR) << "Store is down while updating trace msg, with seq: " << seq + << ", key " << key; + return false; + } +} + +inline bool ParseTraceValue(std::shared_ptr store, + const std::string& key, + uint64_t* seq, + std::string* comm_type) { + try { + std::vector value = store->get(key); + memcpy(seq, value.data(), sizeof(*seq)); + std::string type_value( + reinterpret_cast(value.data() + sizeof(*seq))); + *comm_type = type_value; + return true; + } catch (...) { + LOG(ERROR) << "Store is down while parsing trace value, with key: " << key; + return false; + } +} + +inline std::string RanksToString(const std::vector& ranks) { + std::string result; + for (int rank : ranks) { + if (result.empty()) { + result += std::to_string(rank); + } else { + result += ", " + std::to_string(rank); + } + } + return result; +} + +inline std::string AnalyzeTraceMsg(const TraceMap& trace_map, int gid) { + uint64_t lag_seq = trace_map.begin()->first; + std::vector start_ranks; + std::vector end_ranks; + for (auto& p : trace_map.begin()->second) { + if (p.second.second == TraceEventStart) { + start_ranks.emplace_back(p.first); + } else { + end_ranks.emplace_back(p.first); + } + } + + std::string result = "\n\t The ranks that has desync problem are: "; + if (start_ranks.size()) { + result += "[" + RanksToString(start_ranks) + + "] joined but do not finish collective seq: " + + std::to_string(lag_seq) + " in group_id: " + std::to_string(gid); + } + if (end_ranks.size()) { + result += ", ranks [" + RanksToString(end_ranks) + + "] finished collective seq: " + std::to_string(lag_seq) + + ", but didnt join seq: " + std::to_string(lag_seq + 1) + + " in group_id: " + std::to_string(gid); + } + return result; +} + +inline std::string GenerateTraceMsg(std::shared_ptr store, + const std::string& backend, + int curr_rank, + int group_id, + int world_size) { + std::string result; + TraceMap trace_map; + + uint64_t curr_seq; + std::string curr_comm_type; + + for (int rank = 0; rank < world_size; ++rank) { + uint64_t seq_start = 0; + { + std::string trace_start_key = GetTraceStartKey(backend, rank, group_id); + if (!store->check(trace_start_key)) { + continue; + } + + std::string comm_type; + if (!ParseTraceValue(store, trace_start_key, &seq_start, &comm_type)) { + return result; + } + trace_map[seq_start].emplace(rank, + std::make_pair(comm_type, TraceEventStart)); + if (rank == curr_rank) { + curr_seq = seq_start; + curr_comm_type = std::move(comm_type); + } + } + { + std::string trace_end_key = GetTraceEndKey(backend, rank, group_id); + if (!store->check(trace_end_key)) { + continue; + } + + uint64_t seq = 0; + std::string comm_type; + if (!ParseTraceValue(store, trace_end_key, &seq, &comm_type)) { + return result; + } + if (seq == seq_start) { + trace_map[seq][rank].second = TraceEventEnd; + } + } + } + result += "\n\t Problem summary: rank: " + std::to_string(curr_rank) + + " timeout at collective: " + curr_comm_type + + ", group_id: " + std::to_string(group_id) + + ", seq: " + std::to_string(curr_seq); + result += AnalyzeTraceMsg(trace_map, group_id); + return result; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/type_defs.h b/paddle/phi/core/distributed/type_defs.h index cd201ac5c5aaf..1b7035c1a4528 100644 --- a/paddle/phi/core/distributed/type_defs.h +++ b/paddle/phi/core/distributed/type_defs.h @@ -18,12 +18,16 @@ #include #include +#include "paddle/utils/variant.h" + namespace phi { namespace distributed { class TensorDistAttr; -using SpmdInfo = - std::pair, std::vector>; +using ArgDistAttr = + paddle::variant>; + +using SpmdInfo = std::pair, std::vector>; } // namespace distributed } // namespace phi diff --git a/paddle/fluid/distributed/collective/types.h b/paddle/phi/core/distributed/types.h similarity index 97% rename from paddle/fluid/distributed/collective/types.h rename to paddle/phi/core/distributed/types.h index bd20f2705f22f..3d4d074efd735 100644 --- a/paddle/fluid/distributed/collective/types.h +++ b/paddle/phi/core/distributed/types.h @@ -20,7 +20,7 @@ #include "paddle/phi/common/place.h" -namespace paddle { +namespace phi { namespace distributed { // TODO(shenliang03): To support AVG for reduce @@ -58,4 +58,4 @@ struct ReduceScatterOptions { }; } // namespace distributed -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/core/distributed/utils.h b/paddle/phi/core/distributed/utils.h index f635b7d99fa61..40b28bb2a3e6f 100644 --- a/paddle/phi/core/distributed/utils.h +++ b/paddle/phi/core/distributed/utils.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ // limitations under the License. #pragma once - #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -28,13 +27,119 @@ inline phi::DenseTensor GetPartialTensor(const phi::DenseTensor& tensor, return tensor_flattened.Slice(offset, offset + numel); } -#define NCCL_CHECK(cmd) \ - do { \ - ncclResult_t r = cmd; \ - if (r != ncclSuccess) { \ - exit(EXIT_FAILURE); \ - } \ - } while (0) +inline void* GetPointerByOffset(void* raw_pointer, + size_t offset, + phi::DataType type) { + if (type == phi::DataType::FLOAT32) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::FLOAT64) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::FLOAT16) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::INT32) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::INT64) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::INT8) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::UINT8) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::BOOL) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::BFLOAT16) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Datatype %s in NCCL is not supported.", type)); + } + return nullptr; +} + +inline void CheckSizeOnEachRank(const phi::DDim& tensor_dim, + const std::vector& size_on_each_rank, + int world_size) { + int length_size_on_each_rank = size_on_each_rank.size(); + PADDLE_ENFORCE_EQ( + length_size_on_each_rank, + world_size, + phi::errors::InvalidArgument( + "The length of size_on_each_rank must be equal to world_size.")); + + int64_t sum_size_on_each_rank = + std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0); + PADDLE_ENFORCE_EQ( + sum_size_on_each_rank, + tensor_dim[0], + phi::errors::InvalidArgument( + "The sum of size_on_each_rank must be equal to tensor's dim[0].")); +} + +enum class CommType : std::uint8_t { + BROADCAST = 0, + ALLREDUCE = 1, + ALLREDUCE_SPARSE = 2, // TODO(shenliang03): to support sparse in allreduce + REDUCE = 3, + ALLGATHER = 4, + GATHER = 5, + SCATTER = 6, + REDUCE_SCATTER = 7, + ALLTOALL = 8, + SEND = 9, + RECV = 10, + BARRIER = 11, + UNKNOWN = 100, +}; + +inline bool IsP2POP(CommType comm_type, bool is_batch_p2p = false) { + if (is_batch_p2p) { + return false; + } else { + return comm_type == CommType::SEND || comm_type == CommType::RECV; + } +} + +inline std::string CommTypeToString(CommType CommType) { + switch (CommType) { + case CommType::BROADCAST: + return "Broadcast"; + case CommType::ALLREDUCE: + return "AllReduce"; + case CommType::ALLREDUCE_SPARSE: + return "AllReduce_Sparse"; + case CommType::REDUCE: + return "Reduce"; + case CommType::ALLGATHER: + return "AllGather"; + case CommType::GATHER: + return "Gather"; + case CommType::SCATTER: + return "Scatter"; + case CommType::REDUCE_SCATTER: + return "ReduceScatter"; + case CommType::ALLTOALL: + return "AllToAll"; + case CommType::SEND: + return "Send"; + case CommType::RECV: + return "Recv"; + case CommType::BARRIER: + return "Barrier"; + case CommType::UNKNOWN: + return "Unknown"; + default: + return "Unknown"; + } + return "Unknown"; +} } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/extended_tensor.h b/paddle/phi/core/extended_tensor.h index d02dbabde179f..73cae43c0b54c 100644 --- a/paddle/phi/core/extended_tensor.h +++ b/paddle/phi/core/extended_tensor.h @@ -18,12 +18,14 @@ limitations under the License. */ #include "paddle/phi/core/allocator.h" #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/utils/test_macros.h" + namespace phi { /// \brief The ExtendedTensor is a interface for custom designed class. /// If you want to pass some self-designed data as input/output to kernels, /// you can inherit from this class to store your self-designed data. -class ExtendedTensor : public TensorBase { +class TEST_API ExtendedTensor : public TensorBase { public: ExtendedTensor() = default; virtual ~ExtendedTensor() = default; diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index c7a0a81c7fb4f..8e237c4c48367 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -1134,10 +1134,17 @@ PHI_DEFINE_EXPORTED_bool(gpugraph_debug_gpu_memory, * Example: * Note: nccl blocking wait. */ + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PHI_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait"); #endif +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PHI_DEFINE_EXPORTED_bool(benchmark_nccl, + false, + "enable nccl debug mode to synchronize nccl comm"); +#endif + /** * Autotune related FLAG * Name: FLAGS_use_autotune @@ -1266,13 +1273,13 @@ PHI_DEFINE_EXPORTED_string(tensor_operants_mode, /** * Using new IR in executor FLAG - * Name: enable_new_ir_in_executor + * Name: enable_pir_in_executor * Since Version: 2.6.0 * Value Range: bool, default=false * Example: * Note: If Ture, executor will use new IR */ -PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor, +PHI_DEFINE_EXPORTED_bool(enable_pir_in_executor, false, "Enable new IR in executor"); @@ -1288,30 +1295,35 @@ PHI_DEFINE_EXPORTED_bool(enable_pir_api, false, "Enable new IR API in Python"); /** * Using new IR in executor FLAG - * Name: enable_new_ir_in_executor_trace_run + * Name: enable_pir_in_executor_trace_run * Since Version: 2.6.0 * Value Range: bool, default=false * Example: * Note: If Ture, executor will use new IR and run in beta version by for trace * version. */ -PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_trace_run, +PHI_DEFINE_EXPORTED_bool(enable_pir_in_executor_trace_run, false, "Enable new IR in executor"); /** * Apply inplace pass to new IR FLAG - * Name: new_ir_apply_inplace_pass + * Name: pir_apply_inplace_pass * Since Version: 2.6.0 * Value Range: bool, default=true * Example: * Note: If Ture, will apply inplace pass to new IR. */ -PHI_DEFINE_EXPORTED_bool(new_ir_apply_inplace_pass, +PHI_DEFINE_EXPORTED_bool(pir_apply_inplace_pass, true, "Whether to apply inplace pass on lowering " "::pir::Program to Kernel Dialect"); +PHI_DEFINE_EXPORTED_string( + ir_inplace_kernel_blacklist, + "", + "It controls the ir inplace kernel subset do not use."); + PHI_DEFINE_EXPORTED_bool(enable_record_memory, false, "Enable memory recorder"); PHI_DEFINE_EXPORTED_bool( @@ -1350,3 +1362,18 @@ PHI_DEFINE_EXPORTED_bool(dynamic_static_unified_comm, "Whether to use new communication library in auto " "parallel and static mode."); #endif // FLAGS_dynamic_static_unified_comm + +/** + * ProcessGroupNCCL related FLAG + * Name: enable_async_trace + * Since Version: + * Value Range: bool, default=false + * Example: + * Note: enable nccl async trace. + */ + +PHI_DEFINE_EXPORTED_bool(enable_async_trace, + false, + "enable collective async trace"); + +PHI_DEFINE_EXPORTED_int32(async_trace_count, 5, "collective async trace count"); diff --git a/paddle/phi/core/generator.cc b/paddle/phi/core/generator.cc index 4541b81de4630..b3f8a2d19caba 100644 --- a/paddle/phi/core/generator.cc +++ b/paddle/phi/core/generator.cc @@ -281,6 +281,8 @@ std::pair Generator::IncrementOffset( #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) std::lock_guard lock(this->mu_); uint64_t cur_offset = this->state_.thread_offset; + VLOG(10) << "cur_offset = " << cur_offset + << " increment_offset = " << increment_offset; this->state_.thread_offset += increment_offset; return std::make_pair(this->state_.current_seed, cur_offset); #else diff --git a/paddle/phi/core/meta_tensor.cc b/paddle/phi/core/meta_tensor.cc index 7f156463ca17b..8f63dc5d4d56c 100644 --- a/paddle/phi/core/meta_tensor.cc +++ b/paddle/phi/core/meta_tensor.cc @@ -124,7 +124,13 @@ void MetaTensor::set_dtype(DataType dtype) { DenseTensorUtils::GetMutableMeta(static_cast(tensor_)) ->dtype = dtype; } else if (phi::distributed::DistTensor::classof(tensor_)) { - // skip, DistTensor no need to set dtype + // For pipeline parallelism, DistTensor holds an uninitialized DenseTensor, + // But kernel launch needs to get it's placement, dtype and layout. + VLOG(3) << "DistTensor set dtype: " << dtype; + DenseTensorUtils::GetMutableMeta( + static_cast(tensor_) + ->unsafe_mutable_value()) + ->dtype = dtype; } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported settting dtype for `%s`.", tensor_->type_info().name())); @@ -158,7 +164,11 @@ void MetaTensor::set_layout(DataLayout layout) { DenseTensorUtils::GetMutableMeta(static_cast(tensor_)) ->layout = layout; } else if (phi::distributed::DistTensor::classof(tensor_)) { - // skip, DistTensor no need to set dtype + VLOG(3) << "DistTensor set layout: " << layout; + DenseTensorUtils::GetMutableMeta( + static_cast(tensor_) + ->unsafe_mutable_value()) + ->layout = layout; } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported settting layout for `%s`.", tensor_->type_info().name())); diff --git a/paddle/phi/core/utils/type_info.h b/paddle/phi/core/utils/type_info.h index b4d908e2c1d9c..9e31343ed04a4 100644 --- a/paddle/phi/core/utils/type_info.h +++ b/paddle/phi/core/utils/type_info.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include "paddle/utils/test_macros.h" namespace phi { @@ -40,7 +41,7 @@ class TypeInfo { }; template -class TypeInfoTraits { +class TEST_API TypeInfoTraits { public: static const TypeInfo kType; TypeInfoTraits(); diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index c5612b203d233..5206da6ec3785 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -148,6 +148,33 @@ namespace phi { } \ }() +///////// BOOL and Floating and Integral Dispatch Marco /////////// + +#define PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::BOOL, bool, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT64, int64_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::UINT8, uint8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT16, int16_t, __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + ///////// Floating and Complex Dispatch Marco /////////// #define PD_VISIT_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 2aa8543eb82c3..add5013298d30 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1545,6 +1545,8 @@ void GatherInferMeta(const MetaTensor& x, if (input_dim.size() == 1) { // the index is a 0D tensor and the x is a 1D tensor out->set_dims(phi::DDim(phi::Dim<0>())); + out->set_dtype(x.dtype()); + out->share_lod(x); } else { if (axis.FromTensor() || axis_v == 0) { // decrease the output dimension diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 0aca25103f80a..e7062879573c5 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -228,6 +228,8 @@ void Conv2dXPUInferMeta(const MetaTensor& x, const MetaTensor& bias, const MetaTensor& branch, const MetaTensor& branch_max, + const MetaTensor& scale_max, + const MetaTensor& out_max_in, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, @@ -378,6 +380,8 @@ void FcXPUInferMeta(const MetaTensor& x, const MetaTensor& w, const MetaTensor& w_max, const MetaTensor& bias, + const MetaTensor& scale_max, + const MetaTensor& out_max_in, int in_num_col_dims, bool transpose_x, float alpha, @@ -1914,8 +1918,8 @@ void FusedEmbeddingEltWiseLayerNormInferMeta( auto dim_output = phi::make_ddim({batch, seq_len, hidden}); out->set_dims(dim_output); - // out->share_lod(ids); - // context->ShareLoD("Ids", /*->*/ "Out"); + out->share_lod(*ids[0]); + out->set_dtype((*embs[0]).dtype()); } void FusionTransposeFlattenConcatInferMeta( @@ -1977,6 +1981,7 @@ void FusionTransposeFlattenConcatInferMeta( out_dims[concat_axis] = -1; } out->set_dims(phi::make_ddim(out_dims)); + out->set_dtype((*x[0]).dtype()); } void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x, @@ -2158,13 +2163,304 @@ void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x, } out->set_dims(y_dims); + out->set_dtype(x.dtype()); if (mean) { + mean->set_dtype(x.dtype()); mean->set_dims({dim_0}); } if (variance) { variance->set_dims({dim_0}); + variance->set_dtype(x.dtype()); } out->share_lod(x); } +void FusionGRUInferMeta(const MetaTensor& x, + const MetaTensor& h0, + const MetaTensor& weight_x, + const MetaTensor& weight_h, + const MetaTensor& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + MetaTensor* reordered_h0, + MetaTensor* xx, + MetaTensor* batched_input, + MetaTensor* batched_out, + MetaTensor* hidden) { + std::string mkldnn_data_type_list[] = {"float32", "int8", "bfloat16"}; + PADDLE_ENFORCE_EQ( + std::find(std::begin(mkldnn_data_type_list), + std::end(mkldnn_data_type_list), + mkldnn_data_type) != std::end(mkldnn_data_type_list), + true, + phi::errors::InvalidArgument("The mkldnn_data_type shoule be [float32, " + "int8, bfloat16], but found %s.", + mkldnn_data_type.c_str())); + + DDim x_dims = x.dims(); + auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) + ? phi::flatten_to_2d(x_dims, 1) + : x_dims; + PADDLE_ENFORCE_EQ( + x_mat_dims.size(), + 2, + phi::errors::InvalidArgument("The size of input X dims should be 2, " + "or 3 with second dimension equal to " + "1, but now Input X dim is:[%s] ", + x_dims)); + + auto wx_dims = weight_x.dims(); + PADDLE_ENFORCE_EQ(wx_dims.size(), + 2, + phi::errors::InvalidArgument( + "The rank of Input(WeightX) should be 2, but received " + "WeightX dim size is:%d, WeightX dim is:[%s] ", + wx_dims.size(), + wx_dims)); + PADDLE_ENFORCE_EQ( + wx_dims[0], + x_mat_dims[1], + phi::errors::InvalidArgument( + "The first dimension of flattened WeightX" + "should equal to last dimension of flattened input X, but " + "received fattened WeightX dimension is:%d, flattened X dimension " + "is:%d", + wx_dims[0], + x_mat_dims[1])); + + int frame_size = static_cast(wx_dims[1] / 3); + auto wh_dims = weight_h.dims(); + + PADDLE_ENFORCE_EQ(wh_dims.size(), + 2, + phi::errors::InvalidArgument( + "The rank of Input(WeightH) should be 2, but received " + "WeightH dim size is:%d, WeightH dim is:[%s]", + wh_dims.size(), + wh_dims)); + PADDLE_ENFORCE_EQ(wh_dims[0], + frame_size, + phi::errors::InvalidArgument( + "The first dimension of WeightH " + "should equal to frame_size, but received WeightH's " + "first dimension is: " + "%d, frame size is:%d", + wh_dims[0], + frame_size)); + PADDLE_ENFORCE_EQ(wh_dims[1], + 3 * frame_size, + phi::errors::InvalidArgument( + "The second dimension of Input(WeightH) " + "should equal to 3 * frame_size, but received WeightH " + "is:%d, frame size is:%d", + wh_dims[1], + frame_size)); + + if (h0) { + auto h0_dims = h0.dims(); + PADDLE_ENFORCE_EQ(h0_dims[1], + frame_size, + phi::errors::InvalidArgument( + "The width of H0 must be equal to frame_size, but " + "receiced the width of H0 is:%d, frame size is:%d", + h0_dims[1], + frame_size)); + reordered_h0->set_dtype(x.dtype()); + } + if (bias) { + auto b_dims = bias.dims(); + PADDLE_ENFORCE_EQ(b_dims.size(), + 2, + phi::errors::InvalidArgument( + "The rank of Input(Bias) should be 2, but received " + "Bias rank is:%d, Bias dim is:[%s]", + b_dims.size(), + b_dims)); + PADDLE_ENFORCE_EQ(b_dims[0], + 1, + phi::errors::InvalidArgument( + "The first dimension of Input(Bias) should be 1, but " + "received Bias first dim is:%d, Bias dim is:[%s]", + b_dims[0], + b_dims)); + PADDLE_ENFORCE_EQ(b_dims[1], + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but " + "received bias dim is:[%s], frame size is:%d", + b_dims, + frame_size)); + } + DDim out_dims({x_mat_dims[0], frame_size}); + hidden->set_dims(out_dims); + hidden->share_lod(x); + hidden->set_dtype(x.dtype()); + int xx_width = 0; + if (use_seq) { + xx_width = static_cast(wx_dims[1]); + } else { + xx_width = static_cast(x_mat_dims[1] > wx_dims[1] ? wx_dims[1] + : x_mat_dims[1]); + batched_input->set_dims({x_mat_dims[0], wx_dims[1]}); + batched_input->set_dtype(x.dtype()); + batched_out->set_dims(out_dims); + batched_out->set_dtype(x.dtype()); + } + xx->set_dims({x_mat_dims[0], xx_width}); + xx->set_dtype(x.dtype()); + xx->share_lod(x); +} + +void FusionSeqConvEltAddReluInferMeta(const MetaTensor& x, + const MetaTensor& filter, + const MetaTensor& bias, + const int context_length, + const int context_start, + const int context_stride, + MetaTensor* out, + MetaTensor* col_mat) { + auto x_dims = x.dims(); + auto w_dims = filter.dims(); + PADDLE_ENFORCE_GT( + context_length, + 0, + phi::errors::InvalidArgument("context_length should be greater than 0, " + "but received context_length is: %d", + context_length)); + PADDLE_ENFORCE_EQ(context_stride, + 1, + phi::errors::InvalidArgument( + "Currently, FusionSeqConvEltAddReluOp only supports " + "contextStride=1, but received value is: %d.", + context_stride)); + + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2, + phi::errors::InvalidArgument( + "Input(X) should be 2-D tensor, but reveiced value is: %d.", + x_dims.size())); + + PADDLE_ENFORCE_EQ( + w_dims.size(), + 2, + phi::errors::InvalidArgument( + "Filter should be 2-D tensor, but reveiced value is: %d.", + w_dims.size())); + + PADDLE_ENFORCE_EQ(w_dims[0], + context_length * x_dims[1], + phi::errors::InvalidArgument( + "Filter's height should be equal to context_length * " + "input_hidden_size, but received Filter height is: %d," + "context_length is: %d, input_hidden_size is: %d.", + w_dims[0], + context_length, + x_dims[1])); + + PADDLE_ENFORCE_GT( + context_length + context_start, + 0, + phi::errors::InvalidArgument( + "contextStart size should be smaller than contextLength, " + "but received context_length is: %d, contextStart is: " + "%d.", + context_length, + context_start)); + out->set_dims({x_dims[0], w_dims[1]}); + col_mat->set_dims({x_dims[0], w_dims[0]}); + out->share_lod(x); + col_mat->set_dtype(x.dtype()); + out->set_dtype(x.dtype()); +} + +void FusionSeqExpandConcatFCInferMeta(const std::vector& x, + const MetaTensor& fc_weight, + const MetaTensor& fc_bias, + const std::string& fc_activation, + MetaTensor* out, + MetaTensor* fc_out) { + PADDLE_ENFORCE_GT(x.size(), + 1UL, + phi::errors::InvalidArgument( + "Inputs(X) of FusionSeqExpandConcatFCOp should larger " + "than 1, but received value is: %d.", + x.size())); + + std::vector ins_dims; + ins_dims.reserve(x.size()); + std::transform(x.begin(), + x.end(), + std::back_inserter(ins_dims), + [](const MetaTensor* var) { return var->dims(); }); + + auto w_dims = fc_weight.dims(); // (M0+M1+M2+..) x D + PADDLE_ENFORCE_EQ( + w_dims.size(), + 2, + phi::errors::InvalidArgument( + "Input(FCWeight)'s rank must be 2, but received value is: %d.", + w_dims.size())); + const int D = static_cast(w_dims[1]); + int sum = static_cast(ins_dims[0][1]); + for (size_t i = 1; i < ins_dims.size(); ++i) { + sum += static_cast(ins_dims[i][1]); + } + PADDLE_ENFORCE_EQ( + sum, + w_dims[0], + phi::errors::InvalidArgument("FC height should be sum of all inputs " + "width, but received FC height is: %d, " + "sum of all inputs width is: %d.", + w_dims[0], + sum)); + if (fc_bias) { + auto b_dims = fc_bias.dims(); + PADDLE_ENFORCE_EQ( + b_dims.size() == 1 || b_dims.size() == 2, + true, + phi::errors::InvalidArgument( + "FCBias dim should be 1 or 2, but received value is: %d.", + b_dims.size())); + if (b_dims.size() == 1) { + PADDLE_ENFORCE_EQ(b_dims[0], + D, + phi::errors::InvalidArgument( + "FCBias shapes must be %d when FCBias dim = 1, but " + "received value is: %d.", + D, + b_dims[0])); + } else { + PADDLE_ENFORCE_EQ(b_dims[0], + 1, + phi::errors::InvalidArgument( + "FCBias shapes must be 1x%d, when FCBias dim = 2, " + "but received dim[0] is: %d.", + D, + b_dims[0])); + PADDLE_ENFORCE_EQ(b_dims[1], + D, + phi::errors::InvalidArgument( + "FCBias shapes must be 1x%d, when FCBias dim = 2, " + "but received dim[1] is: %d.", + D, + b_dims[1])); + } + } + fc_out->set_dtype((*x[0]).dtype()); + out->set_dims({ins_dims[0][0], D}); + out->set_dtype((*x[0]).dtype()); + // fcout should be reshape when run since can not get lod in infershape + // explicit share the ref lod + out->share_lod(*x[0]); +} } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index c022a4257e4dc..b6b9c64314ca8 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -62,6 +62,8 @@ void Conv2dXPUInferMeta(const MetaTensor& x, const MetaTensor& bias, const MetaTensor& branch, const MetaTensor& branch_max, + const MetaTensor& scale_max, + const MetaTensor& out_max_in, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, @@ -86,6 +88,8 @@ void FcXPUInferMeta(const MetaTensor& x, const MetaTensor& w, const MetaTensor& w_max, const MetaTensor& bias, + const MetaTensor& scale_max, + const MetaTensor& out_max_in, int in_num_col_dims, bool transpose_x, float alpha, @@ -515,4 +519,41 @@ void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x, MetaTensor* variance, MetaConfig config = MetaConfig()); +void FusionGRUInferMeta(const MetaTensor& x, + const MetaTensor& h0, + const MetaTensor& weight_x, + const MetaTensor& weight_h, + const MetaTensor& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + MetaTensor* reordered_h0, + MetaTensor* xx, + MetaTensor* batched_input, + MetaTensor* batched_out, + MetaTensor* hidden); + +void FusionSeqConvEltAddReluInferMeta(const MetaTensor& x, + const MetaTensor& filter, + const MetaTensor& bias, + const int context_length, + const int context_start, + const int context_stride, + MetaTensor* out, + MetaTensor* col_mat); + +void FusionSeqExpandConcatFCInferMeta(const std::vector& x, + const MetaTensor& fc_weight, + const MetaTensor& fc_bias, + const std::string& fc_activation, + MetaTensor* out, + MetaTensor* fc_out); } // namespace phi diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 0cd5534a9c44a..64cf9b010ae07 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -617,48 +617,52 @@ void BatchNormInferMeta(const MetaTensor& x, (data_layout == DataLayout::kNCHW) ? x_dims[1] : x_dims[x_dims.size() - 1]); - auto scale_dim = scale.dims(); - auto bias_dim = bias.dims(); + if (scale) { + PADDLE_ENFORCE_EQ( + scale.dims().size(), + 1UL, + phi::errors::InvalidArgument( + "ShapeError: the dimension of scale must equal to 1." + "But received: the shape of scale is [%s], the dimension " + "of scale is [%d]", + scale.dims().size(), + scale.dims().size())); + } - PADDLE_ENFORCE_EQ( - scale_dim.size(), - 1UL, - phi::errors::InvalidArgument( - "ShapeError: the dimension of scale must equal to 1." - "But received: the shape of scale is [%s], the dimension " - "of scale is [%d]", - scale_dim, - scale_dim.size())); - PADDLE_ENFORCE_EQ(bias_dim.size(), - 1UL, - phi::errors::InvalidArgument( - "ShapeError: the dimension of bias must equal to 1." - "But received: the shape of bias is [%s],the dimension " - "of bias is [%d]", - bias_dim, - bias_dim.size())); + if (bias) { + PADDLE_ENFORCE_EQ( + bias.dims().size(), + 1UL, + phi::errors::InvalidArgument( + "ShapeError: the dimension of bias must equal to 1." + "But received: the shape of bias is [%s],the dimension " + "of bias is [%d]", + bias.dims(), + bias.dims().size())); + } bool check = true; - if ((!config.is_runtime) && - (phi::product(scale_dim) <= 0 || phi::product(bias_dim) <= 0)) { + if (!scale || !bias || + ((!config.is_runtime) && + (phi::product(scale.dims()) <= 0 || phi::product(bias.dims()) <= 0))) { check = false; } if (check) { - PADDLE_ENFORCE_EQ(scale_dim[0], + PADDLE_ENFORCE_EQ(scale.dims()[0], C, phi::errors::InvalidArgument( "ShapeError: the shape of scale must equal to [%d]" "But received: the shape of scale is [%d]", C, - scale_dim[0])); - PADDLE_ENFORCE_EQ(bias_dim[0], + scale.dims()[0])); + PADDLE_ENFORCE_EQ(bias.dims()[0], C, phi::errors::InvalidArgument( "ShapeError: the shape of bias must equal to [%d]" "But received: the shape of bias is [%d]", C, - bias_dim[0])); + bias.dims()[0])); } y->set_dims(x_dims); mean_out->set_dims({C}); @@ -4279,9 +4283,21 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, MetaTensor* beam_cache_offset_out) { int bsz = static_cast(x.dims()[0]); auto cache_kv_dims = cache_kv.dims(); - int num_head = static_cast(cache_kv.dims()[2]); + int k_num_head = static_cast(cache_kv.dims()[2]); + int v_num_head = k_num_head; int dim_head = static_cast(cache_kv.dims()[4]); + // below's num_head is q's head actually. + int num_head = + x.dims()[x.dims().size() - 1] / dim_head - k_num_head - v_num_head; + PADDLE_ENFORCE_EQ( + num_head % k_num_head, + 0, + errors::InvalidArgument( + "The num_head of query must be divisible by the num_head of key, but " + "recived num_head of query is %d, and the num_head of key is %d", + num_head, + k_num_head)); PADDLE_ENFORCE_EQ( cache_kv_dims.size(), 5, diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index 0e3ac3fb5ca2c..c341cfdf51682 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -16,6 +16,24 @@ limitations under the License. */ namespace phi { +void ArangeInferMeta(const Scalar& start, + const Scalar& end, + const Scalar& step, + DataType dtype, + MetaTensor* out) { + if (!start.FromTensor() && !end.FromTensor() && !step.FromTensor()) { + double start_value = start.to(); + double end_value = end.to(); + double step_value = step.to(); + int numel = + static_cast(std::ceil((end_value - start_value) / step_value)); + out->set_dims(phi::make_ddim(std::vector(1, numel))); + } else { + out->set_dims({-1}); + } + out->set_dtype(dtype); +} + void AssignValueInferMeta(const std::vector& shape, DataType dtype, MetaTensor* out) { @@ -41,13 +59,11 @@ void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) { CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out); } -void CreateIntArrayInferMeta(const IntArray& data, +void CreateVecShapeInferMeta(const std::vector& shape, DataType dtype, MetaTensor* out) { - CreateInferMetaBase({static_cast(data.GetData().size())}, - dtype, - DataLayout::NCHW, - out); + CreateInferMetaBase( + {static_cast(shape.size())}, dtype, DataLayout::NCHW, out); } void CreateInferMetaBase(const std::vector& shape, diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index 2f9c9a69a13f1..b33f5c4b77e5e 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -31,11 +31,17 @@ namespace phi { // // The InferMeta Functions in this file are arranged in alphabetic order. +void ArangeInferMeta(const Scalar& start, + const Scalar& end, + const Scalar& step, + DataType dtype, + MetaTensor* out); + void AssignValueInferMeta(const std::vector& shape, DataType dtype, MetaTensor* out); -void CreateIntArrayInferMeta(const IntArray& data, +void CreateVecShapeInferMeta(const std::vector& shape, DataType dtype, MetaTensor* out); diff --git a/paddle/phi/infermeta/spmd_rules/concat.cc b/paddle/phi/infermeta/spmd_rules/concat.cc new file mode 100644 index 0000000000000..fd036cfad603a --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/concat.cc @@ -0,0 +1,187 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/concat.h" + +#include +#include + +#include "paddle/phi/infermeta/spmd_rules/elementwise.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +static bool IsEmpty(const std::vector& shape) { + return shape.empty() || shape.at(0) == 0; +} + +SpmdInfo ConcatInferSpmd(const std::vector& x, int axis) { + /* +# paddle.concat requires all tensors must either have the same shape (except +# in the concatenating dimension) or be "empty". "Empty" here strictly means +# tensor.shape is torch.Size([0]). When tensor.ndim > 1, it will be treated +# as a non-empty tensor and the shape must match on non-cat dimensions. + */ + + // 1、check tensors shapes + std::vector> tensor_shapes; + std::transform(x.begin(), + x.end(), + std::back_inserter(tensor_shapes), + [](const DistMetaTensor& meta) { + return phi::vectorize(meta.dims()); + }); + bool all_empty = + std::all_of(tensor_shapes.begin(), tensor_shapes.end(), IsEmpty); + if (all_empty) { + return SpmdInfo(); + } + + auto non_empty_iter = + std::find_if(tensor_shapes.begin(), tensor_shapes.end(), [](auto& shape) { + return !IsEmpty(shape); + }); + auto non_empty_index = non_empty_iter - tensor_shapes.begin(); + int64_t ndim = static_cast(tensor_shapes[non_empty_index].size()); + // normlize dim + int64_t dim = axis; + dim = dim < 0 ? dim + ndim : dim; + + std::vector input_attrs; + // 2、make sure all tensors replicated on concat dim + auto n_inputs = x.size(); + for (size_t i = 0; i < n_inputs; ++i) { + const auto& dist_attr = x[i].dist_attr(); + if ((!IsEmpty(tensor_shapes[i])) && IsDimSharded(dist_attr, dim)) { + auto sharded_dist_attr = ReplicateTensorDim(dist_attr, dim); + input_attrs.emplace_back(sharded_dist_attr); + } else { + input_attrs.emplace_back(dist_attr); + } + } + // 3、align non-concat dimensions according to cost + std::vector>> inputs_placements; + std::transform( + input_attrs.begin(), + input_attrs.end(), + std::back_inserter(inputs_placements), + [](const TensorDistAttr& attr) { return attr.to_placement(); }); + const auto& process_mess = input_attrs[non_empty_index].process_mesh(); + auto has_mismatch = [&](int32_t mesh_dim) { + bool mismatch = false; + for (size_t i = 0; i < n_inputs; i++) { + if ((!IsEmpty(tensor_shapes[i])) && + !PlacementEqual(inputs_placements[non_empty_index][mesh_dim], + inputs_placements[i][mesh_dim])) { + mismatch = true; + break; + } + } + return mismatch; + }; + bool need_reshard = false; + int32_t n_mesh_dim = process_mess.ndim(); + std::vector> best_placements( + n_mesh_dim, std::make_shared()); + // a dim can not be sharded twice along diffrent mesh_dim + std::set sharded_dims = {dim}; + + for (int32_t mesh_dim = 0; mesh_dim < process_mess.ndim(); ++mesh_dim) { + if (!has_mismatch(mesh_dim)) { + // use the old placement + auto& best = inputs_placements[non_empty_index][mesh_dim]; + if (best->is_shard()) { + auto shard_placement = std::dynamic_pointer_cast(best); + sharded_dims.insert(shard_placement->get_axis()); + } + best_placements[mesh_dim] = best; + } + } + + for (int32_t mesh_dim = 0; mesh_dim < process_mess.ndim(); ++mesh_dim) { + if (!has_mismatch(mesh_dim)) { + continue; + } + need_reshard = true; + std::vector costs; + for (int32_t shard_dim = 0; shard_dim < ndim; shard_dim++) { + double cost = std::numeric_limits::infinity(); + if (!sharded_dims.count(shard_dim)) { + cost = 0.0; + for (size_t i = 0; i < n_inputs; i++) { + auto& tensor_shape = tensor_shapes[i]; + auto& tensor_dist_attr = input_attrs[i]; + if (IsEmpty(tensor_shape)) { + continue; + } + + if (tensor_shape[shard_dim] < process_mess.dim_size(mesh_dim)) { + // should not be selected + cost += std::numeric_limits::infinity(); + continue; + } + if (IsDimSharded(tensor_dist_attr, shard_dim)) { + continue; + } + int64_t num = std::accumulate(tensor_shape.begin(), + tensor_shape.end(), + 1, + std::multiplies()); + if (num == static_cast(0)) { + continue; + } + std::vector local_shape = + GetLocalShape(tensor_shape, process_mess, inputs_placements[i]); + cost += std::accumulate(local_shape.begin(), + local_shape.end(), + 1, + std::multiplies()) * + process_mess.dim_size(mesh_dim); + } + } + costs.push_back(cost); + } + auto min_itr = std::min_element(costs.begin(), costs.end()); + auto min_dim = min_itr - costs.begin(); + if (!sharded_dims.count(min_dim)) { + best_placements[mesh_dim] = std::make_shared(min_dim); + sharded_dims.insert(min_dim); + } + } + // set placement to the best placements + if (need_reshard) { + std::vector new_input_attrs; + for (auto& e : input_attrs) { + new_input_attrs.emplace_back(FromPlacements(e, best_placements)); + } + std::swap(input_attrs, new_input_attrs); + } + return {{input_attrs}, {input_attrs[non_empty_index]}}; +} + +SpmdInfo ConcatInferSpmdReverse(const std::vector& x, + const DistMetaTensor& output, + int axis) { + // TODO(liuzhenhai): add latter + return SpmdInfo(); +} +SpmdInfo ConcatInferSpmdDynamic(const std::vector& x, + const Scalar& axis) { + return ConcatInferSpmd(x, axis.to()); +} +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/concat.h b/paddle/phi/infermeta/spmd_rules/concat.h new file mode 100644 index 0000000000000..0f7435bec0b23 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/concat.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { +SpmdInfo ConcatInferSpmd(const std::vector& x, int axis); + +SpmdInfo ConcatInferSpmdReverse(const std::vector& x, + const DistMetaTensor& output, + int axis); + +SpmdInfo ConcatInferSpmdDynamic(const std::vector& x, + const Scalar& axis); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc b/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc index eb469200a7ec8..7a3639147f1ee 100644 --- a/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc +++ b/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc @@ -95,7 +95,8 @@ SpmdInfo DefaultDataParallelInferSpmd( << str_join(output_dist_attrs[i].dims_mapping()) << "]"; } - return {dst_input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(dst_input_dist_attrs), + ToArgDistAttr(output_dist_attrs)}; } SpmdInfo DefaultDataParallelInferSpmdReverse( const std::vector& ins, @@ -157,7 +158,8 @@ SpmdInfo DefaultDataParallelInferSpmdReverse( << str_join(dst_input_dist_attrs[i].dims_mapping()) << "]"; } - return {dst_input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(dst_input_dist_attrs), + ToArgDistAttr(output_dist_attrs)}; } } // namespace distributed diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.cc b/paddle/phi/infermeta/spmd_rules/elementwise.cc index 24d6bed03c52d..9ec18bdaf50ce 100644 --- a/paddle/phi/infermeta/spmd_rules/elementwise.cc +++ b/paddle/phi/infermeta/spmd_rules/elementwise.cc @@ -309,6 +309,18 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x, return {{x_dist_attr_dst, y_dist_attr_dst}, {out_dist_attr}}; } +SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad) { + return {{out_grad.dist_attr(), out_grad.dist_attr()}, {out_grad.dist_attr()}}; +} + +SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out, + const DistMetaTensor& out_grad) { + return {{out_grad.dist_attr(), out_grad.dist_attr(), out_grad.dist_attr()}, + {out_grad.dist_attr()}}; +} + SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y, const DistMetaTensor& out_grad, @@ -376,5 +388,17 @@ SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x, {x_grad_dist_attr, y_grad_dist_attr}}; } +SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out, + const DistMetaTensor& out_grad, + int64_t axis) { + // The out's dist_attr is the same with out_grad's dist_attr, reuse + // ElementwiseBinaryGradInferSpmd(x, y, out_grad, axis) to infer dist_attrs of + // {{x, y, out_grad}, {x_grad, y_grad}}, then insert out's dist_attr into it. + SpmdInfo info = ElementwiseBinaryGradInferSpmd(x, y, out_grad, axis); + info.first.emplace(info.first.begin() + 2, out_grad.dist_attr()); + return info; +} } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.h b/paddle/phi/infermeta/spmd_rules/elementwise.h index 736aeec35ed0a..2dd8d4c764a40 100644 --- a/paddle/phi/infermeta/spmd_rules/elementwise.h +++ b/paddle/phi/infermeta/spmd_rules/elementwise.h @@ -27,6 +27,13 @@ SpmdInfo ElementwiseUnaryInferSpmd(const DistMetaTensor& x); SpmdInfo ElementwiseUnaryInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out); +SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad); + +SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out, + const DistMetaTensor& out_grad); + SpmdInfo ElementwiseBinaryInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y); @@ -37,7 +44,13 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x, SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y, const DistMetaTensor& out_grad, - int64_t axis); + int64_t axis = -1); + +SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out, + const DistMetaTensor& out_grad, + int64_t axis = -1); } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/flatten.cc b/paddle/phi/infermeta/spmd_rules/flatten.cc new file mode 100644 index 0000000000000..0a9c4111d8e7f --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/flatten.cc @@ -0,0 +1,203 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/flatten.h" +#include + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/dim_trans.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +int PreprocessAxis(int axis, int ndim) { + if (axis < 0) { + axis += ndim; + } + + PADDLE_ENFORCE_LT( + axis, + ndim, + phi::errors::InvalidArgument("The Start_axis or Stop_axis [%d] is not " + "less than the Tensor X's rank [%d].", + axis, + ndim)); + + return axis; +} + +std::vector MakeFlattenDimTrans( + const std::vector& src_shape, int start_axis, int stop_axis) { + std::vector ret; + + std::vector input_dims; + for (int64_t i = 0; i < static_cast(src_shape.size()); i++) { + if (i < start_axis || i > stop_axis) { + ret.emplace_back(new InputDim(i)); + } else { + input_dims.emplace_back(new InputDim(i)); + } + + if (i == stop_axis) { + ret.emplace_back(make_flatten(input_dims)); + } + } + + return ret; +} + +std::vector MakeFlattenDimTransReverse( + const std::vector& src_shape, int start_axis, int stop_axis) { + std::vector ret; + + std::vector tgt_splitted_shape; + for (int i = start_axis; i <= stop_axis; i++) { + tgt_splitted_shape.emplace_back(src_shape[i]); + } + + for (int64_t i = 0; i < static_cast(src_shape.size()); i++) { + if (i < start_axis) { + ret.emplace_back(new InputDim(i)); + } else if (i > stop_axis) { + ret.emplace_back(new InputDim(i - (stop_axis - start_axis))); + } else { + ret.emplace_back(make_split( + new InputDim(start_axis), tgt_splitted_shape, i - start_axis)); + } + } + + return ret; +} + +SpmdInfo FlattenInferSpmd(const DistMetaTensor& x, + int start_axis, + int stop_axis) { + // Step0: Verify input args based on flatten logic + auto src_shape = phi::vectorize(x.dims()); + int x_ndim = static_cast(src_shape.size()); + auto x_dist_attr_src = x.dist_attr(); + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + x_ndim, + x_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's " + "dims_mapping size [%d] are not matched.", + x_ndim, + x_dims_mapping.size())); + + // Step1: Build the transformation from + // the original shape to the target shape + + start_axis = PreprocessAxis(start_axis, x_ndim); + stop_axis = PreprocessAxis(stop_axis, x_ndim); + std::vector trans = + MakeFlattenDimTrans(src_shape, start_axis, stop_axis); + + // Step2: Infer the dims mapping of input (if reshard is + // needed) and output from the dimension transformation. + std::vector> dims_mapping_vec = + InferFromDimTrans(x, trans); + + // Step3: Update the dist attributes of input + // and output with the inferred dims mapping. + TensorDistAttr x_dist_attr_dst(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr out_dist_attr(x_dist_attr_src); + out_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + + VLOG(4) << "FlattenInferSpmd: X shape: [" << str_join(src_shape) << "]"; + VLOG(4) << "Start_axis: " << start_axis; + VLOG(4) << "Stop_axis: " << start_axis; + VLOG(4) << "Transformation from input to output:"; + for (int64_t i = 0, n = static_cast(trans.size()); i < n; i++) { + DimTrans* t = trans[i]; + VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string(); + } + VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping) + << "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]"; + VLOG(4) << "Out dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n"; + + CleanUp(); + + return {{x_dist_attr_dst}, {out_dist_attr}}; +} + +SpmdInfo FlattenInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + int start_axis, + int stop_axis) { + // Step0: Verify input args based on flatten logic + auto x_shape = phi::vectorize(x.dims()); + auto x_ndim = x_shape.size(); + auto out_shape = phi::vectorize(out.dims()); + int out_ndim = out_shape.size(); + auto out_dist_attr_src = out.dist_attr(); + std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + out_ndim, + out_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor Out's rank [%d] and Out's " + "dims_mapping size [%d] are not matched.", + out_ndim, + out_dims_mapping.size())); + + // Step1: Build the transformation from the output shape + // to original shape. This function infers the dims mapping + // from output to input, we first get the transformation + // from output to input so that we can infer the dims mapping + // with the map from output axes to input axes. + + start_axis = PreprocessAxis(start_axis, x_ndim); + stop_axis = PreprocessAxis(stop_axis, x_ndim); + + std::vector trans = + MakeFlattenDimTransReverse(x_shape, start_axis, stop_axis); + + // Step2: Infer the dims mapping of input with + // output's dims_mapping and the transformation. + std::vector> dims_mapping_vec = + InferFromDimTrans(out, trans); + + // Step3: Update the dist attributes of input + // and output with the inferred dims mapping + TensorDistAttr out_dist_attr_dst(out_dist_attr_src); + out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr x_dist_attr(x.dist_attr()); + x_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + + VLOG(4) << "FlattenInferSpmdReverse: Out shape: [" << str_join(out_shape) + << "] X shape: [" << str_join(x_shape) << "]"; + VLOG(4) << "Transformation from output to input:"; + for (int64_t i = 0, n = trans.size(); i < n; i++) { + DimTrans* t = trans[i]; + VLOG(4) << "\tX axis[" << i << "]: " << t->to_string(); + } + VLOG(4) << "Out dims_mapping_src: [" << str_join(out_dims_mapping) << "] " + << "dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]"; + VLOG(4) << "X dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n"; + + CleanUp(); + + return {{x_dist_attr}, {out_dist_attr_dst}}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/flatten.h b/paddle/phi/infermeta/spmd_rules/flatten.h new file mode 100644 index 0000000000000..bb62d8c0d7b0a --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/flatten.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo FlattenInferSpmd(const DistMetaTensor& x, + int start_axis, + int stop_axis); + +SpmdInfo FlattenInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + int start_axis, + int stop_axis); +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc index 6befef19cfef1..1dfe8bf19c296 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.cc +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -275,7 +275,7 @@ SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x, } VLOG(4) << std::endl; - return {input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(input_dist_attrs), ToArgDistAttr(output_dist_attrs)}; } } // namespace distributed diff --git a/paddle/phi/infermeta/spmd_rules/matmul.cc b/paddle/phi/infermeta/spmd_rules/matmul.cc index 4893c7071f19e..60c7acacf0478 100644 --- a/paddle/phi/infermeta/spmd_rules/matmul.cc +++ b/paddle/phi/infermeta/spmd_rules/matmul.cc @@ -291,17 +291,22 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& out_grad, bool trans_x, bool trans_y) { - auto confirm_dist_attr_same_fn = [&](const TensorDistAttr& x_dist_attr, + auto get_attr = [](const ArgDistAttr& attr) -> const TensorDistAttr& { + return paddle::get(attr); + }; + + auto confirm_dist_attr_same_fn = [&](const ArgDistAttr& x_dist_attr, const DistMetaTensor& y, const char* debug_msg) { + const auto& x_single_dist_attr = get_attr(x_dist_attr); PADDLE_ENFORCE_EQ( - DistAttrsAreBasicallyEqual(x_dist_attr, y.dist_attr()), + DistAttrsAreBasicallyEqual(x_single_dist_attr, y.dist_attr()), true, phi::errors::Unavailable("The matmul grad infer spmd `%s` verify " "error: left dist attr is %s, " "right dist attr is %s.", debug_msg, - x_dist_attr, + x_single_dist_attr, y.dist_attr())); }; @@ -313,8 +318,8 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x, // so it cannot be handled correctly in the backward for the time being // For this case, we uniformly transition the input to the Replicated state. auto fwd_spmd_info = MatmulInferSpmd(x, y, trans_x, trans_y); - if (x.dist_attr() != fwd_spmd_info.first[0] || - y.dist_attr() != fwd_spmd_info.first[1]) { + if (x.dist_attr() != get_attr(fwd_spmd_info.first[0]) || + y.dist_attr() != get_attr(fwd_spmd_info.first[1])) { auto x_r_dist_attr = GetReplicatedDistAttr(x.dist_attr()); auto y_r_dist_attr = GetReplicatedDistAttr(y.dist_attr()); return {{x_r_dist_attr, diff --git a/paddle/phi/infermeta/spmd_rules/reduction.cc b/paddle/phi/infermeta/spmd_rules/reduction.cc index 24c90a1792341..24fc64484f418 100644 --- a/paddle/phi/infermeta/spmd_rules/reduction.cc +++ b/paddle/phi/infermeta/spmd_rules/reduction.cc @@ -29,8 +29,15 @@ using phi::distributed::auto_parallel::str_join; ////////////////// Utils Functions ////////////////// std::string GetOutputNotation(int input_ndim, const std::string& input_axes, - std::vector reduce_dims, + std::vector reduce_dims, bool keep_dim) { + // if input_axes is empty means reduce all + if (reduce_dims.empty()) { + for (int i = 0; i < input_ndim; ++i) { + reduce_dims.emplace_back(i); + } + } + // convert the negative dim value to normal dim value for (auto& reduce_dim : reduce_dims) { if (reduce_dim < 0) { @@ -40,7 +47,7 @@ std::string GetOutputNotation(int input_ndim, std::string output_axes = ""; for (int i = 0; i < input_ndim; i++) { - std::vector::iterator iter = + std::vector::iterator iter = std::find(reduce_dims.begin(), reduce_dims.end(), i); if (iter != reduce_dims.end()) { // if i is reduce dim, the corresponding input axis @@ -58,9 +65,10 @@ std::string GetOutputNotation(int input_ndim, return output_axes; } -SpmdInfo ReductionInferSpmd(const DistMetaTensor& x, - const std::vector& axis, - bool keep_dim) { +SpmdInfo ReductionInferSpmdBase(const DistMetaTensor& x, + const std::vector& axis, + bool keep_dim, + int reduce_type) { // Step0: Verify input args based on reduction logic auto x_shape = phi::vectorize(x.dims()); int x_ndim = x_shape.size(); @@ -102,8 +110,8 @@ SpmdInfo ReductionInferSpmd(const DistMetaTensor& x, // Step3.1 Output Partial std::vector partial_on_dims = ResoluteOutputPartialDimension(axis_to_dim_map, out_axes); - out_dist_attr.set_partial_status( - partial_on_dims /*, handle reduce_type in future */); + out_dist_attr.set_partial_status(partial_on_dims, + static_cast(reduce_type)); // Step3.2 handle input tensor partial (TODO) // If the op is a linear op, i.e. `linearity` is true, it supports @@ -116,14 +124,37 @@ SpmdInfo ReductionInferSpmd(const DistMetaTensor& x, VLOG(4) << "Input0 shape: [" << str_join(x_shape) << "] " << "dims_mapping: [" << str_join(x_dims_mapping) << "]"; VLOG(4) << "Output dims_mapping: [" + str_join(out_dims_mapping) + "] " - << "partial_on_dims: [" + str_join(partial_on_dims) + "]\n\n"; + << "partial_on_dims: [" + str_join(partial_on_dims) + << " with reduce_type " << reduce_type << "]\n\n"; return {{x_dist_attr_src}, {out_dist_attr}}; } +SpmdInfo ReductionInferSpmd(const DistMetaTensor& x, + const std::vector& axis, + bool keep_dim) { + return ReductionInferSpmdBase( + x, axis, keep_dim, static_cast(ReduceType::kRedSum)); +} + +SpmdInfo ReductionMeanInferSpmdDynamic(const DistMetaTensor& x, + const IntArray& axis, + bool keep_dim) { + return ReductionInferSpmdBase( + x, axis.GetData(), keep_dim, static_cast(ReduceType::kRedAvg)); +} + +SpmdInfo ReductionSumInferSpmdDynamic(const DistMetaTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim) { + return ReductionInferSpmdBase( + x, axis.GetData(), keep_dim, static_cast(ReduceType::kRedSum)); +} + SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out, - const std::vector& axis, + const std::vector& axis, bool keep_dim) { // Step0: Verify input args based on reduction logic auto x_shape = phi::vectorize(x.dims()); @@ -174,5 +205,44 @@ SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x, return {{x_dist_attr_dst}, {out_dist_attr_src}}; } +SpmdInfo ReductionGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad, + const IntArray& axis, + bool keep_dim, + bool reduce_all) { + TensorDistAttr x_dist_attr = out_grad.dist_attr(); + TensorDistAttr x_grad_dist_attr = out_grad.dist_attr(); + + std::vector x_dim = phi::vectorize(x.dims()); + std::vector out_grad_dim = phi::vectorize(out_grad.dims()); + + if (x_dim.size() != out_grad_dim.size()) { + auto dims_mapping = x_dist_attr.dims_mapping(); + auto axis_value = axis.GetData(); + + for (size_t i = 0; i < axis_value.size(); ++i) { + if (axis_value[i] < 0) { + axis_value[i] += x_dim.size(); + } + } + std::sort(axis_value.begin(), axis_value.end()); + + // if the input_axes is empty means to reduce all + if (axis_value.empty()) { + for (size_t i = 0; i < x_dim.size(); ++i) { + axis_value.emplace_back(i); + } + } + + for (const auto& axis : axis_value) { + dims_mapping.insert(dims_mapping.begin() + axis, -1); + } + x_dist_attr.set_dims_mapping(dims_mapping); + x_grad_dist_attr.set_dims_mapping(dims_mapping); + } + + return {{x_dist_attr, out_grad.dist_attr()}, {x_grad_dist_attr}}; +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/reduction.h b/paddle/phi/infermeta/spmd_rules/reduction.h index ed9341ddc6904..e010abbb1f60c 100644 --- a/paddle/phi/infermeta/spmd_rules/reduction.h +++ b/paddle/phi/infermeta/spmd_rules/reduction.h @@ -16,6 +16,7 @@ limitations under the License. */ #include +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/type_defs.h" @@ -23,13 +24,32 @@ namespace phi { namespace distributed { SpmdInfo ReductionInferSpmd(const DistMetaTensor& x, - const std::vector& axis, + const std::vector& axis, bool keep_dim); +// This infer spmd function only use in dynamic mode for it uses +// IntArray as parameter. The IntArray may contain vector of tensor +// which is not support in static mode. So we separate these two and +// use dynamic infer_spmd invoke static infer_spmd function. +SpmdInfo ReductionMeanInferSpmdDynamic(const DistMetaTensor& x, + const IntArray& axis, + bool keep_dim); + +SpmdInfo ReductionSumInferSpmdDynamic(const DistMetaTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim); + SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out, - const std::vector& axis, + const std::vector& axis, bool keep_dim); +SpmdInfo ReductionGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad, + const IntArray& axis, + bool keep_dim, + bool reduce_all); + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/replicated.cc b/paddle/phi/infermeta/spmd_rules/replicated.cc index fce5a1d89b263..d0c90f7b2d2a9 100644 --- a/paddle/phi/infermeta/spmd_rules/replicated.cc +++ b/paddle/phi/infermeta/spmd_rules/replicated.cc @@ -54,7 +54,11 @@ SpmdInfo ReplicatedInferSpmd(const std::vector& ins, // Step3: Merge and get Inputs' Batch Axis New Dims Mapping. std::vector dst_input_dist_attrs; for (int64_t i = 0; i < ninputs; i++) { + // `ndim == -1` means input is nullptr int ndim = ins[i]->dims().size(); + if (ndim == -1) { + continue; + } TensorDistAttr dist_attr_dst = CopyTensorDistAttrForOutput(ins[i]->dist_attr()); std::vector dst_dims_maping = GetReplicatedDimsmapping(ndim); @@ -64,6 +68,9 @@ SpmdInfo ReplicatedInferSpmd(const std::vector& ins, VLOG(4) << "ReplicatedSpmd InferForward:"; for (int64_t i = 0; i < ninputs; i++) { + if (ins[i]->dims().size() == -1) { + continue; + } VLOG(4) << "Input" << std::to_string(i) << " shape: [" << str_join(phi::vectorize(ins[i]->dims())) << "] " << "src_dims_mapping: [" @@ -79,7 +86,8 @@ SpmdInfo ReplicatedInferSpmd(const std::vector& ins, << str_join(output_dist_attrs[i].dims_mapping()) << "]"; } - return {dst_input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(dst_input_dist_attrs), + ToArgDistAttr(output_dist_attrs)}; } SpmdInfo ReplicatedInferSpmdReverse( @@ -128,7 +136,53 @@ SpmdInfo ReplicatedInferSpmdReverse( << str_join(dst_input_dist_attrs[i].dims_mapping()) << "]"; } - return {dst_input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(dst_input_dist_attrs), + ToArgDistAttr(output_dist_attrs)}; +} + +SpmdInfo ReplicatedInferDynamic( + const std::vector*>>& + inputs) { + std::vector nonnull_inputs; + int64_t ninputs = inputs.size(); + SpmdInfo spmd_info; + + auto build_tensor_dist_attr = + [&nonnull_inputs](const DistMetaTensor& dist_meta_tensor) { + int ndim = dist_meta_tensor.dims().size(); + TensorDistAttr dist_attr_dst = + CopyTensorDistAttrForOutput(dist_meta_tensor.dist_attr()); + // `ndim == -1` means input is nullptr + if (ndim >= 0) { + std::vector dst_dims_maping = GetReplicatedDimsmapping(ndim); + dist_attr_dst.set_dims_mapping(dst_dims_maping); + nonnull_inputs.push_back(&dist_meta_tensor); + } + return dist_attr_dst; + }; + + for (int64_t i = 0; i < ninputs; i++) { + if (paddle::holds_alternative(inputs[i])) { + auto dist_meta_tensor_ptr = paddle::get<0>(inputs[i]); + auto& dist_meta_tensor = *dist_meta_tensor_ptr; + auto dist_attr_dst = build_tensor_dist_attr(dist_meta_tensor); + VLOG(4) << "input " << i << ": dist attr: " << dist_attr_dst.to_string(); + spmd_info.first.emplace_back(dist_attr_dst); + } else { + std::vector list_dist_attr; + auto dist_meta_tensors_ptr = paddle::get<1>(inputs[i]); + auto& dist_meta_tensors = *dist_meta_tensors_ptr; + for (const auto& dist_meta_tensor : dist_meta_tensors) { + auto dist_attr_dst = build_tensor_dist_attr(dist_meta_tensor); + VLOG(4) << "input " << i + << ": dist attr: " << dist_attr_dst.to_string(); + list_dist_attr.emplace_back(std::move(dist_attr_dst)); + } + spmd_info.first.emplace_back(std::move(list_dist_attr)); + } + } + return spmd_info; } } // namespace distributed diff --git a/paddle/phi/infermeta/spmd_rules/replicated.h b/paddle/phi/infermeta/spmd_rules/replicated.h index a8d6c0719f2ec..1f3a26cb426d4 100644 --- a/paddle/phi/infermeta/spmd_rules/replicated.h +++ b/paddle/phi/infermeta/spmd_rules/replicated.h @@ -41,6 +41,19 @@ SpmdInfo ReplicatedInferSpmdReverse( const std::vector& ins, const std::vector& outs); +SpmdInfo ReplicatedInferDynamic( + const std::vector*>>& + inputs); + +// For phi api +template +SpmdInfo VariadicReplicatedInferSpmdDynamic(const Args&... args) { + return detail::ReplicateInferSpmdDynamicHelper() + .apply(args...) + .Infer(); +} + // For phi api template SpmdInfo VariadicReplicatedInferSpmd(const Args&... args) { diff --git a/paddle/phi/infermeta/spmd_rules/reshape.cc b/paddle/phi/infermeta/spmd_rules/reshape.cc index 4c95b846c87d0..42e946c732161 100644 --- a/paddle/phi/infermeta/spmd_rules/reshape.cc +++ b/paddle/phi/infermeta/spmd_rules/reshape.cc @@ -50,7 +50,7 @@ std::vector InferTargetShape(const std::vector& shape, PADDLE_ENFORCE_EQ( product, len, - phi::errors::InvalidArgument("The total size are not matched")); + phi::errors::InvalidArgument("The total size are not matched.")); return std::vector(shape); } else { std::vector new_shape(shape); @@ -59,7 +59,7 @@ std::vector InferTargetShape(const std::vector& shape, PADDLE_ENFORCE_EQ(len % infer_size, 0, phi::errors::InvalidArgument( - "The total is not diviable by infer_size")); + "The total is not diviable by infer_size.")); new_shape[infer_idx] = infer_size; return new_shape; } @@ -143,8 +143,11 @@ std::vector MakeReshapeDimTrans( SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x, const std::vector& shape) { // Step0: Verify input args based on reshape logic - auto src_shape = phi::vectorize(x.dims()); - int x_ndim = src_shape.size(); + VLOG(2) << "Debug Info for reshape"; + VLOG(2) << "shape: " << str_join(shape); + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + int out_ndim = shape.size(); auto x_dist_attr_src = x.dist_attr(); std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ( @@ -154,20 +157,31 @@ SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x, "dims_mapping size [%d] are not matched.", x_ndim, x_dims_mapping.size())); + VLOG(4) << "ReshapeInferSpmd: X shape: [" << str_join(x_shape) << "]"; + VLOG(4) << "Out shape: [" << str_join(shape) << "]"; // Step1: Build the transformation from // the original shape to the target shape + // handle the case of dynamic shape, like [-1, -1, ...] --> [0, 0, ...]. + // This is used in inference but reshape allows only one '-1' in the + // target shape, so set the shape to a special value '256' + for (int i = 0; i < x_ndim; i++) { + if (x_shape[i] == -1) { + x_shape[i] = 256; + } + } + // handle the '0' values in target shape, '0' indicates // that the target shape is equal to the source shape std::vector tgt_shape(shape); - for (int64_t i = 0, n = static_cast(tgt_shape.size()); i < n; i++) { + for (int64_t i = 0; i < out_ndim; i++) { if (tgt_shape[i] == 0) { - tgt_shape[i] = src_shape[i]; + tgt_shape[i] = x_shape[i]; } } - std::vector trans = MakeReshapeDimTrans(src_shape, tgt_shape); + std::vector trans = MakeReshapeDimTrans(x_shape, tgt_shape); // Step2: Infer the dims mapping of input (if reshard is // needed) and output from the dimension transformation. @@ -181,17 +195,14 @@ SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x, TensorDistAttr out_dist_attr(x_dist_attr_src); out_dist_attr.set_dims_mapping(dims_mapping_vec[1]); - VLOG(4) << "ReshapeInferSpmd: X shape: [" << str_join(src_shape) - << "] Out shape: [" << str_join(tgt_shape) << "]"; VLOG(4) << "Transformation from input to output:"; for (int64_t i = 0, n = static_cast(trans.size()); i < n; i++) { DimTrans* t = trans[i]; VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string(); } VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping) - << "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) - << "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1]) - << "]\n\n"; + << "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]"; + VLOG(4) << "Out dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n"; CleanUp(); @@ -201,9 +212,12 @@ SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x, SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out, const std::vector& shape) { + VLOG(2) << "Debug Info for reshape_reverse"; + VLOG(2) << "shape: " << str_join(shape); // Step0: Verify input args based on reshape logic auto x_shape = phi::vectorize(x.dims()); auto out_shape = phi::vectorize(out.dims()); + int x_ndim = x_shape.size(); int out_ndim = out_shape.size(); auto out_dist_attr_src = out.dist_attr(); std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); @@ -214,14 +228,39 @@ SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x, "dims_mapping size [%d] are not matched.", out_ndim, out_dims_mapping.size())); + VLOG(4) << "ReshapeInferSpmdReverse: Out shape: [" << str_join(out_shape) + << "], X shape: [" << str_join(x_shape) << "]"; // Step1: Build the transformation from the output shape // to original shape. This function infers the dims mapping // from output to input, we first get the transformation // from output to input so that we can infer the dims mapping // with the map from output axes to input axes. - // Shapes in InferSpmdReverse don't contain -1 or 0, so they will - // not be modified and we can directly use them. + + // handle the case of dynamic shape, like [-1, -1, ...] --> [0, 0, ...]. + // This is used in inference but reshape allows only one '-1' in the + // target shape, so set the shape to a special value '256' + for (int i = 0; i < x_ndim; i++) { + if (x_shape[i] == -1) { + x_shape[i] = 256; + } + } + + // handle the '0' values in target shape, '0' indicates + // that the target shape is equal to the source shape + std::vector tgt_shape(shape); + for (int64_t i = 0; i < out_ndim; i++) { + if (shape[i] == 0) { + out_shape[i] = x_shape[i]; + } + } + + // The out_shape may contain '-1', which will cause error + // when inferring the transformation from out_shape to + // x_shape, so infer the '-1' value before inferrng DimTrans + int64_t nelm = std::accumulate( + x_shape.begin(), x_shape.end(), 1, std::multiplies()); + out_shape = InferTargetShape(out_shape, nelm); std::vector trans = MakeReshapeDimTrans(out_shape, x_shape); // Step2: Infer the dims mapping of input with @@ -236,8 +275,6 @@ SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x, TensorDistAttr x_dist_attr(x.dist_attr()); x_dist_attr.set_dims_mapping(dims_mapping_vec[1]); - VLOG(4) << "ReshapeInferSpmdReverse: Out shape: [" << str_join(out_shape) - << "] X shape: [" << str_join(x_shape) << "]"; VLOG(4) << "Transformation from output to input:"; for (int64_t i = 0, n = trans.size(); i < n; i++) { DimTrans* t = trans[i]; diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index dd89de7229b9a..4d7d9d8a2d07b 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -16,17 +16,21 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/infermeta/spmd_rules/concat.h" #include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h" #include "paddle/phi/infermeta/spmd_rules/elementwise.h" #include "paddle/phi/infermeta/spmd_rules/embedding.h" +#include "paddle/phi/infermeta/spmd_rules/flatten.h" #include "paddle/phi/infermeta/spmd_rules/layer_norm.h" #include "paddle/phi/infermeta/spmd_rules/matmul.h" #include "paddle/phi/infermeta/spmd_rules/reduction.h" #include "paddle/phi/infermeta/spmd_rules/replicated.h" #include "paddle/phi/infermeta/spmd_rules/reshape.h" +#include "paddle/phi/infermeta/spmd_rules/slice.h" #include "paddle/phi/infermeta/spmd_rules/softmax.h" #include "paddle/phi/infermeta/spmd_rules/split.h" #include "paddle/phi/infermeta/spmd_rules/transpose.h" +#include "paddle/phi/infermeta/spmd_rules/unsqueeze.h" /** * Design Notes: @@ -68,7 +72,7 @@ PD_REGISTER_SPMD_RULE( // default data parallel rule PD_REGISTER_SPMD_RULE( - unsqueeze, + default_data_parallel, PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd), PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse)); PD_REGISTER_SPMD_RULE( @@ -82,6 +86,12 @@ PD_REGISTER_SPMD_RULE( PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmd), PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmdReverse)); +// unsqueeze rule +PD_REGISTER_SPMD_RULE( + unsqueeze, + PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmd), + PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmdReverse)); + // elementwise unary rule PD_REGISTER_SPMD_RULE( assign, @@ -492,6 +502,11 @@ PD_REGISTER_SPMD_RULE(reshape2, PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd), PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse)); +// flatten rule +PD_REGISTER_SPMD_RULE(flatten, + PD_INFER_SPMD(phi::distributed::FlattenInferSpmd), + PD_INFER_SPMD(phi::distributed::FlattenInferSpmdReverse)); + // embedding rule PD_REGISTER_SPMD_RULE( embedding, @@ -511,6 +526,15 @@ PD_REGISTER_SPMD_RULE( PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmd), PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmdReverse)); +// slice rule +PD_REGISTER_SPMD_RULE(slice, + PD_INFER_SPMD(phi::distributed::SliceInferSpmd), + PD_INFER_SPMD(phi::distributed::SliceInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE(concat, + PD_INFER_SPMD(phi::distributed::ConcatInferSpmd), + PD_INFER_SPMD(phi::distributed::ConcatInferSpmdReverse)); + // transpose rule PD_REGISTER_SPMD_RULE( transpose, diff --git a/paddle/phi/infermeta/spmd_rules/slice.cc b/paddle/phi/infermeta/spmd_rules/slice.cc new file mode 100644 index 0000000000000..d73fdfe8629ef --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/slice.cc @@ -0,0 +1,176 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/slice.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +SpmdInfo SliceInferSpmd(const DistMetaTensor& input, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis) { + auto input_shape = phi::vectorize(input.dims()); + int input_ndim = input_shape.size(); + auto input_dist_attr_src = input.dist_attr(); + std::vector input_dims_mapping = input_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + input_ndim, + input_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor Input's rank [%d] and Input's " + "dims_mapping size [%d] are not matched.", + input_ndim, + input_dims_mapping.size())); + + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + std::string input_axes = alphabet.substr(0, input_ndim); + std::string special_axes = alphabet.substr(input_ndim); + + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + input_axes[axis] = special_axes[i]; + } + + std::string out_axes(input_axes); + + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + out_axes[axis] = '1'; + } + + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{input_axes, input_dims_mapping}}); + + std::vector out_dims_mapping = + GetDimsMappingForAxes(out_axes, axis_to_dim_map); + + TensorDistAttr out_dist_attr = + CopyTensorDistAttrForOutput(input_dist_attr_src); + out_dist_attr.set_dims_mapping(out_dims_mapping); + + TensorDistAttr input_dist_attr_dst(input_dist_attr_src); + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + input_dims_mapping[axis] = -1; + } + input_dist_attr_dst.set_dims_mapping(input_dims_mapping); + + VLOG(4) << "SliceInferSpmd:"; + VLOG(4) << "Einsum Notation: " << input_axes << "-->" << out_axes; + VLOG(4) << "Input shape: [" << str_join(input_shape) << "] " + << "src_dims_mapping: [" + << str_join(input_dist_attr_src.dims_mapping()) << "] " + << "dst_dims_mapping: [" << str_join(input_dims_mapping) << "]"; + VLOG(4) << "Output" + << " dims_mapping: [" << str_join(out_dims_mapping) << "]"; + VLOG(4) << std::endl; + + return {{input_dist_attr_dst}, {out_dist_attr}}; +} + +SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input, + const DistMetaTensor& output, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis) { + auto output_shape = phi::vectorize(output.dims()); + int out_ndim = output_shape.size(); + auto out_dist_attr = output.dist_attr(); + int out_dims_mapping_size = out_dist_attr.dims_mapping().size(); + auto input_shape = phi::vectorize(input.dims()); + int input_ndim = input_shape.size(); + auto input_dist_attr = input.dist_attr(); + std::vector input_dims_mapping = input_dist_attr.dims_mapping(); + + PADDLE_ENFORCE_EQ( + input_ndim, + out_ndim, + phi::errors::InvalidArgument("The Tensor Input's rank [%d] is not equal " + "to the Tensor Output's rank [%d]", + input_ndim, + out_ndim)); + + PADDLE_ENFORCE_EQ( + out_ndim, + out_dims_mapping_size, + phi::errors::InvalidArgument("The Tensor Output's rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + out_ndim, + out_dims_mapping_size)); + + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + std::string input_axes = alphabet.substr(0, input_ndim); + std::string special_axes = alphabet.substr(input_ndim); + + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + input_axes[axis] = special_axes[i]; + } + + std::string out_axes(input_axes); + + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + out_axes[axis] = special_axes[i]; + } + + std::vector>> axes_sharding_info; + std::vector out_dims_mapping = output.dist_attr().dims_mapping(); + axes_sharding_info.emplace_back(std::make_pair(out_axes, out_dims_mapping)); + + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + + input_dims_mapping = GetDimsMappingForAxes(input_axes, axis_to_dim_map, true); + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + input_dims_mapping[axis] = -1; + } + input_dist_attr.set_dims_mapping(input_dims_mapping); + out_dims_mapping = GetDimsMappingForAxes(out_axes, axis_to_dim_map, true); + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + out_dims_mapping[axis] = -1; + } + out_dist_attr.set_dims_mapping(out_dims_mapping); + + VLOG(4) << "SliceInferSpmdReverse:"; + VLOG(4) << "Einsum Notation: " << input_axes << "-->" << out_axes; + VLOG(4) << "Output" + << " shape: [" << str_join(phi::vectorize(output.dims())) << "] " + << "src_dims_mapping: [" + << str_join(output.dist_attr().dims_mapping()) << "] " + << "dst_dims_mapping: [" << str_join(out_dist_attr.dims_mapping()) + << "]"; + VLOG(4) << "Input shape: [" << str_join(input_shape) << "] " + << "dims_mapping: [" << str_join(input_dims_mapping) << "]\n\n"; + + return {{input_dist_attr}, {out_dist_attr}}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/slice.h b/paddle/phi/infermeta/spmd_rules/slice.h new file mode 100644 index 0000000000000..5a49ad9e0c48d --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/slice.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo SliceInferSpmd(const DistMetaTensor& input, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis); + +SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input, + const DistMetaTensor& output, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/split.cc b/paddle/phi/infermeta/spmd_rules/split.cc index 4bc2a9ce0bdb1..0856fec2e89df 100644 --- a/paddle/phi/infermeta/spmd_rules/split.cc +++ b/paddle/phi/infermeta/spmd_rules/split.cc @@ -92,8 +92,10 @@ SpmdInfo SplitWithNumInferSpmd(const DistMetaTensor& x, int num, int axis) { << str_join(out_dims_mapping) << "]"; } VLOG(4) << std::endl; - - return {{x_dist_attr_dst}, out_dist_attrs}; + // TODO(liuzhenhai): remedy this + // should return list in list [] + // return {{x_dist_attr_dst}, {out_dist_attrs}}; + return {{x_dist_attr_dst}, ToArgDistAttr(out_dist_attrs)}; } SpmdInfo SplitWithNumInferSpmdReverse( @@ -193,8 +195,9 @@ SpmdInfo SplitWithNumInferSpmdReverse( } VLOG(4) << "Input shape: [" << str_join(x_shape) << "] " << "dims_mapping: [" << str_join(x_dims_mapping) << "]\n\n"; - - return {{x_dist_attr}, out_dist_attrs}; + // TODO(liuzhenhai): remedy this + // return {{x_dist_attr}, {out_dist_attrs}}; + return {{x_dist_attr}, ToArgDistAttr(out_dist_attrs)}; } SpmdInfo SplitInferSpmd(const DistMetaTensor& x, diff --git a/paddle/phi/infermeta/spmd_rules/unsqueeze.cc b/paddle/phi/infermeta/spmd_rules/unsqueeze.cc new file mode 100644 index 0000000000000..6af4210f92d80 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/unsqueeze.cc @@ -0,0 +1,206 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights resized. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/unsqueeze.h" +#include +#include + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/dim_trans.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +std::vector MakeUnsqueezeDimTrans( + const std::vector& x_shape, + std::vector* out_shape, + const std::vector& axis) { + int64_t n = static_cast(x_shape.size() + axis.size()); + std::vector ret; + ret.resize(n); + out_shape->resize(n); + fill(ret.begin(), ret.end(), new Singleton()); + fill(out_shape->begin(), out_shape->end(), 1); + + for (int64_t i = 0, j = 0; i < n; i++) { + auto it = find(axis.begin(), axis.end(), i); + + if (it == axis.end()) { + if (x_shape[j] != 1) { + ret[i] = new InputDim(j); + (*out_shape)[i] = x_shape[j]; + } + + j++; + } + } + + return ret; +} + +std::vector MakeUnsqueezeDimTransReverse( + const std::vector& out_shape, + const std::vector& axis, + const int& x_ndim, + const int& out_ndim) { + std::vector ret; + ret.resize(x_ndim); + fill(ret.begin(), ret.end(), new Singleton()); + + for (int64_t i = 0, j = 0; i < out_ndim; i++) { + auto it = find(axis.begin(), axis.end(), i); + + if (it == axis.end()) { + if (out_shape[i] != 1) { + ret[j] = new InputDim(i); + } + + j++; + } + } + + return ret; +} + +SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x, + const std::vector& axis) { + // Step0: Verify input args based on unsqueeze logic + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + auto x_dist_attr_src = x.dist_attr(); + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + x_ndim, + x_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's " + "dims_mapping size [%d] are not matched.", + x_ndim, + x_dims_mapping.size())); + + // Step1: Build the transformation from + // the original shape to the target shape + + std::vector out_shape; + std::vector axis_copy(axis); + + for (int64_t i = 0; i < static_cast(axis_copy.size()); i++) { + if (axis_copy[i] < 0) { + axis_copy[i] += x_ndim + 1; + } + } + + std::vector trans = + MakeUnsqueezeDimTrans(x_shape, &out_shape, axis_copy); + + // Step2: Infer the dims mapping of input (if reshard is + // needed) and output from the dimension transformation. + std::vector> dims_mapping_vec = + InferFromDimTrans(x, trans); + + // Step3: Update the dist attributes of input + // and output with the inferred dims mapping. + TensorDistAttr x_dist_attr_dst(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr out_dist_attr(x_dist_attr_src); + out_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + + VLOG(4) << "UnsqueezeInferSpmd: X shape: [" << str_join(x_shape) + << "] Out shape: [" << str_join(out_shape) << "]"; + VLOG(4) << "Transformation from input to output:"; + for (int64_t i = 0, n = static_cast(trans.size()); i < n; i++) { + DimTrans* t = trans[i]; + VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string(); + } + VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping) + << "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) + << "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1]) + << "]\n\n"; + + CleanUp(); + + return {{x_dist_attr_dst}, {out_dist_attr}}; +} + +SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + const std::vector& axis) { + // Step0: Verify input args based on unsqueeze logic + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + auto out_shape = phi::vectorize(out.dims()); + int out_ndim = out_shape.size(); + auto out_dist_attr_src = out.dist_attr(); + std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + out_ndim, + out_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor Out's rank [%d] and Out's " + "dims_mapping size [%d] are not matched.", + out_ndim, + out_dims_mapping.size())); + + // Step1: Build the transformation from the output shape + // to original shape. This function infers the dims mapping + // from output to input, we first get the transformation + // from output to input so that we can infer the dims mapping + // with the map from output axes to input axes. + + std::vector axis_copy(axis); + + for (int64_t i = 0; i < static_cast(axis_copy.size()); i++) { + if (axis_copy[i] < 0) { + axis_copy[i] += x_ndim + 1; + } + } + + std::vector trans = + MakeUnsqueezeDimTransReverse(out_shape, axis_copy, x_ndim, out_ndim); + + // Step2: Infer the dims mapping of input with + // output's dims_mapping and the transformation. + std::vector> dims_mapping_vec = + InferFromDimTrans(out, trans); + + // Step3: Update the dist attributes of input + // and output with the inferred dims mapping + TensorDistAttr out_dist_attr_dst(out_dist_attr_src); + out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr x_dist_attr(x.dist_attr()); + x_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + + VLOG(4) << "UnsqueezeInferSpmdReverse: Out shape: [" << str_join(out_shape) + << "] X shape: [" << str_join(x_shape) << "]"; + VLOG(4) << "Transformation from output to input:"; + for (int64_t i = 0, n = trans.size(); i < n; i++) { + DimTrans* t = trans[i]; + VLOG(4) << "\tX axis[" << i << "]: " << t->to_string(); + } + VLOG(4) << "Out dims_mapping_src: [" << str_join(out_dims_mapping) << "] " + << "dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]"; + VLOG(4) << "X dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n"; + + CleanUp(); + + return {{x_dist_attr}, {out_dist_attr_dst}}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/unsqueeze.h b/paddle/phi/infermeta/spmd_rules/unsqueeze.h new file mode 100644 index 0000000000000..a2f3490409b83 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/unsqueeze.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x, + const std::vector& axis); + +SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + const std::vector& axis); +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index 31bfba2a0d433..42bbc659b2f2b 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -135,7 +135,7 @@ TensorDistAttr CopyTensorDistAttrForOutput( TensorDistAttr new_dist_attr = TensorDistAttr(); new_dist_attr.set_process_mesh(src_dist_attr.process_mesh()); new_dist_attr.set_batch_dim(src_dist_attr.batch_dim()); - new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims()); + // new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims()); // new_dist_attr.set_annotated(false); TODO unset field is false by default. new_dist_attr.clean_partial_status(); // in partial-stage I, partial is allow // to propagate @@ -164,6 +164,99 @@ TensorDistAttr GetReplicatedDistAttr(const TensorDistAttr& dist_attr) { return dst_dist_attr; } +TensorDistAttr ReplicateTensorDim(const TensorDistAttr& dist_attr, int dim) { + TensorDistAttr dst_dist_attr = CopyTensorDistAttrForOutput(dist_attr); + std::vector dims_mapping = dist_attr.dims_mapping(); + dims_mapping[dim] = kReplicateDim; + dst_dist_attr.set_dims_mapping(dims_mapping); + return dst_dist_attr; +} + +bool IsDimSharded(const TensorDistAttr& dist_attr, int dim) { + return dist_attr.is_shard(-1, dim); +} + +bool PlacementEqual(const std::shared_ptr& a, + const std::shared_ptr& b) { + if (a->is_partial()) { + if (!b->is_partial()) { + return false; + } + auto a_partial = std::dynamic_pointer_cast(a); + auto b_partial = std::dynamic_pointer_cast(b); + return a_partial->get_reduce_type() == b_partial->get_reduce_type(); + } + if (a->is_replicated()) { + if (b->is_replicated()) { + return true; + } + return false; + } + if (!b->is_shard()) { + return false; + } + + auto a_shard = std::dynamic_pointer_cast(a); + auto b_shard = std::dynamic_pointer_cast(b); + return a_shard->get_axis() == b_shard->get_axis(); +} + +TensorDistAttr FromPlacements( + const TensorDistAttr& dist_attr, + const std::vector>& placements) { + TensorDistAttr dst_dist_attr = CopyTensorDistAttrForOutput(dist_attr); + std::vector dims_mapping(dist_attr.dims_mapping().size(), -1); + paddle::flat_hash_map partial_status; + + for (size_t mesh_dim = 0; mesh_dim < placements.size(); mesh_dim++) { + auto& placement = placements[mesh_dim]; + if (placement->is_shard()) { + auto shard_placement = std::dynamic_pointer_cast(placement); + dims_mapping[shard_placement->get_axis()] = mesh_dim; + } + if (placement->is_partial()) { + auto partial_placement = + std::dynamic_pointer_cast(placement); + auto reduce_type = partial_placement->get_reduce_type(); + partial_status[mesh_dim] = reduce_type; + } + } + dst_dist_attr.set_dims_mapping(dims_mapping); + dst_dist_attr.set_partial_status(partial_status); + return dst_dist_attr; +} + +std::vector ToArgDistAttr( + const std::vector& dist_attrs) { + std::vector items_dist_attrs; + std::transform( + dist_attrs.begin(), + dist_attrs.end(), + std::back_inserter(items_dist_attrs), + [](const TensorDistAttr& attr) -> ArgDistAttr { return {attr}; }); + return items_dist_attrs; +} + +std::vector GetLocalShape( + const std::vector shape, + const ProcessMesh& mesh, + const std::vector>& placements) { + auto local_shape = shape; + auto n_placement = placements.size(); + for (size_t i = 0; i < n_placement; i++) { + auto& placement = placements.at(i); + if (placement->is_shard()) { + auto mesh_dim_size = mesh.dim_size(i); + auto shard_dim = + std::dynamic_pointer_cast(placement)->get_axis(); + auto split_size = + (shape.at(shard_dim) + mesh_dim_size - 1) / mesh_dim_size; + local_shape[shard_dim] = split_size; + } + } + return local_shape; +} + std::vector GetDimsMappingForAxes( const std::string& axes, const std::unordered_map& axis_to_dim_map, diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h index cd16a95bceac7..b5b5e207a0ee6 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.h +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -69,6 +69,25 @@ std::vector ResoluteOutputPartialDimension( // Repliacated state TensorDistAttr GetReplicatedDistAttr(const TensorDistAttr& dist_attr); +bool IsDimSharded(const TensorDistAttr& dist_attr, int dim); + +std::vector GetLocalShape( + const std::vector shape, + const ProcessMesh& mesh, + const std::vector>& placements); + +TensorDistAttr FromPlacements( + const TensorDistAttr& dist_attr, + const std::vector>& placements); + +std::vector ToArgDistAttr( + const std::vector& dist_attrs); + +TensorDistAttr ReplicateTensorDim(const TensorDistAttr& dist_attr, int dim); + +bool PlacementEqual(const std::shared_ptr& a, + const std::shared_ptr& b); + // Adaptor for variadic arguments template struct ArgsIterator { @@ -112,6 +131,12 @@ struct VariadicSpmdRuleArgumentParser } } + void operator()(const std::vector& x) { + for (auto& t : x) { + inputs.emplace_back(&t); + } + } + // deal with outputs void operator()(DistMetaTensor* out) { outputs.emplace_back(out); } @@ -125,6 +150,28 @@ struct VariadicSpmdRuleArgumentParser SpmdInfo InferBackward() { return Fn(inputs, outputs); } }; + +using DynamicSpmdFn = SpmdInfo (*)( + const std::vector*>>&); + +template +struct ReplicateInferSpmdDynamicHelper + : public ArgsIterator> { + SpmdInfo Infer() { return Fn(inputs); } + + void operator()(const DistMetaTensor& x) { inputs.emplace_back(&x); } + void operator()(const std::vector& x) { + inputs.emplace_back(&x); + } + + void operator()(std::vector&& x) = delete; + void operator()(DistMetaTensor&& x) = delete; + + std::vector*>> + inputs; +}; } // namespace detail // Get dims mapping for the given axes according to sharding information of diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index d97a16e57fa61..d86b25b7ba224 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -295,10 +295,10 @@ void FlashAttnInferMeta(const MetaTensor& q, out->set_layout(q.layout()); } -void ArangeInferMeta(const MetaTensor& start, - const MetaTensor& end, - const MetaTensor& step, - MetaTensor* out) { +void ArangeTensorInferMeta(const MetaTensor& start, + const MetaTensor& end, + const MetaTensor& step, + MetaTensor* out) { PADDLE_ENFORCE_EQ(phi::product(start.dims()), 1, phi::errors::InvalidArgument( @@ -1258,6 +1258,35 @@ void SendURecvInferMeta(const MetaTensor& x, } } +void SparseMomentumInferMeta(const MetaTensor& param, + const MetaTensor& learning_rate, + const MetaTensor& velocity, + MetaTensor* param_out, + MetaTensor* velocity_out, + MetaTensor* master_param_out) { + auto lr_dims = phi::product(learning_rate.dims()); + PADDLE_ENFORCE_EQ(lr_dims != 0 && lr_dims == 1, + true, + phi::errors::InvalidArgument( + "Learning_rate should be a scalar. But Received " + "LearningRate's dim [%s]", + lr_dims)); + auto param_dim = param.dims(); + PADDLE_ENFORCE_EQ( + param_dim, + velocity.dims(), + phi::errors::InvalidArgument( + "Param and Velocity of SparseMomentumOp should have the same " + "dimension. But received Param's dim [%s] and Velocity [%s].", + param_dim, + velocity.dims())); + param_out->set_dims(param_dim); + velocity_out->set_dims(param_dim); + if (master_param_out != nullptr) { + master_param_out->set_dims(param_dim); + } +} + void SpectralNormInferMeta(const MetaTensor& weight, const MetaTensor& u, const MetaTensor& v, @@ -1382,4 +1411,116 @@ void ViterbiDecodeInferMeta(const MetaTensor& input, scores->set_dtype(length.dtype()); } +void QuantLinearInferMeta(const MetaTensor& x, + const MetaTensor& w, + const MetaTensor& bias, + int in_num_col_dims, + const std::string& activation_type, + bool padding_weights, + float scale_in, + const std::vector& scale_weights, + int quant_round_type, + float quant_max_bound, + float quant_min_bound, + MetaTensor* y) { + auto w_dims = w.dims(); + PADDLE_ENFORCE_EQ( + w_dims.size(), + 2, + phi::errors::InvalidArgument( + "The input Weight of quant_linear is expected to be a 2-D tensor. " + "But received the number of Weight's dimensions is %d, " + "Weight's shape is %s.", + w_dims.size(), + w_dims)); + if (bias) { + auto bias_dims = bias.dims(); + auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1]; + + PADDLE_ENFORCE_LE(bias_dims.size(), + 2, + phi::errors::InvalidArgument( + "The input Bias of quant_linear is expected to be a " + "1-D or 2-D tensor. But " + "received the number of Bias's dimensions is %d, " + "Bias's shape is %s.", + bias_dims.size(), + bias_dims)); + + PADDLE_ENFORCE_EQ( + bias_dims[bias_dims.size() - 1], + w_dims1, + phi::errors::InvalidArgument( + "The last dimension of input Bias is expected be equal " + "to the actual width of input Weight. But received the last " + "dimension of Bias is %d, Bias's shape is %s; " + "the actual width of Weight is %d, Weight's shape is %s.", + bias_dims[bias_dims.size() - 1], + bias_dims, + w_dims1, + w_dims)); + + if (bias_dims.size() == 2) { + PADDLE_ENFORCE_EQ( + bias_dims[0], + 1, + phi::errors::InvalidArgument( + "The first dimension of input Bias is expected to be 1, " + "but received %d, Bias's shape is %s.", + bias_dims[0], + bias_dims)); + } + } + + auto in_dims = x.dims(); + PADDLE_ENFORCE_LT( + in_num_col_dims, + in_dims.size(), + phi::errors::InvalidArgument( + "The attribute in_num_col_dims used to flatten Input to " + "a 2-D tensor, is expected to be less than the number of " + "Input's dimensions. But received in_num_col_dims is %d, " + "the number of Input's dimensions is %d, Input's shape is %s.", + in_num_col_dims, + in_dims.size(), + in_dims)); + + if (!activation_type.empty()) { + PADDLE_ENFORCE_EQ( + activation_type, + "relu", + phi::errors::InvalidArgument( + "The attribute activation_type of quant_linear is expected " + "to be \"relu\", but received %s.", + activation_type.c_str())); + } + + std::vector output_dims; + + auto in_mat_dims = phi::flatten_to_2d(in_dims, in_num_col_dims); + auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0]; + auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1]; + PADDLE_ENFORCE_EQ( + in_mat_dims[1], + w_dims0, + phi::errors::InvalidArgument( + "The input's second dimension and weight's first dimension is " + "expected to be the same. But received input's second dimension is " + "%d, input's shape is %s; weight's first dimension is %d, weight's " + "shape is %s.", + in_mat_dims[1], + in_mat_dims, + w_dims0, + phi::make_ddim({w_dims0, w_dims1}))); + output_dims.reserve(static_cast(in_num_col_dims + 1)); + for (int i = 0; i < in_num_col_dims; ++i) { + output_dims.push_back(in_dims[i]); + } + output_dims.push_back(w_dims1); + + y->set_dims(make_ddim(output_dims)); + y->share_lod(x); + y->set_dtype(x.dtype()); +} + } // namespace phi diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 797835a1abd51..7272941504ff2 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -48,10 +48,10 @@ void AddmmInferMeta(const MetaTensor& input, float alpha, MetaTensor* out); -void ArangeInferMeta(const MetaTensor& start, - const MetaTensor& end, - const MetaTensor& step, - MetaTensor* out); +void ArangeTensorInferMeta(const MetaTensor& start, + const MetaTensor& end, + const MetaTensor& step, + MetaTensor* out); void BoxCoderInferMeta(const MetaTensor& prior_box, const MetaTensor& prior_box_var, @@ -202,6 +202,12 @@ void SendURecvInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* dst_count); +void SparseMomentumInferMeta(const MetaTensor& param, + const MetaTensor& learning_rate, + const MetaTensor& velocity, + MetaTensor* param_out, + MetaTensor* velocity_out, + MetaTensor* master_param_out); void SpectralNormInferMeta(const MetaTensor& weight, const MetaTensor& u, const MetaTensor& v, @@ -219,4 +225,17 @@ void ViterbiDecodeInferMeta(const MetaTensor& input, MetaTensor* path, MetaConfig config = MetaConfig()); +void QuantLinearInferMeta(const MetaTensor& x, + const MetaTensor& w, + const MetaTensor& bias, + int in_num_col_dims, + const std::string& activation_type, + bool padding_weights, + float scale_in, + const std::vector& scale_weights, + int quant_round_type, + float quant_max_bound, + float quant_min_bound, + MetaTensor* y); + } // namespace phi diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 1dd9355549c02..ca470efc9b2a7 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -549,17 +549,17 @@ void CumScalarAxisInferMeta(const MetaTensor& x, void CumWithIndicesInferMeta(const MetaTensor& x, int axis, - int dtype, + DataType dtype, MetaTensor* out, MetaTensor* indices) { auto x_dims = x.dims(); - auto indices_type = phi::TransToPhiDataType(dtype); PADDLE_ENFORCE_EQ( - (indices_type == DataType::INT32 || indices_type == DataType::INT64), + (dtype == DataType::INT32 || dtype == DataType::INT64), true, - phi::errors::InvalidArgument("dtype of indices must be int32 or int64")); + phi::errors::InvalidArgument( + "dtype of indices must be DataType::INT32 or DataType::INT64")); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { int _axis = 0; if (axis < 0) { _axis = axis + x_dims.size(); @@ -606,7 +606,7 @@ void CumWithIndicesInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); out->share_lod(x); indices->set_dims(x_dims); - indices->set_dtype(indices_type); + indices->set_dtype(dtype); indices->share_lod(x); } @@ -672,6 +672,15 @@ void DecodeJpegInferMeta(const MetaTensor& x, } } +void DeQuantizeXPUInferMeta(const MetaTensor& x, + DataType out_dtype, + float scale, + MetaTensor* y) { + auto x_dims = x.dims(); + y->set_dims(x_dims); + y->set_dtype(out_dtype); +} + void DiagEmbedInferMeta( const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out) { auto x_dims = x.dims(); @@ -3263,6 +3272,7 @@ void ReduceInferMeta(const MetaTensor& x, if (axis.empty()) { reduce_all = true; } + ReduceInferMetaBase(x, axis, keep_dim, reduce_all, out); } @@ -3768,6 +3778,15 @@ void FillSplitOutDims(const MetaTensor& x, } } +void QuantizeXPUInferMeta(const MetaTensor& x, + DataType out_dtype, + float scale, + MetaTensor* y) { + auto x_dims = x.dims(); + y->set_dims(x_dims); + y->set_dtype(out_dtype); +} + void SplitInferMeta(const MetaTensor& x, const IntArray& sections, const Scalar& axis, @@ -3969,8 +3988,8 @@ void SqueezeWithXShapeInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* xshape, MetaConfig config) { - SqueezeInferMeta(x, axes, out, config); const auto& x_dims = x.dims(); + SqueezeInferMeta(x, axes, out, config); std::vector xshape_dims(x_dims.size() + 1); xshape_dims[0] = 0; for (int i = 0; i < x_dims.size(); ++i) { @@ -4833,7 +4852,7 @@ void UniqueConsecutiveInferMeta(const MetaTensor& x, bool return_inverse, bool return_counts, const std::vector& axis, - int dtype, + DataType dtype, MetaTensor* out, MetaTensor* index, MetaTensor* counts) { diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index d79b53a71097e..c88a12d34506d 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -137,7 +137,7 @@ void CumScalarAxisInferMeta(const MetaTensor& x, void CumWithIndicesInferMeta(const MetaTensor& x, int axis, - int dtype, + DataType dtype, MetaTensor* out, MetaTensor* indices); @@ -145,6 +145,11 @@ void DecodeJpegInferMeta(const MetaTensor& x, const std::string& mode, MetaTensor* out); +void DeQuantizeXPUInferMeta(const MetaTensor& x, + DataType out_dtype, + float scale, + MetaTensor* y); + void DiagEmbedInferMeta( const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out); @@ -453,6 +458,11 @@ void QrInferMeta(const MetaTensor& x, MetaTensor* q, MetaTensor* r); +void QuantizeXPUInferMeta(const MetaTensor& x, + DataType out_dtype, + float scale, + MetaTensor* y); + void WeightQuantizeInferMeta(const MetaTensor& x, const std::string& algo, MetaTensor* out, @@ -706,7 +716,7 @@ void UniqueConsecutiveInferMeta(const MetaTensor& x, bool return_inverse, bool return_counts, const std::vector& axis, - int dtype, + DataType dtype, MetaTensor* out, MetaTensor* index, MetaTensor* counts); diff --git a/paddle/phi/kernels/arange_kernel.h b/paddle/phi/kernels/arange_kernel.h index 6c879e27d79a6..b4edbbee1a5fc 100644 --- a/paddle/phi/kernels/arange_kernel.h +++ b/paddle/phi/kernels/arange_kernel.h @@ -14,15 +14,23 @@ #pragma once +#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { +template +void ArangeTensorKernel(const Context& dev_ctx, + const DenseTensor& start, + const DenseTensor& end, + const DenseTensor& step, + DenseTensor* out); + template void ArangeKernel(const Context& dev_ctx, - const DenseTensor& start, - const DenseTensor& end, - const DenseTensor& step, + const Scalar& start, + const Scalar& end, + const Scalar& step, DenseTensor* out); template diff --git a/paddle/phi/kernels/batch_norm_grad_kernel.h b/paddle/phi/kernels/batch_norm_grad_kernel.h index ec4753604283f..fc3d2f3d9886a 100644 --- a/paddle/phi/kernels/batch_norm_grad_kernel.h +++ b/paddle/phi/kernels/batch_norm_grad_kernel.h @@ -23,8 +23,8 @@ namespace phi { template void BatchNormGradFunctor(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, @@ -45,8 +45,8 @@ void BatchNormGradFunctor(const Context& dev_ctx, template void BatchNormGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, @@ -67,7 +67,7 @@ template void BatchNormDoubleGradKernel( const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, + const paddle::optional& scale, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, diff --git a/paddle/phi/kernels/batch_norm_kernel.h b/paddle/phi/kernels/batch_norm_kernel.h index edae79941f535..b81f9b0370096 100644 --- a/paddle/phi/kernels/batch_norm_kernel.h +++ b/paddle/phi/kernels/batch_norm_kernel.h @@ -25,8 +25,8 @@ void BatchNormKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& mean, const DenseTensor& variance, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, bool is_test, float momentum, float epsilon, @@ -57,8 +57,8 @@ void BatchNormInferKernel(const Context& dev_ctx, template void phi::BatchNormGradFunctor( \ const ::phi::backend##Context& dev_ctx, \ const DenseTensor& x, \ - const DenseTensor& scale, \ - const DenseTensor& bias, \ + const paddle::optional& scale, \ + const paddle::optional& bias, \ const paddle::optional& mean, \ const paddle::optional& variance, \ const DenseTensor& saved_mean, \ diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index 65bde5601128f..84ec899d9d399 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -423,7 +423,8 @@ PD_REGISTER_KERNEL(cos_triple_grad, phi::dtype::complex, phi::dtype::complex) {} -PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softsign_grad, + SoftsignGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_double_grad, SigmoidDoubleGradKernel) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index a8169df1021d2..e704eefc54ebb 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -230,7 +230,7 @@ PD_REGISTER_KERNEL(expm1, PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {} PD_REGISTER_KERNEL( square, CPU, ALL_LAYOUT, phi::SquareKernel, float, double, int, int64_t) {} -PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softsign, SoftsignKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel) diff --git a/paddle/phi/kernels/cpu/arange_kernel.cc b/paddle/phi/kernels/cpu/arange_kernel.cc index 7f7e555423176..b2684b2f6159a 100644 --- a/paddle/phi/kernels/cpu/arange_kernel.cc +++ b/paddle/phi/kernels/cpu/arange_kernel.cc @@ -21,14 +21,11 @@ limitations under the License. */ namespace phi { template -void ArangeKernel(const Context& dev_ctx, - const DenseTensor& start, - const DenseTensor& end, - const DenseTensor& step, - DenseTensor* out) { - T start_value = start.data()[0]; - T end_value = end.data()[0]; - T step_value = step.data()[0]; +void ArangeFunc(const Context& dev_ctx, + const T& start_value, + const T& end_value, + const T& step_value, + DenseTensor* out) { int64_t size = 0; phi::funcs::GetSize(start_value, end_value, step_value, &size); out->Resize(phi::make_ddim({size})); @@ -40,7 +37,39 @@ void ArangeKernel(const Context& dev_ctx, } } +template +void ArangeTensorKernel(const Context& dev_ctx, + const DenseTensor& start, + const DenseTensor& end, + const DenseTensor& step, + DenseTensor* out) { + T start_value = start.data()[0]; + T end_value = end.data()[0]; + T step_value = step.data()[0]; + ArangeFunc(dev_ctx, start_value, end_value, step_value, out); +} + +template +void ArangeKernel(const Context& dev_ctx, + const Scalar& start, + const Scalar& end, + const Scalar& step, + DenseTensor* out) { + T start_value = start.to(); + T end_value = end.to(); + T step_value = step.to(); + ArangeFunc(dev_ctx, start_value, end_value, step_value, out); +} + } // namespace phi +PD_REGISTER_KERNEL(arange_tensor, + CPU, + ALL_LAYOUT, + phi::ArangeTensorKernel, + float, + double, + int, + int64_t) {} PD_REGISTER_KERNEL( arange, CPU, ALL_LAYOUT, phi::ArangeKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc index 32d06c354a1c2..7dc8f39da0513 100644 --- a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc @@ -38,8 +38,8 @@ using ConstEigenVectorArrayMap = template void BatchNormGradFunctor(const Context& ctx, const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, @@ -139,8 +139,6 @@ void BatchNormGradFunctor(const Context& ctx, inv_var_data = saved_variance.data(); } - ConstEigenVectorArrayMap scale_arr(scale.data(), C); - ConstEigenVectorArrayMap bias_arr(bias.data(), C); ConstEigenVectorArrayMap mean_arr(mean_data, C); ConstEigenVectorArrayMap inv_var_arr(inv_var_data, C); @@ -167,6 +165,20 @@ void BatchNormGradFunctor(const Context& ctx, phi::Copy(ctx, *d_y, ctx.GetPlace(), false, d_x); return; } + auto* Scale = scale.get_ptr(); + auto* Bias = bias.get_ptr(); + Eigen::Array scale_arr(C); + Eigen::Array bias_arr(C); + if (Scale) { + scale_arr = ConstEigenVectorArrayMap(Scale->data(), C); + } else { + scale_arr.setOnes(); + } + if (Bias) { + bias_arr = ConstEigenVectorArrayMap(Bias->data(), C); + } else { + bias_arr.setZero(); + } int scale_coefff = use_global_stats ? 1 : N * sample_size; const auto scale_inv_var_nhw = scale_arr * inv_var_arr / scale_coefff; @@ -295,8 +307,8 @@ void BatchNormGradFunctor(const Context& ctx, template void BatchNormGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, @@ -338,7 +350,7 @@ template void BatchNormDoubleGradKernel( const Context& ctx, const DenseTensor& x, - const DenseTensor& scale, + const paddle::optional& scale, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, @@ -357,7 +369,7 @@ void BatchNormDoubleGradKernel( DenseTensor* scale_grad, DenseTensor* y_grad_grad) { const auto* X = &x; - const auto* Scale = &scale; + const auto* Scale = scale.get_ptr(); const auto* dY = &y_grad; const auto* Saved_mean = &saved_mean; const auto* Saved_variance = &saved_variance; diff --git a/paddle/phi/kernels/cpu/batch_norm_kernel.cc b/paddle/phi/kernels/cpu/batch_norm_kernel.cc index 4db0e2f3f5378..e6acb16a89185 100644 --- a/paddle/phi/kernels/cpu/batch_norm_kernel.cc +++ b/paddle/phi/kernels/cpu/batch_norm_kernel.cc @@ -37,8 +37,8 @@ void BatchNormKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& mean, const DenseTensor& variance, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, bool is_test, float momentum, float epsilon, @@ -167,11 +167,27 @@ void BatchNormKernel(const Context& ctx, // ((x - est_mean) * (inv_var) * scale + bias // formula transform ====> // (x * inv_var * scale) + (bias - est_mean * inv_var * scale) - ConstEigenVectorArrayMap scale_arr(scale.data(), C); - ConstEigenVectorArrayMap bias_arr(bias.data(), C); - Eigen::Array new_scale = inv_std * scale_arr; - Eigen::Array new_bias = - bias_arr - mean_arr * inv_std * scale_arr; + auto* Scale = scale.get_ptr(); + auto* Bias = bias.get_ptr(); + Eigen::Array new_scale(C); + Eigen::Array new_bias(C); + if (Scale && Bias) { + ConstEigenVectorArrayMap scale_arr(Scale->data(), C); + ConstEigenVectorArrayMap bias_arr(Bias->data(), C); + new_scale = inv_std * scale_arr; + new_bias = bias_arr - mean_arr * inv_std * scale_arr; + } else if (Scale) { + ConstEigenVectorArrayMap scale_arr(Scale->data(), C); + new_scale = inv_std * scale_arr; + new_bias = -(mean_arr * inv_std * scale_arr); + } else if (Bias) { + ConstEigenVectorArrayMap bias_arr(Bias->data(), C); + new_scale = inv_std; + new_bias = bias_arr - mean_arr * inv_std; + } else { + new_scale = inv_std; + new_bias = -(mean_arr * inv_std); + } switch (data_layout) { case DataLayout::kNCHW: { diff --git a/paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc b/paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc index 88fb4f4feb91f..acd84a80be2ad 100644 --- a/paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc @@ -28,7 +28,7 @@ void CummaxGradKernel(const Context& dev_ctx, const DenseTensor& indices, const DenseTensor& out_grad, int axis, - int dtype, + DataType dtype, DenseTensor* x_grad) { dev_ctx.template Alloc(x_grad); phi::funcs::SetConstant functor; @@ -36,11 +36,10 @@ void CummaxGradKernel(const Context& dev_ctx, if (axis < 0) { axis = axis + x.dims().size(); } - auto indices_type = phi::TransToPhiDataType(dtype); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { phi::funcs::cpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { phi::funcs::cpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); } @@ -52,7 +51,7 @@ void CumminGradKernel(const Context& dev_ctx, const DenseTensor& indices, const DenseTensor& out_grad, int axis, - int dtype, + DataType dtype, DenseTensor* x_grad) { dev_ctx.template Alloc(x_grad); phi::funcs::SetConstant functor; @@ -60,11 +59,10 @@ void CumminGradKernel(const Context& dev_ctx, if (axis < 0) { axis = axis + x.dims().size(); } - auto indices_type = phi::TransToPhiDataType(dtype); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { phi::funcs::cpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { phi::funcs::cpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); } diff --git a/paddle/phi/kernels/cpu/cum_maxmin_kernel.cc b/paddle/phi/kernels/cpu/cum_maxmin_kernel.cc index be1cfe3d86b1f..881664601b85c 100644 --- a/paddle/phi/kernels/cpu/cum_maxmin_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_maxmin_kernel.cc @@ -149,14 +149,13 @@ template void CummaxKernel(const Context& dev_ctx, const DenseTensor& x, int axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* indices) { - auto indices_type = phi::TransToPhiDataType(dtype); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, out, indices); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, out, indices); } @@ -166,14 +165,13 @@ template void CumminKernel(const Context& dev_ctx, const DenseTensor& x, int axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* indices) { - auto indices_type = phi::TransToPhiDataType(dtype); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, out, indices); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, out, indices); } diff --git a/paddle/phi/kernels/cpu/cumprod_grad_kernel.cc b/paddle/phi/kernels/cpu/cumprod_grad_kernel.cc index 071140a2a5420..a2cc99c59fe2d 100644 --- a/paddle/phi/kernels/cpu/cumprod_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/cumprod_grad_kernel.cc @@ -56,10 +56,10 @@ void CumprodGradKernel(const Context& dev_ctx, Allocator::AllocationPtr x_conj; Allocator::AllocationPtr out_conj; if (phi::IsComplexType(x.dtype())) { - x_conj = const_cast(dev_ctx.GetAllocator()) + x_conj = const_cast(dev_ctx.GetAllocator()) // NOLINT .Allocate(numel * sizeof(T)); auto* x_data_conj = reinterpret_cast(x_conj->ptr()); - out_conj = const_cast(dev_ctx.GetAllocator()) + out_conj = const_cast(dev_ctx.GetAllocator()) // NOLINT .Allocate(numel * sizeof(T)); auto* out_data_conj = reinterpret_cast(out_conj->ptr()); diff --git a/paddle/phi/kernels/cpu/elementwise_divide_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_divide_grad_kernel.cc index a0e2611f92cfc..f09e09a1a14aa 100644 --- a/paddle/phi/kernels/cpu/elementwise_divide_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_divide_grad_kernel.cc @@ -45,8 +45,12 @@ PD_REGISTER_KERNEL(divide_grad, phi::DivideGradKernel, float, double, + int8_t, + uint8_t, + int16_t, int, int64_t, + bool, phi::dtype::complex, phi::dtype::complex) {} @@ -58,5 +62,6 @@ PD_REGISTER_KERNEL(divide_double_grad, double, int, int64_t, + bool, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc b/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc index 20aae406136a2..b7fdefe023e73 100644 --- a/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc @@ -59,7 +59,11 @@ PD_REGISTER_KERNEL(divide, phi::DivideKernel, float, double, + int8_t, + uint8_t, + int16_t, int, int64_t, + bool, complex64, complex128) {} diff --git a/paddle/phi/kernels/cpu/expand_grad_kernel.cc b/paddle/phi/kernels/cpu/expand_grad_kernel.cc index 5cbbf253b747d..82db6a17101ab 100644 --- a/paddle/phi/kernels/cpu/expand_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/expand_grad_kernel.cc @@ -27,4 +27,12 @@ PD_REGISTER_KERNEL(expand_grad, float, double, int, - int64_t) {} + int64_t, + bool, + int16_t, + uint8_t, + int8_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/expand_kernel.cc b/paddle/phi/kernels/cpu/expand_kernel.cc index 2df833d0f9c30..f0a1f89762ffb 100644 --- a/paddle/phi/kernels/cpu/expand_kernel.cc +++ b/paddle/phi/kernels/cpu/expand_kernel.cc @@ -28,4 +28,11 @@ PD_REGISTER_KERNEL(expand, double, int, int64_t, - bool) {} + bool, + int16_t, + uint8_t, + int8_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/full_kernel.cc b/paddle/phi/kernels/cpu/full_kernel.cc index 5e37d3dfa262f..bb2533490cfc2 100644 --- a/paddle/phi/kernels/cpu/full_kernel.cc +++ b/paddle/phi/kernels/cpu/full_kernel.cc @@ -88,13 +88,14 @@ void FullLikeKernel(const Context& dev_ctx, template void FullIntArrayKernel(const Context& dev_ctx, - const IntArray& val, + const std::vector& shape, DataType dtype UNUSED, DenseTensor* out) { - out->Resize(phi::make_ddim({static_cast(val.GetData().size())})); + out->Resize(phi::make_ddim({static_cast(shape.size())})); T* out_data = dev_ctx.template Alloc(out); - for (size_t i = 0; i < val.GetData().size(); ++i) { - out_data[i] = static_cast(val.GetData()[i]); + for (size_t i = 0; i < shape.size(); ++i) { + int64_t val = shape[i]; + out_data[i] = static_cast(val); } } diff --git a/paddle/phi/kernels/cpu/gather_nd_grad_kernel.cc b/paddle/phi/kernels/cpu/gather_nd_grad_kernel.cc index 5aaec6f6139e5..59a3c2215a870 100644 --- a/paddle/phi/kernels/cpu/gather_nd_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/gather_nd_grad_kernel.cc @@ -58,8 +58,13 @@ PD_REGISTER_KERNEL(gather_nd_grad, CPU, ALL_LAYOUT, phi::GatherNdGradKernel, + bool, float, double, - int64_t, int, - uint8_t) {} + int8_t, + int64_t, + int16_t, + uint8_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/gather_nd_kernel.cc b/paddle/phi/kernels/cpu/gather_nd_kernel.cc index 8ae866a1c8add..1ccf76f6be986 100644 --- a/paddle/phi/kernels/cpu/gather_nd_kernel.cc +++ b/paddle/phi/kernels/cpu/gather_nd_kernel.cc @@ -52,10 +52,13 @@ PD_REGISTER_KERNEL(gather_nd, CPU, ALL_LAYOUT, phi::GatherNdKernel, + bool, float, double, - int64_t, int, + int8_t, + int64_t, int16_t, - bool, - uint8_t) {} + uint8_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc index 1a733e5c0ab87..7385a928c1791 100644 --- a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -238,4 +238,11 @@ PD_REGISTER_KERNEL(index_put_grad, double, int, int64_t, - bool) {} + bool, + int16_t, + uint8_t, + int8_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/index_put_kernel.cc b/paddle/phi/kernels/cpu/index_put_kernel.cc index b2a47d61de0e3..f587978c2c2ad 100644 --- a/paddle/phi/kernels/cpu/index_put_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_kernel.cc @@ -171,4 +171,11 @@ PD_REGISTER_KERNEL(index_put, double, int, int64_t, - bool) {} + bool, + int16_t, + uint8_t, + int8_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc index 14937ea613936..d798c6b81c966 100644 --- a/paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc @@ -93,9 +93,9 @@ void InstanceNormGradKernel(const Context& dev_ctx, } auto scale_e = - scale_ptr - ? EigenVector::Flatten(*scale_ptr) - : EigenVector::Flatten(const_cast(scale_data)); + scale_ptr ? EigenVector::Flatten(*scale_ptr) + : EigenVector::Flatten( + const_cast(scale_data)); // NOLINT auto mean_e = EigenVector::Flatten(saved_mean); auto inv_var_e = EigenVector::Flatten(saved_variance); auto dy_e = EigenVector::Flatten(d_y); diff --git a/paddle/phi/kernels/cpu/instance_norm_kernel.cc b/paddle/phi/kernels/cpu/instance_norm_kernel.cc index 00069d5dec9a8..1242babaf0c83 100644 --- a/paddle/phi/kernels/cpu/instance_norm_kernel.cc +++ b/paddle/phi/kernels/cpu/instance_norm_kernel.cc @@ -111,14 +111,14 @@ void InstanceNormKernel(const Context& dev_ctx, set_constant(dev_ctx, &bias_data, static_cast(0)); } auto scale_e = - scale_ptr - ? EigenVector::Flatten(*scale_ptr) - : EigenVector::Flatten(const_cast(scale_data)); + scale_ptr ? EigenVector::Flatten(*scale_ptr) + : EigenVector::Flatten( + const_cast(scale_data)); // NOLINT auto scale_arr = scale_e.reshape(C_shape); - auto bias_e = - bias_ptr - ? EigenVector::Flatten(*bias_ptr) - : EigenVector::Flatten(const_cast(bias_data)); + auto bias_e = bias_ptr + ? EigenVector::Flatten(*bias_ptr) + : EigenVector::Flatten( + const_cast(bias_data)); // NOLINT auto bias_arr = bias_e.reshape(C_shape); dev_ctx.template Alloc(y); diff --git a/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc index 0713725127190..45fd22f8a0f18 100644 --- a/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc +++ b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc @@ -34,7 +34,7 @@ void LapackSVD(const T* x_data, T* eigenvalues_data, int rows, int cols) { char jobz = 'N'; int mx = std::max(rows, cols); int mn = std::min(rows, cols); - T* a = const_cast(x_data); + T* a = const_cast(x_data); // NOLINT int lda = rows; int lwork = 3 * mn + std::max(mx, 7 * mn); std::vector work(lwork); diff --git a/paddle/phi/kernels/cpu/reduce_sum_kernel.cc b/paddle/phi/kernels/cpu/reduce_sum_kernel.cc index aa690227bda01..e18a50eddad34 100644 --- a/paddle/phi/kernels/cpu/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_sum_kernel.cc @@ -51,6 +51,8 @@ PD_REGISTER_KERNEL(sum_raw, phi::dtype::float16, phi::dtype::bfloat16, int16_t, + int8_t, + uint8_t, int, int64_t, complex64, diff --git a/paddle/phi/kernels/cpu/scale_kernel.cc b/paddle/phi/kernels/cpu/scale_kernel.cc index a7aea9210a1c6..fac805c90ba63 100644 --- a/paddle/phi/kernels/cpu/scale_kernel.cc +++ b/paddle/phi/kernels/cpu/scale_kernel.cc @@ -58,6 +58,7 @@ PD_REGISTER_KERNEL(scale, CPU, ALL_LAYOUT, phi::ScaleKernel, + bool, float, double, phi::dtype::bfloat16, diff --git a/paddle/phi/kernels/cpu/set_value_grad_kernel.cc b/paddle/phi/kernels/cpu/set_value_grad_kernel.cc index dad7628dcf30a..ed35513d98550 100644 --- a/paddle/phi/kernels/cpu/set_value_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/set_value_grad_kernel.cc @@ -28,6 +28,10 @@ PD_REGISTER_KERNEL(set_value_grad, int, int64_t, bool, + int16_t, + uint8_t, + int8_t, + phi::dtype::bfloat16, phi::dtype::float16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/set_value_kernel.cc b/paddle/phi/kernels/cpu/set_value_kernel.cc index 4b0c0415e4834..dd48ae94a96e7 100644 --- a/paddle/phi/kernels/cpu/set_value_kernel.cc +++ b/paddle/phi/kernels/cpu/set_value_kernel.cc @@ -28,7 +28,11 @@ PD_REGISTER_KERNEL(set_value, int, int64_t, bool, + int16_t, + uint8_t, + int8_t, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} PD_REGISTER_KERNEL(set_value_with_tensor, @@ -40,6 +44,10 @@ PD_REGISTER_KERNEL(set_value_with_tensor, int, int64_t, bool, + int16_t, + uint8_t, + int8_t, + phi::dtype::bfloat16, phi::dtype::float16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/sign_kernel.cc b/paddle/phi/kernels/cpu/sign_kernel.cc index 9ded252c5c592..f03f39f80dcbe 100644 --- a/paddle/phi/kernels/cpu/sign_kernel.cc +++ b/paddle/phi/kernels/cpu/sign_kernel.cc @@ -21,4 +21,13 @@ limitations under the License. */ // See Note [ Why still include the fluid headers? ] #include "paddle/phi/common/bfloat16.h" -PD_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, phi::SignKernel, float, double) {} +PD_REGISTER_KERNEL(sign, + CPU, + ALL_LAYOUT, + phi::SignKernel, + int8_t, + int16_t, + int32_t, + int64_t, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/slice_grad_kernel.cc b/paddle/phi/kernels/cpu/slice_grad_kernel.cc index 0ecb3940fb275..b7ff211bd004e 100644 --- a/paddle/phi/kernels/cpu/slice_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/slice_grad_kernel.cc @@ -23,37 +23,48 @@ PD_REGISTER_KERNEL(slice_grad, ALL_LAYOUT, phi::SliceGradKernel, bool, - uint8_t, int, + uint8_t, int64_t, float, double, + int16_t, + int8_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, - phi::dtype::complex, - phi::dtype::bfloat16) {} + phi::dtype::complex) {} PD_REGISTER_KERNEL(slice_array_grad, CPU, ALL_LAYOUT, phi::SliceArrayGradKernel, bool, - int, - int64_t, float, double, + int, + int8_t, + int64_t, + int16_t, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, - phi::dtype::complex, - phi::dtype::bfloat16) {} + phi::dtype::complex) {} PD_REGISTER_KERNEL(slice_array_dense_grad, CPU, ALL_LAYOUT, phi::SliceArrayDenseGradKernel, bool, - int, - int64_t, float, double, + int, + int8_t, + int64_t, + int16_t, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, - phi::dtype::complex, - phi::dtype::bfloat16) {} + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/slice_kernel.cc b/paddle/phi/kernels/cpu/slice_kernel.cc index a6eec654775eb..9c75f64214f12 100644 --- a/paddle/phi/kernels/cpu/slice_kernel.cc +++ b/paddle/phi/kernels/cpu/slice_kernel.cc @@ -23,39 +23,48 @@ PD_REGISTER_KERNEL(slice, ALL_LAYOUT, phi::SliceKernel, bool, - uint8_t, int, + uint8_t, int64_t, float, double, + int16_t, + int8_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, - phi::dtype::complex, - phi::dtype::bfloat16) {} + phi::dtype::complex) {} PD_REGISTER_KERNEL(slice_array, CPU, ALL_LAYOUT, phi::SliceArrayKernel, bool, - int, - uint8_t, - int64_t, float, double, + int, + int8_t, + int64_t, + int16_t, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, - phi::dtype::complex, - phi::dtype::bfloat16) {} + phi::dtype::complex) {} PD_REGISTER_KERNEL(slice_array_dense, CPU, ALL_LAYOUT, phi::SliceArrayDenseKernel, bool, - int, - uint8_t, - int64_t, float, double, + int, + int8_t, + int64_t, + int16_t, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, - phi::dtype::complex, - phi::dtype::bfloat16) {} + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/stack_grad_kernel.cc b/paddle/phi/kernels/cpu/stack_grad_kernel.cc index 36057ce07711c..e55ffa865c99a 100644 --- a/paddle/phi/kernels/cpu/stack_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/stack_grad_kernel.cc @@ -52,12 +52,15 @@ PD_REGISTER_KERNEL(stack_grad, CPU, ALL_LAYOUT, phi::StackGradKernel, + bool, float, double, - bool, - int64_t, int, - uint8_t, int8_t, + int16_t, + int64_t, + uint8_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/stack_kernel.cc b/paddle/phi/kernels/cpu/stack_kernel.cc index 39a9dfba1b6a9..acf0425efa5a4 100644 --- a/paddle/phi/kernels/cpu/stack_kernel.cc +++ b/paddle/phi/kernels/cpu/stack_kernel.cc @@ -64,12 +64,15 @@ PD_REGISTER_KERNEL(stack, CPU, ALL_LAYOUT, phi::StackKernel, + bool, float, double, - bool, - int64_t, int, - uint8_t, int8_t, + int64_t, + int16_t, + uint8_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc b/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc index 21ef18dbd90cf..1c15f04f7dd41 100644 --- a/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc @@ -23,11 +23,13 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad, CPU, ALL_LAYOUT, phi::StridedSliceRawGradKernel, - bool, - int, - int64_t, float, double, + bool, + int64_t, + int16_t, + int, + phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -36,11 +38,15 @@ PD_REGISTER_KERNEL(strided_slice_array_grad, CPU, ALL_LAYOUT, phi::StridedSliceArrayGradKernel, - bool, - int, - int64_t, float, double, + bool, + int64_t, + int16_t, + int, + uint8_t, + int8_t, + phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/strided_slice_kernel.cc b/paddle/phi/kernels/cpu/strided_slice_kernel.cc index e9a1671bcc4c9..785d7e55cb12f 100644 --- a/paddle/phi/kernels/cpu/strided_slice_kernel.cc +++ b/paddle/phi/kernels/cpu/strided_slice_kernel.cc @@ -24,10 +24,14 @@ PD_REGISTER_KERNEL(strided_slice_raw, ALL_LAYOUT, phi::StridedSliceRawKernel, bool, - int, - int64_t, float, double, + int, + int8_t, + int64_t, + int16_t, + uint8_t, + phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -37,10 +41,14 @@ PD_REGISTER_KERNEL(strided_slice_array, ALL_LAYOUT, phi::StridedSliceArrayKernel, bool, - int, - int64_t, float, double, + int, + int8_t, + int64_t, + int16_t, + uint8_t, + phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/svd_kernel.cc b/paddle/phi/kernels/cpu/svd_kernel.cc index a3f6f38fe4780..c2018867782f5 100644 --- a/paddle/phi/kernels/cpu/svd_kernel.cc +++ b/paddle/phi/kernels/cpu/svd_kernel.cc @@ -28,7 +28,7 @@ void LapackSvd( char jobz = full ? 'A' : 'S'; int mx = std::max(rows, cols); int mn = std::min(rows, cols); - T* a = const_cast(X); + T* a = const_cast(X); // NOLINT int lda = rows; int ldu = rows; int ldvt = full ? cols : mn; diff --git a/paddle/phi/kernels/cpu/transpose_grad_kernel.cc b/paddle/phi/kernels/cpu/transpose_grad_kernel.cc index cc3340edcb4ab..627bc942e4678 100644 --- a/paddle/phi/kernels/cpu/transpose_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/transpose_grad_kernel.cc @@ -31,6 +31,7 @@ PD_REGISTER_KERNEL(transpose_grad, uint8_t, int8_t, int16_t, + phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/unique_consecutive_kernel.cc b/paddle/phi/kernels/cpu/unique_consecutive_kernel.cc index d0d674d06ee2b..8c3a14a5edf76 100644 --- a/paddle/phi/kernels/cpu/unique_consecutive_kernel.cc +++ b/paddle/phi/kernels/cpu/unique_consecutive_kernel.cc @@ -30,12 +30,11 @@ void UniqueConsecutiveKernel(const Context& dev_ctx, bool return_inverse, bool return_counts, const std::vector& axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* index, DenseTensor* counts) { - auto data_type = phi::TransToPhiDataType(dtype); - if (data_type == phi::DataType::INT32) { + if (dtype == phi::DataType::INT32) { PADDLE_ENFORCE_LE( x.numel(), INT_MAX, @@ -48,14 +47,14 @@ void UniqueConsecutiveKernel(const Context& dev_ctx, if (axis.empty()) { phi::VisitDataTypeTiny( - data_type, + dtype, UniqueConsecutiveFlattenedTensorFunctor( dev_ctx, x, out, return_inverse, return_counts, index, counts)); } else { int valid_axis = axis[0]; if (valid_axis < 0) valid_axis += x.dims().size(); phi::VisitDataTypeTiny( - data_type, + dtype, UniqueConsecutiveDimFunctor(dev_ctx, x, out, diff --git a/paddle/phi/kernels/cum_maxmin_grad_kernel.h b/paddle/phi/kernels/cum_maxmin_grad_kernel.h index 13a6b7ee6ec1e..a018a3bfcc940 100644 --- a/paddle/phi/kernels/cum_maxmin_grad_kernel.h +++ b/paddle/phi/kernels/cum_maxmin_grad_kernel.h @@ -24,7 +24,7 @@ void CummaxGradKernel(const Context& dev_ctx, const DenseTensor& indices, const DenseTensor& out_grad, int axis, - int dtype, + DataType dtype, DenseTensor* x_grad); template @@ -33,7 +33,7 @@ void CumminGradKernel(const Context& dev_ctx, const DenseTensor& indices, const DenseTensor& out_grad, int axis, - int dtype, + DataType dtype, DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/cum_maxmin_kernel.h b/paddle/phi/kernels/cum_maxmin_kernel.h index 37755deb5d91e..19e3fc9da0b80 100644 --- a/paddle/phi/kernels/cum_maxmin_kernel.h +++ b/paddle/phi/kernels/cum_maxmin_kernel.h @@ -22,7 +22,7 @@ template void CummaxKernel(const Context& dev_ctx, const DenseTensor& x, int axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* indices); @@ -30,7 +30,7 @@ template void CumminKernel(const Context& dev_ctx, const DenseTensor& x, int axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* indices); diff --git a/paddle/phi/kernels/elementwise_divide_kernel.h b/paddle/phi/kernels/elementwise_divide_kernel.h index c5c9993826b54..8a78435950c0f 100644 --- a/paddle/phi/kernels/elementwise_divide_kernel.h +++ b/paddle/phi/kernels/elementwise_divide_kernel.h @@ -25,14 +25,24 @@ void DivideKernel(const Context& dev_ctx, const DenseTensor& y, DenseTensor* out); +template +void Divide(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* dense_out) { + MetaTensor meta_out(dense_out); + ElementwiseInferMeta(x, y, &meta_out); + if (x.initialized()) { + DivideKernel(dev_ctx, x, y, dense_out); + } +} + template DenseTensor Divide(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y) { DenseTensor dense_out; - MetaTensor meta_out(&dense_out); - ElementwiseInferMeta(x, y, &meta_out); - DivideKernel(dev_ctx, x, y, &dense_out); + Divide(dev_ctx, x, y, &dense_out); return dense_out; } diff --git a/paddle/phi/kernels/frobenius_norm_grad_kernel.h b/paddle/phi/kernels/frobenius_norm_grad_kernel.h index 65db8dd9e0a10..78494c4423f7e 100644 --- a/paddle/phi/kernels/frobenius_norm_grad_kernel.h +++ b/paddle/phi/kernels/frobenius_norm_grad_kernel.h @@ -16,6 +16,7 @@ #include +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -25,7 +26,7 @@ void FrobeniusNormGradKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& dout, - const std::vector& axis, + const IntArray& axis, bool keep_dim, bool reduce_all, DenseTensor* dx); diff --git a/paddle/phi/kernels/frobenius_norm_kernel.h b/paddle/phi/kernels/frobenius_norm_kernel.h index 30122cb416094..45ddb6123b85d 100644 --- a/paddle/phi/kernels/frobenius_norm_kernel.h +++ b/paddle/phi/kernels/frobenius_norm_kernel.h @@ -16,6 +16,7 @@ #include +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -23,7 +24,7 @@ namespace phi { template void FrobeniusNormKernel(const Context& ctx, const DenseTensor& x, - const std::vector& axis, + const IntArray& axis, bool keep_dim, bool reduce_all, DenseTensor* out); diff --git a/paddle/phi/kernels/full_kernel.h b/paddle/phi/kernels/full_kernel.h index cef58433e9e04..b10e02658fe75 100644 --- a/paddle/phi/kernels/full_kernel.h +++ b/paddle/phi/kernels/full_kernel.h @@ -92,7 +92,7 @@ DenseTensor FullLike(const Context& dev_ctx, template void FullIntArrayKernel(const Context& dev_ctx, - const IntArray& val, + const std::vector& shape, DataType dtype, DenseTensor* out); diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index b2c2d493c48ad..06b59644cf11d 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -107,6 +107,14 @@ struct Conj { } }; +// T is phi::dtype::complex or phi::dtype::complex +template +struct Real { + HOSTDEVICE ComplexType operator()(const ComplexType& val) const { + return ComplexType(val.real); + } +}; + // sine'(x) = cos(x) template struct SinGradFunctor : public BaseActivationFunctor { @@ -2129,6 +2137,24 @@ struct SoftsignGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct SoftsignGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + ComplexType one = static_cast>(1.0f); + auto temp = (-x / (one + x.abs()).square()).unaryExpr(Real()); + + dx.device(d) = dout * (one / (one + x.abs()) + temp * x / x.abs()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + // sigmoid(x) = 1 / (1 + exp(-x)) template struct SigmoidFunctor : public BaseActivationFunctor { @@ -4339,6 +4365,17 @@ struct CudaSoftsignFunctor : public BaseActivationFunctor { } }; +template +struct CudaSoftsignFunctor> + : public BaseActivationFunctor> { + using Complex = ComplexType; + Complex one = static_cast(1.0f); + + __device__ __forceinline__ Complex operator()(const Complex x) const { + return x / (one + static_cast(abs(x))); + } +}; + template struct CudaSoftsignGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); @@ -4353,6 +4390,23 @@ struct CudaSoftsignGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaSoftsignGradFunctor> + : public BaseActivationFunctor> { + using Complex = ComplexType; + Complex one = static_cast(1.0f); + + __device__ __forceinline__ Complex operator()(const Complex dout, + const Complex x) const { + Complex abs_x = static_cast(abs(x)); + Complex abs_x_plus = one + abs_x; + Complex temp = static_cast((-x / (abs_x_plus * abs_x_plus)).real); + return dout * (one / abs_x_plus + temp * x / abs_x); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaSigmoidFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; diff --git a/paddle/phi/kernels/funcs/dropout_impl.cu.h b/paddle/phi/kernels/funcs/dropout_impl.cu.h index 23756b3bdde96..14696b45c78db 100644 --- a/paddle/phi/kernels/funcs/dropout_impl.cu.h +++ b/paddle/phi/kernels/funcs/dropout_impl.cu.h @@ -126,15 +126,18 @@ struct DstMaskFunctor { }; template -__global__ void VectorizedRandomGenerator(const size_t n, - uint64_t seed, - const float dropout_prob, - const T* src, - uint8_t* mask, - T* dst, - bool is_upscale_in_train, - uint64_t increment, - size_t main_offset) { +__global__ void VectorizedRandomGenerator( + unsigned int + identifier, /* This is used to relate kernel to cudaGraph nodes*/ + const size_t n, + uint64_t seed, + const float dropout_prob, + const T* src, + uint8_t* mask, + T* dst, + bool is_upscale_in_train, + uint64_t increment, + size_t main_offset) { size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); static constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; @@ -334,7 +337,6 @@ void DropoutFwGPUKernelDriver( dropout_prob, x_data, mask_data, - increment, main_offset, mask_functor, @@ -347,30 +349,42 @@ void DropoutFwGPUKernelDriver( } else { bool copy_in_kernel = GetSeedDataAndIncrement( dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment); - -#define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator - PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!is_fix_seed, - PD_DROPOUT_KERNEL_NAME, - grid_size, - block_size, - 0, - stream, - offset, - KERNEL_PARAMS.As(1), - KERNEL_PARAMS.As(7), - size, - seed_data, - dropout_prob, - x_data, - mask_data, - y_data, - upscale_in_train, - increment, - main_offset); -#undef PD_DROPOUT_KERNEL_NAME + void* functionPtr = + reinterpret_cast(&(VectorizedRandomGenerator)); + cudaFunction_t cudaFunc; + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); + const phi::GPUContext* dev_ctx_p = &dev_ctx; + phi::backends::gpu::CUDAGraphNodeLauncher::parameterSetter_t + parameterSetter = [offset, dev_ctx_p]( + phi::backends::gpu::CUDAKernelParams& params) { + uint64_t seed_data, increment; + phi::funcs::GetSeedDataAndIncrement( + *dev_ctx_p, nullptr, false, 0, offset, &seed_data, &increment); + params.As(2) = seed_data; + params.As(8) = increment; + VLOG(10) << "CUDA_GRAPH seed_data = " << seed_data + << ", increment = " << increment; + }; + phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t + cudaKernelCallback = [=](unsigned int id) { + VectorizedRandomGenerator + <<>>(id, + size, + seed_data, + dropout_prob, + x_data, + mask_data, + y_data, + upscale_in_train, + increment, + main_offset); + }; + phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch( + cudaFunc, parameterSetter, cudaKernelCallback); + + VLOG(10) << "NON_CUDA_GRAPH seed_data = " << seed_data + << ", increment = " << increment; } - VLOG(4) << "Dropout seed: " << seed << ", offset: " << offset - << ", seed_data:" << seed_data; } else { if (upscale_in_train) { // y = x diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cc b/paddle/phi/kernels/funcs/eigen/broadcast.cc index e0b074548c91d..04e13a6799931 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cc +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cc @@ -82,14 +82,21 @@ INSTANTIATION(EigenBroadcast, dtype::complex); INSTANTIATION(EigenBroadcast, float); INSTANTIATION(EigenBroadcast, double); INSTANTIATION(EigenBroadcast, int); +INSTANTIATION(EigenBroadcast, int8_t); +INSTANTIATION(EigenBroadcast, uint8_t); +INSTANTIATION(EigenBroadcast, int16_t); INSTANTIATION(EigenBroadcast, int64_t); INSTANTIATION(EigenBroadcastGrad, bool); INSTANTIATION(EigenBroadcastGrad, float); INSTANTIATION(EigenBroadcastGrad, dtype::float16); +INSTANTIATION(EigenBroadcastGrad, dtype::bfloat16); INSTANTIATION(EigenBroadcastGrad, dtype::complex); INSTANTIATION(EigenBroadcastGrad, dtype::complex); INSTANTIATION(EigenBroadcastGrad, double); INSTANTIATION(EigenBroadcastGrad, int); +INSTANTIATION(EigenBroadcastGrad, int8_t); +INSTANTIATION(EigenBroadcastGrad, uint8_t); +INSTANTIATION(EigenBroadcastGrad, int16_t); INSTANTIATION(EigenBroadcastGrad, int64_t); template struct EigenBroadcastGrad; template struct EigenBroadcastGrad; diff --git a/paddle/phi/kernels/funcs/eigen/pad.cc b/paddle/phi/kernels/funcs/eigen/pad.cc index 8041fc4ae175e..946bff40544ee 100644 --- a/paddle/phi/kernels/funcs/eigen/pad.cc +++ b/paddle/phi/kernels/funcs/eigen/pad.cc @@ -61,9 +61,12 @@ struct EigenPad { INSTANTIATION(EigenPad, bool); INSTANTIATION(EigenPad, uint8_t); INSTANTIATION(EigenPad, int); +INSTANTIATION(EigenPad, int8_t); +INSTANTIATION(EigenPad, int16_t); INSTANTIATION(EigenPad, int64_t); INSTANTIATION(EigenPad, float); INSTANTIATION(EigenPad, double); +INSTANTIATION(EigenPad, dtype::float16); INSTANTIATION(EigenPad, dtype::bfloat16); INSTANTIATION(EigenPad, dtype::complex); INSTANTIATION(EigenPad, dtype::complex); diff --git a/paddle/phi/kernels/funcs/eigen/pad.cu b/paddle/phi/kernels/funcs/eigen/pad.cu index c4a3dd9ecc4f5..6d371e9a85291 100644 --- a/paddle/phi/kernels/funcs/eigen/pad.cu +++ b/paddle/phi/kernels/funcs/eigen/pad.cu @@ -61,6 +61,8 @@ struct EigenPad { INSTANTIATION(EigenPad, bool); INSTANTIATION(EigenPad, uint8_t); INSTANTIATION(EigenPad, int); +INSTANTIATION(EigenPad, int8_t); +INSTANTIATION(EigenPad, int16_t); INSTANTIATION(EigenPad, int64_t); INSTANTIATION(EigenPad, float); INSTANTIATION(EigenPad, double); diff --git a/paddle/phi/kernels/funcs/eigen/scale.cc b/paddle/phi/kernels/funcs/eigen/scale.cc index 7e2d463a9fab1..b3e5246a57226 100644 --- a/paddle/phi/kernels/funcs/eigen/scale.cc +++ b/paddle/phi/kernels/funcs/eigen/scale.cc @@ -39,6 +39,7 @@ struct EigenScale { } }; +template struct EigenScale; template struct EigenScale; template struct EigenScale; template struct EigenScale; diff --git a/paddle/phi/kernels/funcs/eigen/scale.cu b/paddle/phi/kernels/funcs/eigen/scale.cu index 0474068fc40eb..ffc8118e0adae 100644 --- a/paddle/phi/kernels/funcs/eigen/scale.cu +++ b/paddle/phi/kernels/funcs/eigen/scale.cu @@ -38,6 +38,7 @@ struct EigenScale { } }; +template struct EigenScale; template struct EigenScale; template struct EigenScale; template struct EigenScale; diff --git a/paddle/phi/kernels/funcs/eigen/sign.cc b/paddle/phi/kernels/funcs/eigen/sign.cc index 450df3c764c12..e71257f3f74aa 100644 --- a/paddle/phi/kernels/funcs/eigen/sign.cc +++ b/paddle/phi/kernels/funcs/eigen/sign.cc @@ -29,6 +29,10 @@ struct EigenSign { } }; +template struct EigenSign; +template struct EigenSign; +template struct EigenSign; +template struct EigenSign; template struct EigenSign; template struct EigenSign; diff --git a/paddle/phi/kernels/funcs/eigen/sign.cu b/paddle/phi/kernels/funcs/eigen/sign.cu index b630ba7bb6c40..58a4fe36232b6 100644 --- a/paddle/phi/kernels/funcs/eigen/sign.cu +++ b/paddle/phi/kernels/funcs/eigen/sign.cu @@ -29,6 +29,10 @@ struct EigenSign { } }; +template struct EigenSign; +template struct EigenSign; +template struct EigenSign; +template struct EigenSign; template struct EigenSign; template struct EigenSign; template struct EigenSign; diff --git a/paddle/phi/kernels/funcs/eigen/slice.cu b/paddle/phi/kernels/funcs/eigen/slice.cu index ade58d0698759..5591fc076fd8f 100644 --- a/paddle/phi/kernels/funcs/eigen/slice.cu +++ b/paddle/phi/kernels/funcs/eigen/slice.cu @@ -61,6 +61,8 @@ struct EigenSlice { INSTANTIATION(EigenSlice, bool); INSTANTIATION(EigenSlice, uint8_t); INSTANTIATION(EigenSlice, int); +INSTANTIATION(EigenSlice, int8_t); +INSTANTIATION(EigenSlice, int16_t); INSTANTIATION(EigenSlice, int64_t); INSTANTIATION(EigenSlice, float); INSTANTIATION(EigenSlice, double); diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu index 481d6fea2d4ae..716d5c3979459 100644 --- a/paddle/phi/kernels/funcs/fc_functor.cu +++ b/paddle/phi/kernels/funcs/fc_functor.cu @@ -19,6 +19,12 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/fc_functor.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" +#include "paddle/phi/kernels/funcs/quant_dequant.h" +#include "paddle/phi/kernels/matmul_kernel.h" + namespace phi { namespace funcs { @@ -367,5 +373,86 @@ template class FCFunctor; template class FCFunctor; template class FCFunctor; +template +void FCInt8Functor::operator()( + const DeviceContext& context, + const int M, + const int N, + const int K, + const T* X, + const DenseTensor* w_tensor, + T* Y, + float scale_in, + std::vector scale_weights, + int quant_round_type, + float quant_max_bound, + float quant_min_bound, + const T* B, + bool relu, + bool padding_weights) { + PADDLE_ENFORCE_EQ(padding_weights, + false, + errors::PermissionDenied( + "Weight padding in fc can not be used in GPU scope.")); + const int8_t* W = w_tensor->data(); + + DenseTensor quant_x_tensor, quant_y_tensor; + quant_x_tensor.Resize(phi::make_ddim({M, K})); + quant_y_tensor.Resize(phi::make_ddim({M, N})); + context.template Alloc(&quant_x_tensor, + quant_x_tensor.numel() * sizeof(int8_t)); + context.template Alloc(&quant_y_tensor, + quant_y_tensor.numel() * sizeof(int32_t)); + LaunchQuantKernelWithVecSize(X, + quant_x_tensor.data(), + scale_in, + M, + K, + quant_round_type, + quant_max_bound, + quant_min_bound, + context.stream()); + + MatmulKernel( + context, quant_x_tensor, *w_tensor, false, false, &quant_y_tensor); + + DenseTensor scale_weights_dev; + scale_weights_dev.Resize(phi::make_ddim({N})); + context.template Alloc(&scale_weights_dev, + scale_weights_dev.numel() * sizeof(float)); + float* scale_weights_dev_ptr = scale_weights_dev.data(); + cudaMemcpyAsync(scale_weights_dev_ptr, + scale_weights.data(), + N * sizeof(float), + cudaMemcpyHostToDevice); + + phi::backends::gpu::GpuLaunchConfig config; + if (N % DequantKernelVecSize == 0) { + config = phi::backends::gpu::GetGpuLaunchConfig1D( + context, M * N, DequantKernelVecSize); + } else { + config = phi::backends::gpu::GetGpuLaunchConfig1D(context, M * N, 1); + } + LaunchDequantKernelWithScaleOfInputAndWeight(quant_y_tensor.data(), + Y, + M, + N, + context.stream(), + &config, + scale_in, + scale_weights_dev_ptr, + quant_max_bound); + + if (B == NULL) { + return; + } + + // M * N + AddReluKernel(context.stream(), M, N, Y, B, relu); +} + +template class FCInt8Functor; +template class FCInt8Functor; +template class FCInt8Functor; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/fc_functor.h b/paddle/phi/kernels/funcs/fc_functor.h index 9be644e771800..2bc7ec4bb99b2 100644 --- a/paddle/phi/kernels/funcs/fc_functor.h +++ b/paddle/phi/kernels/funcs/fc_functor.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" namespace phi { namespace funcs { @@ -36,5 +37,25 @@ class FCFunctor { bool weight_pass = false); }; +template +class FCInt8Functor { + public: + void operator()(const DeviceContext& context, + const int M, + const int N, + const int K, + const T* X, + const DenseTensor* W, + T* Y, + float scale_in, + std::vector scale_weights, + int quant_round_type, + float quant_max_bound, + float quant_min_bound, + const T* B = nullptr, + bool relu = false, + bool weight_pass = false); +}; + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/gpc.cc b/paddle/phi/kernels/funcs/gpc.cc index 47a3001b4fda2..cd02f27639208 100644 --- a/paddle/phi/kernels/funcs/gpc.cc +++ b/paddle/phi/kernels/funcs/gpc.cc @@ -145,7 +145,7 @@ static edge_node **bound_list(lmt_node **lmt, double y) { if (!*lmt) { /* Add node onto the tail end of the LMT */ gpc_malloc( - *lmt, sizeof(lmt_node), const_cast("LMT insertion")); + *lmt, sizeof(lmt_node), const_cast("LMT insertion")); // NOLINT (*lmt)->y = y; (*lmt)->first_bound = nullptr; (*lmt)->next = nullptr; @@ -154,7 +154,7 @@ static edge_node **bound_list(lmt_node **lmt, double y) { /* Insert a new LMT node before the current node */ existing_node = *lmt; gpc_malloc( - *lmt, sizeof(lmt_node), const_cast("LMT insertion")); + *lmt, sizeof(lmt_node), const_cast("LMT insertion")); // NOLINT (*lmt)->y = y; (*lmt)->first_bound = nullptr; (*lmt)->next = existing_node; @@ -173,9 +173,10 @@ static edge_node **bound_list(lmt_node **lmt, double y) { static void add_to_sbtree(int *entries, sb_tree **sbtree, double y) { if (!*sbtree) { /* Add a new tree node here */ - gpc_malloc(*sbtree, - sizeof(sb_tree), - const_cast("scanbeam tree insertion")); + gpc_malloc( + *sbtree, + sizeof(sb_tree), + const_cast("scanbeam tree insertion")); // NOLINT (*sbtree)->y = y; (*sbtree)->less = nullptr; (*sbtree)->more = nullptr; @@ -253,7 +254,7 @@ static edge_node *build_lmt(lmt_node **lmt, /* Create the entire input polygon edge table in one go */ gpc_malloc(edge_table, total_vertices * static_cast(sizeof(edge_node)), - const_cast("edge table creation")); + const_cast("edge table creation")); // NOLINT for (c = 0; c < p->num_contours; c++) { if (p->contour[c].num_vertices < 0) { @@ -412,7 +413,7 @@ static void add_intersection( if (!*it) { /* Append a new node to the tail of the list */ gpc_malloc( - *it, sizeof(it_node), const_cast("IT insertion")); + *it, sizeof(it_node), const_cast("IT insertion")); // NOLINT (*it)->ie[0] = edge0; (*it)->ie[1] = edge1; (*it)->point.x = x; @@ -423,7 +424,7 @@ static void add_intersection( /* Insert a new node mid-list */ existing_node = *it; gpc_malloc( - *it, sizeof(it_node), const_cast("IT insertion")); + *it, sizeof(it_node), const_cast("IT insertion")); // NOLINT (*it)->ie[0] = edge0; (*it)->ie[1] = edge1; (*it)->point.x = x; @@ -449,7 +450,7 @@ static void add_st_edge(st_node **st, if (!*st) { /* Append edge onto the tail end of the ST */ gpc_malloc( - *st, sizeof(st_node), const_cast("ST insertion")); + *st, sizeof(st_node), const_cast("ST insertion")); // NOLINT (*st)->edge = edge; (*st)->xb = edge->xb; (*st)->xt = edge->xt; @@ -464,7 +465,7 @@ static void add_st_edge(st_node **st, /* No intersection - insert edge here (before the ST edge) */ existing_node = *st; gpc_malloc( - *st, sizeof(st_node), const_cast("ST insertion")); + *st, sizeof(st_node), const_cast("ST insertion")); // NOLINT (*st)->edge = edge; (*st)->xb = edge->xb; (*st)->xt = edge->xt; @@ -548,7 +549,9 @@ static void add_left(polygon_node *p, double x, double y) { /* Create a new vertex node and set its fields */ gpc_malloc( - nv, sizeof(vertex_node), const_cast("vertex node creation")); + nv, + sizeof(vertex_node), + const_cast("vertex node creation")); // NOLINT nv->x = x; nv->y = y; @@ -586,7 +589,9 @@ static void add_right(polygon_node *p, double x, double y) { /* Create a new vertex node and set its fields */ gpc_malloc( - nv, sizeof(vertex_node), const_cast("vertex node creation")); + nv, + sizeof(vertex_node), + const_cast("vertex node creation")); // NOLINT nv->x = x; nv->y = y; nv->next = nullptr; @@ -631,11 +636,15 @@ static void add_local_min(polygon_node **p, existing_min = *p; gpc_malloc( - *p, sizeof(polygon_node), const_cast("polygon node creation")); + *p, + sizeof(polygon_node), + const_cast("polygon node creation")); // NOLINT /* Create a new vertex node and set its fields */ gpc_malloc( - nv, sizeof(vertex_node), const_cast("vertex node creation")); + nv, + sizeof(vertex_node), + const_cast("vertex node creation")); // NOLINT nv->x = x; nv->y = y; nv->next = nullptr; @@ -666,9 +675,10 @@ static int count_tristrips(polygon_node *tn) { void add_vertex(vertex_node **t, double x, double y) { if (!(*t)) { - gpc_malloc(*t, - sizeof(vertex_node), - const_cast("tristrip vertex creation")); + gpc_malloc( + *t, + sizeof(vertex_node), + const_cast("tristrip vertex creation")); // NOLINT (*t)->x = x; (*t)->y = y; (*t)->next = nullptr; @@ -690,9 +700,10 @@ static void new_tristrip(polygon_node **tn, double x, double y) { if (!(*tn)) { - gpc_malloc(*tn, - sizeof(polygon_node), - const_cast("tristrip node creation")); + gpc_malloc( + *tn, + sizeof(polygon_node), + const_cast("tristrip node creation")); // NOLINT (*tn)->next = nullptr; (*tn)->v[LEFT] = nullptr; (*tn)->v[RIGHT] = nullptr; @@ -712,7 +723,7 @@ static bbox *create_contour_bboxes(gpc_polygon *p) { gpc_malloc(box, p->num_contours * static_cast(sizeof(bbox)), - const_cast("Bounding box creation")); + const_cast("Bounding box creation")); // NOLINT PADDLE_ENFORCE_NOT_NULL( box, phi::errors::ResourceExhausted("Failed to malloc box memory.")); @@ -757,7 +768,7 @@ static void minimax_test(gpc_polygon *subj, gpc_polygon *clip, gpc_op op) { gpc_malloc( o_table, subj->num_contours * clip->num_contours * static_cast(sizeof(int)), - const_cast("overlap table creation")); + const_cast("overlap table creation")); // NOLINT /* Check all subject contour bounding boxes against clip boxes */ for (s = 0; s < subj->num_contours; s++) { @@ -879,7 +890,7 @@ void gpc_add_contour(gpc_polygon *p, gpc_vertex_list *new_contour, int hole) { /* Create an extended hole array */ gpc_malloc(extended_hole, (p->num_contours + 1) * static_cast(sizeof(int)), - const_cast("contour hole addition")); + const_cast("contour hole addition")); // NOLINT PADDLE_ENFORCE_NOT_NULL( extended_hole, phi::errors::ResourceExhausted("Failed to malloc extended hole memory.")); @@ -888,7 +899,7 @@ void gpc_add_contour(gpc_polygon *p, gpc_vertex_list *new_contour, int hole) { gpc_malloc( extended_contour, (p->num_contours + 1) * static_cast(sizeof(gpc_vertex_list)), - const_cast("contour addition")); + const_cast("contour addition")); // NOLINT /* Copy the old contour and hole data into the extended arrays */ for (c = 0; c < p->num_contours; c++) { @@ -903,7 +914,7 @@ void gpc_add_contour(gpc_polygon *p, gpc_vertex_list *new_contour, int hole) { gpc_malloc( extended_contour[c].vertex, new_contour->num_vertices * static_cast(sizeof(gpc_vertex)), - const_cast("contour addition")); + const_cast("contour addition")); // NOLINT for (v = 0; v < new_contour->num_vertices; v++) { extended_contour[c].vertex[v] = new_contour->vertex[v]; // NOLINT } @@ -1004,7 +1015,7 @@ void gpc_polygon_clip(gpc_op op, /* Build scanbeam table from scanbeam tree */ gpc_malloc(sbt, sbt_entries * static_cast(sizeof(double)), - const_cast("sbt creation")); + const_cast("sbt creation")); // NOLINT PADDLE_ENFORCE_NOT_NULL(sbt, phi::errors::ResourceExhausted( "Failed to malloc scanbeam table memory.")); @@ -1501,11 +1512,11 @@ void gpc_polygon_clip(gpc_op op, if (result->num_contours > 0) { gpc_malloc(result->hole, result->num_contours * static_cast(sizeof(int)), - const_cast("hole flag table creation")); + const_cast("hole flag table creation")); // NOLINT gpc_malloc( result->contour, result->num_contours * static_cast(sizeof(gpc_vertex_list)), - const_cast("contour creation")); + const_cast("contour creation")); // NOLINT c = 0; for (poly = out_poly; poly; poly = npoly) { @@ -1513,10 +1524,11 @@ void gpc_polygon_clip(gpc_op op, if (poly->active) { result->hole[c] = poly->proxy->hole; result->contour[c].num_vertices = poly->active; - gpc_malloc(result->contour[c].vertex, - result->contour[c].num_vertices * - static_cast(sizeof(gpc_vertex)), - const_cast("vertex creation")); + gpc_malloc( + result->contour[c].vertex, + result->contour[c].num_vertices * + static_cast(sizeof(gpc_vertex)), + const_cast("vertex creation")); // NOLINT v = result->contour[c].num_vertices - 1; for (vtx = poly->proxy->v[LEFT]; vtx; vtx = nv) { @@ -1651,7 +1663,7 @@ void gpc_tristrip_clip(gpc_op op, /* Build scanbeam table from scanbeam tree */ gpc_malloc(sbt, sbt_entries * static_cast(sizeof(double)), - const_cast("sbt creation")); + const_cast("sbt creation")); // NOLINT PADDLE_ENFORCE_NOT_NULL(sbt, phi::errors::ResourceExhausted( "Failed to malloc scanbeam table memory.")); @@ -2190,7 +2202,7 @@ void gpc_tristrip_clip(gpc_op op, gpc_malloc( result->strip, result->num_strips * static_cast(sizeof(gpc_vertex_list)), - const_cast("tristrip list creation")); + const_cast("tristrip list creation")); // NOLINT s = 0; for (tn = tlist; tn; tn = tnn) { @@ -2201,7 +2213,7 @@ void gpc_tristrip_clip(gpc_op op, gpc_malloc( result->strip[s].vertex, tn->active * static_cast(sizeof(gpc_vertex)), - const_cast("tristrip creation")); + const_cast("tristrip creation")); // NOLINT v = 0; if (false) { lt = tn->v[RIGHT]; @@ -2253,5 +2265,3 @@ void gpc_tristrip_clip(gpc_op op, } // namespace funcs } // namespace phi - -/* vim: set expandtab ts=4 sw=4 sts=4 tw=100: */ diff --git a/paddle/phi/kernels/funcs/quant_dequant.h b/paddle/phi/kernels/funcs/quant_dequant.h index f640dcc369bb7..c0ba1df5c6344 100644 --- a/paddle/phi/kernels/funcs/quant_dequant.h +++ b/paddle/phi/kernels/funcs/quant_dequant.h @@ -42,6 +42,11 @@ inline HOSTDEVICE T roundWithTiesToEven(T x) { : xUpper); } +template +inline HOSTDEVICE T roundWithTiesAwayFromZero(T x) { + return static_cast(x > 0 ? ceil(x) : floor(x)); +} + template __forceinline__ __device__ int8_t quant_helper(const T input, const float scale, @@ -60,6 +65,25 @@ __forceinline__ __device__ int8_t quant_helper(const T input, return static_cast(quant_value); } +template +__forceinline__ __device__ int8_t +quant_helper_ties_to_even_or_away_from_zero(const T input, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + float quant_value = max_bound * scale * static_cast(input); + + if (round_type == 0) { + quant_value = static_cast(roundWithTiesToEven(quant_value)); + } else { + quant_value = static_cast(roundWithTiesAwayFromZero(quant_value)); + } + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + return static_cast(quant_value); +} + template __global__ void QuantKernel(const T* input, char4* output, @@ -87,6 +111,102 @@ __global__ void QuantKernel(const T* input, } } +template +__global__ void QuantKernelWithVecSize(const T* input, + char4* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + char4 tmp; + tmp.x = quant_helper_ties_to_even_or_away_from_zero( + input[m_id * n + n_id], scale, round_type, max_bound, min_bound); + tmp.y = quant_helper_ties_to_even_or_away_from_zero( + input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); + tmp.z = quant_helper_ties_to_even_or_away_from_zero( + input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound); + tmp.w = quant_helper_ties_to_even_or_away_from_zero( + input[m_id * n + n_id + 3], scale, round_type, max_bound, min_bound); + output[(m_id * n + n_id) >> 2] = tmp; + } +} + +template +__global__ void QuantKernelWithVecSize(const T* input, + char3* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) * 3; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + char3 tmp; + tmp.x = quant_helper_ties_to_even_or_away_from_zero( + input[m_id * n + n_id], scale, round_type, max_bound, min_bound); + tmp.y = quant_helper_ties_to_even_or_away_from_zero( + input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); + tmp.z = quant_helper_ties_to_even_or_away_from_zero( + input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound); + output[(m_id * n + n_id) / 3] = tmp; + } +} + +template +__global__ void QuantKernelWithVecSize(const T* input, + char2* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + char2 tmp; + tmp.x = quant_helper_ties_to_even_or_away_from_zero( + input[m_id * n + n_id], scale, round_type, max_bound, min_bound); + tmp.y = quant_helper_ties_to_even_or_away_from_zero( + input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); + output[(m_id * n + n_id) >> 1] = tmp; + } +} + +template +__global__ void QuantKernelWithVecSize(const T* input, + char* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x); + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + char tmp; + tmp = quant_helper_ties_to_even_or_away_from_zero( + input[m_id * n + n_id], scale, round_type, max_bound, min_bound); + output[m_id * n + n_id] = tmp; + } +} + template void LaunchQuantKernel(const T* input, int8_t* output, @@ -98,7 +218,7 @@ void LaunchQuantKernel(const T* input, const float min_bound, gpuStream_t stream) { // TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1 - dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32); + dim3 grid(((n >> 2) + 31) / 32, (m + 31) / 32); dim3 block(32, 32); QuantKernel<<>>(input, @@ -111,6 +231,78 @@ void LaunchQuantKernel(const T* input, min_bound); } +template +void LaunchQuantKernelWithVecSize(const T* input, + int8_t* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound, + gpuStream_t stream) { + int vec_size = 1; + if (n % 4 == 0) { + vec_size = 4; + } else if (n % 3 == 0) { + vec_size = 3; + } else if (n % 2 == 0) { + vec_size = 2; + } + + dim3 grid(((n / vec_size) + 31) / 32, (m + 31) / 32); + dim3 block(32, 32); + + switch (vec_size) { + case 4: + QuantKernelWithVecSize<<>>( + input, + reinterpret_cast(output), + scale, + m, + n, + round_type, + max_bound, + min_bound); + break; + case 3: + QuantKernelWithVecSize<<>>( + input, + reinterpret_cast(output), + scale, + m, + n, + round_type, + max_bound, + min_bound); + break; + case 2: + QuantKernelWithVecSize<<>>( + input, + reinterpret_cast(output), + scale, + m, + n, + round_type, + max_bound, + min_bound); + break; + case 1: + QuantKernelWithVecSize<<>>( + input, + reinterpret_cast(output), + scale, + m, + n, + round_type, + max_bound, + min_bound); + break; + default: + return; + } +} + template __global__ void DequantKernel(T* output, const int32_t* input, @@ -155,4 +347,72 @@ void LaunchDequantKernel(const int32_t* input, output, input, m, n, quant_in_scale, dequant_out_scale_data); } +template +__global__ void DequantKernelWithScaleOfInputAndWeight( + T* output, + const int32_t* input, + const int m, // batch size + const int n, // hidden + const float quant_in_scale, + const float* quant_weight_scale, + float quant_max_bound) { + int numel = m * n; + int stride = blockDim.x * gridDim.x * VecSize; + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + int col_id = idx % n; + + phi::AlignedVector in_vec; + phi::AlignedVector out_scale_vec; + phi::AlignedVector out_vec; + + for (; idx < numel; idx += stride) { + phi::Load(input + idx, &in_vec); + phi::Load(quant_weight_scale + col_id, &out_scale_vec); + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + out_vec[i] = static_cast(static_cast(in_vec[i]) / + (quant_max_bound * quant_max_bound * + quant_in_scale * out_scale_vec[i])); + } + + phi::Store(out_vec, output + idx); + } +} + +template +void LaunchDequantKernelWithScaleOfInputAndWeight( + const int32_t* input, + T* output, + const int m, // m + const int n, // n + gpuStream_t stream, + GpuLaunchConfig* gpu_config, + const float quant_in_scale, + const float* quant_weight_scale, + float quant_max_bound) { + if (n % DequantKernelVecSize != 0) { + DequantKernelWithScaleOfInputAndWeight<<block_per_grid, + gpu_config->thread_per_block, + 0, + stream>>>(output, + input, + m, + n, + quant_in_scale, + quant_weight_scale, + quant_max_bound); + return; + } + DequantKernelWithScaleOfInputAndWeight + <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( + output, + input, + m, + n, + quant_in_scale, + quant_weight_scale, + quant_max_bound); +} + } // namespace phi diff --git a/paddle/phi/kernels/fusion/cpu/fusion_gru_kernel.cc b/paddle/phi/kernels/fusion/cpu/fusion_gru_kernel.cc new file mode 100644 index 0000000000000..3b140091fc69c --- /dev/null +++ b/paddle/phi/kernels/fusion/cpu/fusion_gru_kernel.cc @@ -0,0 +1,439 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include // for memcpy +#include +#include + +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" +#include "paddle/phi/kernels/funcs/jit/kernels.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" + +namespace phi { +namespace fusion { + +#define INIT_BASE_DEFINES \ + auto x_lod = x.lod(); \ + auto x_dims = x.dims(); /* T x M*/ \ + auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) \ + ? phi::flatten_to_2d(x_dims, 1) \ + : x_dims; \ + auto wh_dims = weight_h.dims(); /* D x 3D*/ \ + const int total_T = x_mat_dims[0]; \ + const int D3 = wh_dims[1] + +#define INIT_OTHER_DEFINES \ + const int M = x_mat_dims[1]; \ + const int D = wh_dims[0]; \ + const int D2 = D * 2; \ + const phi::jit::gru_attr_t attr(D, \ + phi::jit::to_kerneltype(gate_activation), \ + phi::jit::to_kerneltype(activation)); \ + phi::jit::gru_t one_step; \ + auto ComputeH1 = \ + phi::jit::KernelFuncs, phi::CPUPlace>::Cache() \ + .At(attr); \ + auto ComputeHtPart1 = phi::jit::KernelFuncs, \ + phi::CPUPlace>::Cache() \ + .At(attr); \ + auto ComputeHtPart2 = phi::jit::KernelFuncs, \ + phi::CPUPlace>::Cache() \ + .At(attr); \ + const T* x_data = x.data(); \ + const T* wx_data = weight_x.data(); \ + const T* wh_data = weight_h.data(); \ + T* xx_data = dev_ctx.template Alloc(xx) + +template +void SeqCompute(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& h0, + const DenseTensor& weight_x, + const DenseTensor& weight_h, + const paddle::optional& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + DenseTensor* reordered_h0, + DenseTensor* xx, + DenseTensor* batched_input, + DenseTensor* batched_out, + DenseTensor* hidden) { + INIT_BASE_DEFINES; + INIT_OTHER_DEFINES; + const int N = static_cast(x_lod[0].size() - 1); + const T* h0_data = h0 ? h0->data() : nullptr; + const T* wh_state_data = wh_data + D * D2; + T* hidden_out_data = dev_ctx.template Alloc(hidden); + + auto blas = phi::funcs::GetBlas(dev_ctx); + + phi::funcs::FCFunctor fc; + fc(dev_ctx, + total_T, + D3, + M, + x_data, + wx_data, + xx_data, + bias ? bias->data() : nullptr); + + int xx_offset = D3; + int gate_offset = D; + if (is_reverse) { + const int offset = (total_T - 1) * D; + xx_data = xx_data + offset * 3; + hidden_out_data = hidden_out_data + offset; + xx_offset = -D3; + gate_offset = -D; + } + auto move_step = [&]() { + xx_data = xx_data + xx_offset; + hidden_out_data = hidden_out_data + gate_offset; + }; + for (int i = 0; i < N; ++i) { + int bid = is_reverse ? N - 1 - i : i; + int seq_len = static_cast(x_lod[0][bid + 1] - x_lod[0][bid]); + const T* prev_hidden_data = nullptr; + int tstart = 0; + if (h0_data) { + prev_hidden_data = h0_data + bid * D; + } else { + one_step.gates = xx_data; + one_step.ht = hidden_out_data; + ComputeH1(&one_step, &attr); + prev_hidden_data = hidden_out_data; + tstart = 1; + move_step(); + } + for (int step = tstart; step < seq_len; ++step) { + // gemm prev * (Wu + Wr) + blas.GEMM(CblasNoTrans, + CblasNoTrans, + 1, + D2, + D, + static_cast(1), + prev_hidden_data, + D, + wh_data, + D2, + static_cast(1), + xx_data, + D3); + one_step.gates = xx_data; + one_step.ht_1 = prev_hidden_data; + one_step.ht = hidden_out_data; + ComputeHtPart1(&one_step, &attr); + // gemm rt * Ws + blas.GEMM(CblasNoTrans, + CblasNoTrans, + 1, + D, + D, + static_cast(1), + hidden_out_data, + D, + wh_state_data, + D, + static_cast(1), + xx_data + D2, + D3); + ComputeHtPart2(&one_step, &attr); + // save prev + prev_hidden_data = hidden_out_data; + move_step(); + } + } +} + +template +void BatchCompute(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& h0, + const DenseTensor& weight_x, + const DenseTensor& weight_h, + const paddle::optional& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + DenseTensor* reordered_h0, + DenseTensor* xx, + DenseTensor* batched_input, + DenseTensor* batched_out, + DenseTensor* hidden) { + INIT_BASE_DEFINES; + if (x_lod[0].size() == 2) { + xx->Resize({total_T, D3}); + SeqCompute(dev_ctx, + x, + h0, + weight_x, + weight_h, + bias, + activation, + gate_activation, + is_reverse, + use_seq, + origin_mode, + use_mkldnn, + mkldnn_data_type, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + reordered_h0, + xx, + batched_input, + batched_out, + hidden); + return; + } + INIT_OTHER_DEFINES; + T* batched_input_data = dev_ctx.template Alloc(batched_input); + T* batched_out_data = dev_ctx.template Alloc(batched_out); + dev_ctx.template Alloc(hidden); + auto blas = phi::funcs::GetBlas(dev_ctx); + phi::funcs::LoDTensor2BatchFunctor to_batch; + + phi::funcs::FCFunctor fc; + if (M > D3) { + fc(dev_ctx, + total_T, + D3, + M, + x_data, + wx_data, + xx_data, + bias ? bias->data() : nullptr); + to_batch(dev_ctx, *xx, batched_input, true, is_reverse); + } else { + to_batch(dev_ctx, x, xx, true, is_reverse); + batched_input->set_lod(xx->lod()); + fc(dev_ctx, + total_T, + D3, + M, + xx_data, + wx_data, + batched_input_data, + bias ? bias->data() : nullptr); + } + + auto batched_lod = batched_input->lod(); + const auto& seq_order = batched_lod[2]; + const int max_bs = static_cast(seq_order.size()); + reordered_h0->Resize({max_bs, D}); + + int tstart = 0; + T* prev_hidden_data = nullptr; + if (h0) { + // reorder h0 + T* reordered_h0_data = dev_ctx.template Alloc(reordered_h0); + const T* h0_data = h0->data(); + prev_hidden_data = reordered_h0_data; + size_t sz = sizeof(T) * D; + for (int i = 0; i < max_bs; ++i) { + std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); + reordered_h0_data += D; + } + } else { + // compute without h0 + T* cur_in_data = batched_input_data; + T* cur_out_data = batched_out_data; + // W: {W_update, W_reset; W_state} + for (int i = 0; i < max_bs; ++i) { + one_step.gates = cur_in_data; + one_step.ht = cur_out_data; + ComputeH1(&one_step, &attr); + // add offset + cur_in_data += D3; + cur_out_data += D; + } + tstart = 1; + prev_hidden_data = batched_out_data; + } + // Then start from next + const T* wh_state_data = wh_data + D * D2; + const auto& batch_starts = batched_lod[0]; + const int max_seq_len = static_cast(batch_starts.size() - 1); + batched_input_data = batched_input_data + tstart * max_bs * D3; + batched_out_data = batched_out_data + tstart * max_bs * D; + for (int step = tstart; step < max_seq_len; ++step) { + const int cur_bs = + static_cast(batch_starts[step + 1] - batch_starts[step]); + // gemm prev * (Wu + Wr) + blas.GEMM(CblasNoTrans, + CblasNoTrans, + cur_bs, + D2, + D, + static_cast(1), + prev_hidden_data, + D, + wh_data, + D2, + static_cast(1), + batched_input_data, + D3); + + T* cur_batched_data = batched_input_data; + T* cur_out_data = batched_out_data; + T* cur_prev_hidden_data = prev_hidden_data; + for (int i = 0; i < cur_bs; ++i) { + one_step.gates = cur_batched_data; + one_step.ht_1 = cur_prev_hidden_data; + one_step.ht = cur_out_data; + ComputeHtPart1(&one_step, &attr); + + cur_batched_data += D3; + cur_prev_hidden_data += D; + cur_out_data += D; + } + + cur_batched_data = batched_input_data; + cur_out_data = batched_out_data; + blas.GEMM(CblasNoTrans, + CblasNoTrans, + cur_bs, + D, + D, + static_cast(1), + cur_out_data, + D, + wh_state_data, + D, + static_cast(1), + cur_batched_data + D2, + D3); + + cur_prev_hidden_data = prev_hidden_data; + for (int i = 0; i < cur_bs; ++i) { + one_step.gates = cur_batched_data; + one_step.ht_1 = cur_prev_hidden_data; + one_step.ht = cur_out_data; + ComputeHtPart2(&one_step, &attr); + cur_batched_data += D3; + cur_prev_hidden_data += D; + cur_out_data += D; + } + prev_hidden_data = batched_out_data; + batched_out_data = cur_out_data; + batched_input_data = cur_batched_data; + } + + phi::funcs::Batch2LoDTensorFunctor to_seq; + batched_out->set_lod(batched_lod); + to_seq(dev_ctx, *batched_out, hidden); +} + +template +void FusionGRUKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& h0, + const DenseTensor& weight_x, + const DenseTensor& weight_h, + const paddle::optional& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + DenseTensor* reordered_h0, + DenseTensor* xx, + DenseTensor* batched_input, + DenseTensor* batched_out, + DenseTensor* hidden) { + if (use_seq) { + SeqCompute(dev_ctx, + x, + h0, + weight_x, + weight_h, + bias, + activation, + gate_activation, + is_reverse, + use_seq, + origin_mode, + use_mkldnn, + mkldnn_data_type, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + reordered_h0, + xx, + batched_input, + batched_out, + hidden); + } else { + BatchCompute(dev_ctx, + x, + h0, + weight_x, + weight_h, + bias, + activation, + gate_activation, + is_reverse, + use_seq, + origin_mode, + use_mkldnn, + mkldnn_data_type, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + reordered_h0, + xx, + batched_input, + batched_out, + hidden); + } +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL( + fusion_gru, CPU, ALL_LAYOUT, phi::fusion::FusionGRUKernel, float, double) {} diff --git a/paddle/phi/kernels/fusion/cpu/fusion_seqconv_eltadd_relu_kernel.cc b/paddle/phi/kernels/fusion/cpu/fusion_seqconv_eltadd_relu_kernel.cc new file mode 100644 index 0000000000000..fbe2ea8d12bc2 --- /dev/null +++ b/paddle/phi/kernels/fusion/cpu/fusion_seqconv_eltadd_relu_kernel.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include // for min, max +#include + +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" + +namespace phi { +namespace fusion { + +template +void FusionSeqConvEltAddReluKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& filter, + const DenseTensor& bias, + const int context_length, + const int context_start, + const int context_stride, + DenseTensor* out, + DenseTensor* col_mat) { + auto x_lod = x.lod(); + auto x_dims = phi::vectorize(x.dims()); + auto w_dims = phi::vectorize(filter.dims()); + PADDLE_ENFORCE_EQ( + bias.numel(), + w_dims[1], + phi::errors::InvalidArgument( + "bias size should be equal to weights feature size, but received " + "bias size is: %d, weights feature size is: %d.", + bias.numel(), + w_dims[1])); + PADDLE_ENFORCE_EQ( + x_lod.size(), + 1UL, + phi::errors::InvalidArgument( + "Only support one level sequence now, but received value is: %d.", + x_lod.size())); + + const T* x_data = x.data(); + const T* w_data = filter.data(); + const T* b_data = bias.data(); + T* y_data = dev_ctx.template Alloc(out); + T* col_data = dev_ctx.template Alloc(col_mat); + + int up_pad = std::max(0, -context_start); + int down_pad = std::max(0, context_start + context_length - 1); + // im2col + int src_mat_w = static_cast(x_dims[1]); + int src_mat_w_sz = src_mat_w * sizeof(T); + int col_mat_w = static_cast(w_dims[0]); + int col_mat_w_sz = col_mat_w * sizeof(T); + for (int i = 0; i < static_cast(x_lod[0].size()) - 1; ++i) { + int st = static_cast(x_lod[0][i]); + int ed = static_cast(x_lod[0][i + 1]); + const T* src_data = x_data + st * src_mat_w; + T* dst_data = col_data + st * col_mat_w; + int seq_len = ed - st; + if (seq_len > up_pad + down_pad) { + // zero all up_pad and fill data + std::memset(dst_data, 0, up_pad * col_mat_w_sz); + dst_data = dst_data + up_pad * src_mat_w; + int copy_size = col_mat_w_sz - up_pad * src_mat_w_sz; + for (int j = 0; j < up_pad; ++j) { + // blas.VCOPY? + std::memcpy(dst_data, src_data, copy_size); + dst_data += (col_mat_w - src_mat_w); + copy_size += src_mat_w_sz; + } + // fill data + if (context_start > 0) { + src_data += context_start * src_mat_w; + } + for (int j = 0; j < seq_len - up_pad - down_pad; ++j) { + std::memcpy(dst_data, src_data, copy_size); + dst_data += col_mat_w; + src_data += src_mat_w; + } + // zero all down_pad and fill data + std::memset(dst_data, 0, down_pad * col_mat_w_sz); + copy_size -= src_mat_w_sz; + for (int j = 0; j < down_pad; ++j) { + if (copy_size < 0) { + copy_size = 0; + } + std::memcpy(dst_data, src_data, copy_size); + dst_data += col_mat_w; + src_data += src_mat_w; + copy_size -= src_mat_w_sz; + } + } else { + std::memset(dst_data, 0, seq_len * col_mat_w_sz); + dst_data = dst_data + up_pad * src_mat_w; + int zero_sz = up_pad * src_mat_w_sz; + int cur_src_sz = seq_len * src_mat_w_sz; + for (int j = 0; j < std::min(up_pad, seq_len); ++j) { + int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz); + std::memcpy(dst_data, src_data, copy_size); + dst_data += (col_mat_w - src_mat_w); + zero_sz -= src_mat_w_sz; + } + // from bottom + dst_data = col_data + ed * col_mat_w; + src_data = x_data + st * src_mat_w; + if (context_start > 0) { + src_data += context_start * src_mat_w; + } + zero_sz = down_pad * src_mat_w_sz; + for (int j = 1; j <= std::min(down_pad, seq_len); ++j) { + int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz); + if (copy_size < 0) { + copy_size = 0; + } + std::memcpy(dst_data - (zero_sz + copy_size) / sizeof(T), + src_data + std::max(seq_len - j - up_pad, 0) * src_mat_w, + copy_size); + dst_data -= col_mat_w; + zero_sz -= src_mat_w_sz; + } + } + } + phi::funcs::FCFunctor fc; + fc(dev_ctx, + x_dims[0], + w_dims[1], + w_dims[0], + col_data, + w_data, + y_data, + b_data, + true); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fusion_seqconv_eltadd_relu, + CPU, + ALL_LAYOUT, + phi::fusion::FusionSeqConvEltAddReluKernel, + float, + double) {} diff --git a/paddle/phi/kernels/fusion/cpu/fusion_seqexpand_concat_fc_kernel.cc b/paddle/phi/kernels/fusion/cpu/fusion_seqexpand_concat_fc_kernel.cc new file mode 100644 index 0000000000000..d5eb7894455f1 --- /dev/null +++ b/paddle/phi/kernels/fusion/cpu/fusion_seqexpand_concat_fc_kernel.cc @@ -0,0 +1,170 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/backends/cpu/cpu_info.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/cpu_vec.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" + +namespace phi { +namespace fusion { +template +void FusionSeqExpandConcatFCKernel(const Context& dev_ctx, + const std::vector& x, + const DenseTensor& fc_weight, + const paddle::optional& fc_bias, + const std::string& fc_activation, + DenseTensor* out, + DenseTensor* fc_out) { + auto* ref_in = x[0]; + auto ref_lod = ref_in->lod(); + auto in1_lod = x[1]->lod(); + auto ref_dims = ref_in->dims(); // T x M0 + auto in1_dims = x[1]->dims(); // N x M1 + auto w_dims = fc_weight.dims(); + const int N = static_cast(ref_lod[0].size() - 1); + const int total_T = static_cast(ref_dims[0]); + const int M0 = static_cast(ref_dims[1]); + const int M1 = static_cast(in1_dims[1]); + const int D = static_cast(w_dims[1]); + + // some check and fcout should be reshape here + // since infershape can not get lod info + PADDLE_ENFORCE_EQ( + ref_lod.size(), + 1UL, + phi::errors::InvalidArgument( + "Only support input lod size is 1, but received value is: %d.", + ref_lod.size())); + PADDLE_ENFORCE_EQ( + in1_lod.size(), + 1UL, + phi::errors::InvalidArgument( + "Only support input lod size is 1, but received value is: %d.", + in1_lod.size())); + PADDLE_ENFORCE_EQ(static_cast(in1_lod[0].size() - 1), + N, + phi::errors::InvalidArgument( + "Batch size of all inputs should be equal to %d, but " + "received value is: %d.", + N, + static_cast(in1_lod[0].size() - 1))); + PADDLE_ENFORCE_EQ( + static_cast(in1_lod[0][N]), + N, + phi::errors::InvalidArgument("Seq_length of other inputs should " + "be %d, but received value is: %d.", + N, + static_cast(in1_lod[0][N]))); + PADDLE_ENFORCE_EQ( + in1_dims[0], + N, + phi::errors::InvalidArgument( + "input height should be batch size: %d, but received value is %d.", + N, + in1_dims[0])); + for (size_t i = 2; i < x.size(); ++i) { + PADDLE_ENFORCE_EQ(x[i]->dims()[0], + N, + phi::errors::InvalidArgument( + "All other inputs height should be equal to %d, " + "but received value is: %d.", + N, + x[i]->dims()[0])); + PADDLE_ENFORCE_EQ(x[i]->lod(), + in1_lod, + phi::errors::InvalidArgument( + "All other inputs should have same lod: %d, but " + "received value is: %d.", + in1_lod, + x[i]->lod())); + } + fc_out->Resize({N, D}); + + std::function fc_act; + if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) { + phi::funcs::VecActivations act_functor; + fc_act = act_functor(fc_activation); + } else { + phi::funcs::VecActivations act_functor; + fc_act = act_functor(fc_activation); + } + + const T* ref_in_data = ref_in->data(); + const T* in1_data = x[1]->data(); + const T* w_data = fc_weight.data(); + T* out_data = dev_ctx.template Alloc(out); + T* fc_out_data = dev_ctx.template Alloc(fc_out); + + auto blas = phi::funcs::GetBlas(dev_ctx); + + phi::funcs::FCFunctor fc; + fc(dev_ctx, + total_T, + D, + M0, + ref_in_data, + w_data, + out_data, + fc_bias ? fc_bias->data() : NULL); + w_data = w_data + M0 * D; + // first write on + blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data); + w_data = w_data + M1 * D; + for (size_t i = 2; i < x.size(); ++i) { + // add on + const T* in_data = x[i]->data(); + const int K = static_cast(x[i]->dims()[1]); + blas.GEMM(CblasNoTrans, + CblasNoTrans, + N, + D, + K, + static_cast(1), + in_data, + K, + w_data, + D, + static_cast(1), + fc_out_data, + D); + w_data = w_data + K * D; + } + T* cur_out_data = out_data; + for (int i = 0; i < N; ++i) { + int seq_len = static_cast(ref_lod[0][i + 1] - ref_lod[0][i]); + T* src = fc_out_data + i * D; + for (int step = 0; step < seq_len; ++step) { + blas.VADD(D, cur_out_data, src, cur_out_data); + cur_out_data = cur_out_data + D; + } + } + fc_act(total_T * D, out_data, out_data); +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fusion_seqexpand_concat_fc, + CPU, + ALL_LAYOUT, + phi::fusion::FusionSeqExpandConcatFCKernel, + float, + double) {} diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py index 8dd51f0c797a4..d08187a1453a6 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py @@ -418,10 +418,10 @@ def write_decl_impl( def write_main_header(): - main_header_content = ''' + main_header_content = f''' #pragma once -#ifdef {} +#ifdef {ENABLE_MACRO} #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" @@ -542,9 +542,7 @@ def write_main_header(): #include "./cutlass_forward.h" #endif -'''.format( - ENABLE_MACRO - ) +''' path = Path(args.dst_path) / "autogen_variable" os.makedirs(path, exist_ok=True) diff --git a/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu index 3b9618db02db0..894903fb0fab8 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu @@ -59,11 +59,7 @@ void FusedBatchNormAddActGradKernel(const Context &dev_ctx, DenseTensor *z_grad, DenseTensor *scale_grad, DenseTensor *bias_grad) { -#if CUDNN_VERSION < 7401 - PADDLE_THROW(phi::errors::Unimplemented( - "The fused_bn_add_activation operator is not supported on GPU " - "when CUDNN version < 7.4.1")); -#endif +#if defined(PADDLE_WITH_CUDA) and CUDNN_VERSION >= 7401 bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU; PADDLE_ENFORCE_EQ(is_gpu_place, true, @@ -208,6 +204,11 @@ void FusedBatchNormAddActGradKernel(const Context &dev_ctx, phi::dynload::cudnnDestroyTensorDescriptor(data_desc_)); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "The fused_bn_add_activation operator is not supported on GPU " + "when CUDNN version < 7.4.1")); +#endif } } // namespace fusion diff --git a/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_kernel.cu index 7b5b4119cf970..52152476e4aca 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_kernel.cu @@ -59,11 +59,7 @@ void FusedBatchNormAddActKernel(const Context &dev_ctx, DenseTensor *saved_mean, DenseTensor *saved_variance, DenseTensor *reserve_space) { -#if CUDNN_VERSION < 7401 - PADDLE_THROW(phi::errors::Unimplemented( - "The fused_bn_add_activation operator is not supported on GPU " - "when CUDNN version < 7.4.1")); -#endif +#if defined(PADDLE_WITH_CUDA) and CUDNN_VERSION >= 7401 bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU; PADDLE_ENFORCE_EQ(is_gpu_place, true, @@ -210,6 +206,11 @@ void FusedBatchNormAddActKernel(const Context &dev_ctx, phi::dynload::cudnnDestroyTensorDescriptor(data_desc_)); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "The fused_bn_add_activation operator is not supported on GPU " + "when CUDNN version < 7.4.1")); +#endif } } // namespace fusion diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu index dce2f8e5247e7..093b23728eaab 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu @@ -89,14 +89,17 @@ struct NoMaskBwFunctor { }; template -__global__ void VectorizedDropoutBackward(const size_t n, - uint64_t seed, - T* x, - T* y, - const T* out_grad, - uint64_t increment, - size_t main_offset, - Functor functor) { +__global__ void VectorizedDropoutBackward( + /* This is used to relate kernel to cudaGraph nodes*/ + unsigned int identifier, + const size_t n, + uint64_t seed, + T* x, + T* y, + const T* out_grad, + uint64_t increment, + size_t main_offset, + Functor functor) { size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); static constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; @@ -198,25 +201,44 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx, auto functor = upscale_in_train ? NoMaskBwFunctor(1.0f - dropout_rate) : NoMaskBwFunctor(1.0f - dropout_rate, 1.0f); -#define PD_DROPOUT_KERNEL_NAME \ - VectorizedDropoutBackward> - PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!fix_seed, - PD_DROPOUT_KERNEL_NAME, - grid_size, - block_size, - 0, - stream, - offset, - KERNEL_PARAMS.As(1), - KERNEL_PARAMS.As(5), - numel, - seed_data, // need save - x_grad_data, - y_grad_data, - out_grad_data, // grad - increment, // need save - main_offset, - functor); + + // we assume seed/offset is same across iterations + // seed_offset_data should preserved by cudaGraph pool + const phi::GPUContext* dev_ctx_p = &dev_ctx; + auto parameterSetter = [offset, dev_ctx_p, seed_offset]( + phi::backends::gpu::CUDAKernelParams& params) { + const auto* seed_offset_data = seed_offset.data(); + const uint64_t seed_data = static_cast(seed_offset_data[0]); + const uint64_t increment = static_cast(seed_offset_data[1]); + + params.As(2) = seed_data; + params.As(6) = increment; + VLOG(10) << "CUDA_GRAPH seed_data = " << seed_data + << ", increment = " << increment; + }; + void* functionPtr = reinterpret_cast( + &(VectorizedDropoutBackward>)); + cudaFunction_t cudaFunc; + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); + phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t + cudaKernelCallback = [=](unsigned int id) { + VectorizedDropoutBackward> + <<>>( + id, + numel, + seed_data, // idx: 2 need save + x_grad_data, + y_grad_data, + out_grad_data, + increment, // idx: 6 need save + main_offset, + functor); + }; + phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch( + cudaFunc, parameterSetter, cudaKernelCallback); + + VLOG(10) << "NON_CUDA_GRAPH seed_data = " << seed_data + << ", increment = " << increment; } } diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu index 3cb1a6742543a..7c675bcbe264c 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu @@ -75,14 +75,17 @@ struct ScaleAddFuctor { }; template -__global__ void VectorizedDropoutForward(const size_t n, - uint64_t seed, - const T* src, - const T* res, - T* dst, - uint64_t increment, - size_t main_offset, - Functor functor) { +__global__ void VectorizedDropoutForward( + /* This is used to relate kernel to cudaGraph nodes*/ + unsigned int identifier, + const size_t n, + uint64_t seed, + const T* src, + const T* res, + T* dst, + uint64_t increment, + size_t main_offset, + Functor functor) { size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); static constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; @@ -169,8 +172,9 @@ void FusedDropoutAddKernel(const Context& dev_ctx, size_t block_size = random_prop[1]; size_t offset = random_prop[2]; size_t main_offset = random_prop[3]; + auto seed_tensor_ptr = seed_tensor.get_ptr(); funcs::GetSeedDataAndIncrement(dev_ctx, - seed_tensor.get_ptr(), + seed_tensor_ptr, fix_seed, seed, offset, @@ -179,32 +183,54 @@ void FusedDropoutAddKernel(const Context& dev_ctx, seed_offset_data[0] = static_cast(seed_data); seed_offset_data[1] = static_cast(increment); - VLOG(4) << "FusedDropoutAdd seed: " << seed << ", offset: " << offset - << ", seed_data:" << seed_data; - auto dst_functor = NoMaskFwFunctor(1.0f - dropout_rate, upscale_in_train); -#define PD_DROPOUT_KERNEL_NAME \ - VectorizedDropoutForward> - PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!fix_seed, - PD_DROPOUT_KERNEL_NAME, - grid_size, - block_size, - 0, - stream, - offset, - KERNEL_PARAMS.As(1), - KERNEL_PARAMS.As(5), - numel, - seed_data, // need save - x_data, - y_data, - out_data, - increment, // need save - main_offset, - dst_functor); -#undef PD_DROPOUT_KERNEL_NAME + // we assume seed/offset is same across iterations + // seed_offset_data should preserved by cudaGraph pool + const phi::GPUContext* dev_ctx_p = &dev_ctx; + void* functionPtr = reinterpret_cast( + &(VectorizedDropoutForward>)); + cudaFunction_t cudaFunc; + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); + auto parameterSetter = + [numel, dev_ctx_p, seed, offset, seed_offset_data, seed_tensor_ptr]( + phi::backends::gpu::CUDAKernelParams& params) { + uint64_t seed_data, increment; + // we get the seed_data/increment from seed/offset + phi::funcs::GetSeedDataAndIncrement(*dev_ctx_p, + seed_tensor_ptr, + false, // fix_seed + seed, + offset, + &seed_data, + &increment); + params.As(2) = seed_data; + params.As(6) = increment; + VLOG(10) << "CUDA_GRAPH seed_data = " << seed_data + << ", increment = " << increment; + + seed_offset_data[0] = static_cast(seed_data); + seed_offset_data[1] = static_cast(increment); + }; + phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t + cudaKernelCallback = [=](unsigned int id) { + VectorizedDropoutForward> + <<>>(id, + numel, + seed_data, // need save + x_data, + y_data, + out_data, + increment, // need save + main_offset, + dst_functor); + }; + phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch( + cudaFunc, parameterSetter, cudaKernelCallback); + + VLOG(10) << "NON_CUDA_GRAPH seed_data = " << seed_data + << ", increment = " << increment; } else { using MT = typename phi::dtype::MPTypeTrait::Type; MT factor = static_cast(1.0f - dropout_rate); diff --git a/paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu index e19996d63c791..f891b94bf1eb7 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu @@ -24,8 +24,8 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/batch_norm_utils.h" #include "paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h" -DECLARE_bool(cudnn_deterministic); -DECLARE_bool(cudnn_exhaustive_search); +PHI_DECLARE_bool(cudnn_deterministic); +PHI_DECLARE_bool(cudnn_exhaustive_search); namespace phi { namespace fusion { diff --git a/paddle/phi/kernels/fusion/gpu/fusion_transpose_flatten_concat_kernel.cu b/paddle/phi/kernels/fusion/gpu/fusion_transpose_flatten_concat_kernel.cu index 954fbd67b96ab..b71f814fd4c98 100644 --- a/paddle/phi/kernels/fusion/gpu/fusion_transpose_flatten_concat_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fusion_transpose_flatten_concat_kernel.cu @@ -37,6 +37,7 @@ void TransposeFlattenConcatFusionKernel( const int flatten_axis, const int concat_axis, DenseTensor* out) { +#if defined(PADDLE_WITH_CUDA) dev_ctx.template Alloc(out, out->numel() * sizeof(T)); auto odims = out->dims(); @@ -114,6 +115,10 @@ void TransposeFlattenConcatFusionKernel( phi::dynload::cudnnDestroyTensorDescriptor(in_desc)); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cudnnDestroyTensorDescriptor(out_desc)); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "The fusion_transpose_flatten_concat operator is not supported on HIP.")); +#endif } } // namespace fusion diff --git a/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.cu b/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.cu index 47ceb7ba1fdbc..0d65c4436b23d 100644 --- a/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.cu +++ b/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.cu @@ -92,6 +92,9 @@ struct Masked_multihead_attention_params { int beam_width; int cache_batch_size; int num_head; + // k_num_head and v_num_head must be equal, we unify them. + // kv_num_head = k_num_head && kv_num_head == v_num_head + int kv_num_head; int timestep; // cache_seq_length int seq_len; int max_seq_length; @@ -403,6 +406,14 @@ __global__ void masked_multihead_attention_kernel( const int bbi = bi / params.beam_width; const int hi = blockIdx.x; const int bhi = bi * params.num_head + hi; + + const int kv_num_head = params.kv_num_head; + const int num_head_per_group = params.num_head / kv_num_head; + // hi means the head index in query processed by this cuda thread. + // kv_bhi means the merged batch and head index in key and value processed by + // this cuda thread. + const int kv_bhi = bi * kv_num_head + hi / num_head_per_group; + const int bbhi = bbi * params.beam_width * params.num_head + hi; const int ti = params.cum_offsets ? bi * params.seq_len - params.cum_offsets[bi] : -1; @@ -418,8 +429,9 @@ __global__ void masked_multihead_attention_kernel( ? params.timestep : params.sequence_lengths[bi]; - // qkv [B, S=1, 3, num_head, head_dim] - int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; + // qkv [B, S=1, num_head + 2 * kv_num_head, head_dim] + // this hi means the head index in query! + int qkv_base_offset = bi * (params.num_head + 2 * kv_num_head) * Dh + hi * Dh; constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); @@ -444,7 +456,8 @@ __global__ void masked_multihead_attention_kernel( if (tid < QK_VECS_PER_WARP) { int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE; - int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE; + int q_bias_offset = hi * Dh + tid * QK_VEC_SIZE; + int k_bias_offset = hi / num_head_per_group * Dh + tid * QK_VEC_SIZE; Qk_vec q; zero(q); @@ -461,7 +474,10 @@ __global__ void masked_multihead_attention_kernel( // ? *reinterpret_cast(&k_base[qk_offset]) // : k; if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) { - load_func.template load(k, params.num_head * Dh + qk_offset); + load_func.template load(k, + params.num_head * Dh + qk_offset - + hi * Dh + + hi / num_head_per_group * Dh); } if (params.add_qkv_bias) { @@ -472,11 +488,11 @@ __global__ void masked_multihead_attention_kernel( q_bias = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&q_bias_base[qk_bias_offset]) + ? *reinterpret_cast(&q_bias_base[q_bias_offset]) : q_bias; k_bias = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&k_bias_base[qk_bias_offset]) + ? *reinterpret_cast(&k_bias_base[k_bias_offset]) : k_bias; q = add(q, q_bias); @@ -582,7 +598,7 @@ __global__ void masked_multihead_attention_kernel( int co = tid / QK_VECS_IN_16B; int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE; - int offset = bhi * params.max_seq_length * Dh + + int offset = kv_bhi * params.max_seq_length * Dh + co * params.max_seq_length * QK_ELTS_IN_16B + act_time_step * QK_ELTS_IN_16B + ci; if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { @@ -640,7 +656,7 @@ __global__ void masked_multihead_attention_kernel( constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; + T *k_cache = ¶ms.cache_kv[kv_bhi * params.max_seq_length * Dh + ki]; T *k_cache_batch = ¶ms.cache_kv[bbhi * params.max_seq_length * Dh + ki]; int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP; @@ -737,12 +753,20 @@ __global__ void masked_multihead_attention_kernel( constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; using V_vec = typename V_vec_::Type; + // now we have got [1, seq] ,distributed in logits_smem. + // next we compute [1, seq] * [seq, head_dim] = [1, head_dim] + // THREADS_PER_VALUE means num of threads per value's head_dim. + // we split the seq dimension for more cuda threads to compute. + // vo means the first seq index processed by this cuda thread in the value. + // vi means the head_dim index processed by this cuda thread in the value. + // so this cuda thread compute [1, k] * [k, vi:vi+V_VEC_SIZE] and k starts + // from vo and increases by a step V_PER_ITER. int vo = tid / THREADS_PER_VALUE; int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE; - T *v_cache = ¶ms.cache_kv[params.cache_batch_size * params.num_head * + T *v_cache = ¶ms.cache_kv[params.cache_batch_size * kv_num_head * params.max_seq_length * Dh + - bhi * params.max_seq_length * Dh + vi]; + kv_bhi * params.max_seq_length * Dh + vi]; T *v_cache_batch = ¶ms.cache_kv[params.batch_size * params.num_head * params.max_seq_length * Dh + bbhi * params.max_seq_length * Dh + vi]; @@ -755,7 +779,7 @@ __global__ void masked_multihead_attention_kernel( V_vec_acum out; zero(out); - + // V_PER_ITER is used to strip-mined the seq dimension. constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; if (Dh == Dh_MAX || vi < Dh) { for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) { @@ -783,15 +807,19 @@ __global__ void masked_multihead_attention_kernel( V_vec v_bias; zero(v_bias); + // now we process the last v. if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { // V_vec v = *reinterpret_cast( // ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); V_vec v; - load_func.template load( - v, 2 * params.num_head * Dh + qkv_base_offset + vi); + load_func.template load(v, + qkv_base_offset + vi - hi * Dh + + params.num_head * Dh + kv_num_head * Dh + + hi / num_head_per_group * Dh); if (params.add_qkv_bias) { v_bias = *reinterpret_cast( - ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); + ¶ms + .qkv_bias[(kv_num_head + params.num_head) * Dh + hi * Dh + vi]); v = add(v, v_bias); } @@ -806,6 +834,7 @@ __global__ void masked_multihead_attention_kernel( __syncthreads(); + // now we do the reduction in the seq dimension to get [1, head_dim]. if (Dh == Dh_MAX || vi < Dh) { #pragma unroll for (int active_groups = V_PER_ITER; active_groups >= 2; @@ -830,6 +859,7 @@ __global__ void masked_multihead_attention_kernel( } } + // write the [1, head_dim] result back to global memory. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { #ifdef MMHA_USE_FP32_ACUM_FOR_OUT // convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + @@ -1319,12 +1349,17 @@ void DispatchWithDtype(const Context &dev_ctx, const auto &x_dims = x.dims(); int bsz = x_dims[0]; int cache_bsz = cache_kv.dims()[1]; - int num_head = cache_kv.dims()[2]; int max_seq_len = cache_kv.dims()[3]; int dim_head = cache_kv.dims()[4]; int timestep = max_seq_len; float inv_sqrt_dh = 1. / sqrt(dim_head); + int k_num_head = cache_kv.dims()[2]; + int v_num_head = k_num_head; + // this num_head means query's head + int num_head = + x.dims()[x.dims().size() - 1] / dim_head - k_num_head - v_num_head; + Masked_multihead_attention_params params; bool mask_broadcast_num_heads = true; @@ -1385,6 +1420,7 @@ void DispatchWithDtype(const Context &dev_ctx, params.batch_size = bsz; params.cache_batch_size = cache_bsz; params.num_head = num_head; + params.kv_num_head = k_num_head; params.timestep = timestep; params.seq_len = seq_len; params.max_seq_length = max_seq_len; diff --git a/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h b/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h index ed311e520681f..12e64caa54b0a 100644 --- a/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h +++ b/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h @@ -1325,16 +1325,18 @@ inline __device__ void apply_rotary_embedding(uint2& q, // NOLINT k.y = rotary_embedding_transform(k.y, cos.y, sin.x); } -inline __device__ void apply_rotary_embedding(uint2& q, // NOLINT - uint2& k, // NOLINT - float4& cos, // NOLINT - float4& sin) { // NOLINT +inline __device__ void apply_rotary_embedding( + uint2& q, // NOLINT equals 4 half. + uint2& k, // NOLINT + float4& cos, // NOLINT 2 float2 cos. + float4& sin) { // NOLINT Float4_& cos_ = *reinterpret_cast(&cos); Float4_& sin_ = *reinterpret_cast(&sin); + // cos_.x is float2 q.x = rotary_embedding_transform(q.x, cos_.x, sin_.x); k.x = rotary_embedding_transform(k.x, cos_.x, sin_.x); q.y = rotary_embedding_transform(q.y, cos_.y, sin_.y); - k.y = rotary_embedding_transform(k.y, cos_.y, sin_.x); + k.y = rotary_embedding_transform(k.y, cos_.y, sin_.y); } inline __device__ void apply_rotary_embedding(uint4& q, // NOLINT diff --git a/paddle/phi/kernels/fusion/onednn/fused_transpose_kernel.cc b/paddle/phi/kernels/fusion/onednn/fused_transpose_kernel.cc index 00c3f9ba7ecdf..964263424f097 100644 --- a/paddle/phi/kernels/fusion/onednn/fused_transpose_kernel.cc +++ b/paddle/phi/kernels/fusion/onednn/fused_transpose_kernel.cc @@ -96,8 +96,9 @@ void FusedTransposeKernel(const Context& dev_ctx, errors::PreconditionNotMet("oneDNN Transpose kernel must use CPUPlace")); if (!(fused_squeeze2_axes.empty())) { - SetInMemDescWithSqueeze2FuseSupport( - fused_squeeze2_axes, const_cast(&x), x.mem_desc()); + SetInMemDescWithSqueeze2FuseSupport(fused_squeeze2_axes, + const_cast(&x), + x.mem_desc()); // NOLINT } if (axis.size() == 1) { @@ -158,7 +159,7 @@ void FusedTransposeKernel(const Context& dev_ctx, auto scales_md = dnnl::memory::desc( {1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x); auto scales = dnnl::memory( - scales_md, dev_ctx.GetEngine(), const_cast(&scale)); + scales_md, dev_ctx.GetEngine(), const_cast(&scale)); // NOLINT args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scales}); } diff --git a/paddle/phi/kernels/fusion/onednn/fusion_gru_kernel.cc b/paddle/phi/kernels/fusion/onednn/fusion_gru_kernel.cc new file mode 100644 index 0000000000000..e3fa939aad753 --- /dev/null +++ b/paddle/phi/kernels/fusion/onednn/fusion_gru_kernel.cc @@ -0,0 +1,638 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/compat/convert_utils.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/expect.h" +#include "paddle/phi/core/utils/data_type.h" + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +using phi::OneDNNContext; +using phi::funcs::CreateKey; +using phi::funcs::OneDNNGetDataType; +using phi::funcs::OneDNNMemDesc; +using phi::funcs::RNNReorderType; +using OneDNNMemoryFormat = dnnl::memory::format_tag; + +template +class GRUOneDNNHandler + : public phi::funcs::OneDNNHandlerT { + public: + GRUOneDNNHandler(const OneDNNContext& dev_ctx, + const dnnl::engine onednn_engine, + phi::Place cpu_place UNUSED, + const phi::DenseTensor* input, + const phi::DenseTensor* weight_h, + const phi::DenseTensor* h0, + const bool is_reverse, + const float scale_data, + const float shift_data, + const std::string& gate_activation, + const std::string& activation, + const std::vector& scale_weights, + const int64_t N, + const int64_t Ti, + const int64_t IC, + const int64_t OC) + : phi::funcs::OneDNNHandlerT( + dev_ctx, + dev_ctx.GetEngine(), + cpu_place, + CreateKey(dev_ctx, + dev_ctx.GetInputsName("X")[0] + + dev_ctx.GetInputsName("WeightH")[0], + OneDNNGetDataType(), + Ti)), + N(N), + Ti(Ti), + IC(IC), + OC(OC), + G(3) { + std::string unique_name = + dev_ctx.GetInputsName("X")[0] + dev_ctx.GetInputsName("WeightH")[0]; + // Create memory key without Ti because weights, bias and h0 memories + // do not depend on Ti size but primitive and input/output memory do + memory_key_ = phi::funcs::ExtendKeyWithThreadInfoIfNeeded( + dev_ctx, CreateKey(dev_ctx, unique_name, OneDNNGetDataType())); + // Is it int8 kernel + const bool is_INT8 = std::is_same::value; + if (is_INT8) { + const int weights_scale_mask = + 0 + + (1 << 3) // bit, indicating the unique scales for `g` dim in `ldigo` + + + (1 << 4); // bit, indicating the unique scales for `o` dim in `ldigo` + + attr_.set_rnn_data_qparams(scale_data, shift_data); + attr_.set_rnn_weights_qparams(weights_scale_mask, scale_weights); + } + + if (unlikely(!this->isCached())) { + // oneDNN kernel has hardcoded activation functions + PADDLE_ENFORCE_EQ( + gate_activation, + "sigmoid", + phi::errors::Unimplemented( + "oneDNN fusion_gru supports only sigmoid as a gate activation.")); + PADDLE_ENFORCE_EQ( + activation, + "tanh", + phi::errors::Unimplemented( + "oneDNN fusion_gru supports only tanh as an activation.")); + + // Weights for int8 kernel are of a type s8 + const auto weights_dt = + is_INT8 ? dnnl::memory::data_type::s8 : OneDNNGetDataType(); + + // oneDNN RNN dimensions + const int64_t D = 1; // Directions + const int64_t L = 1; // Layers (PP supports only 1 stacked layer) + const int64_t G = 3; // Number of Gates, 3 for GRU + + // Create memory descriptors + auto input_md = OneDNNMemDesc( + {Ti, N, IC}, OneDNNGetDataType(), OneDNNMemoryFormat::ntc); + auto weight_x_md = + OneDNNMemDesc({L, D, IC, G, OC}, weights_dt, OneDNNMemoryFormat::any); + auto weight_h_md = + OneDNNMemDesc({L, D, OC, G, OC}, weights_dt, OneDNNMemoryFormat::any); + auto bias_md = OneDNNMemDesc( + {L, D, G, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldgo); + auto hidden_md = OneDNNMemDesc( + {Ti, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ntc); + auto h0_md = OneDNNMemDesc( + {L, D, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldnc); + + // Create GRU oneDNN primitive + const auto direction = + is_reverse ? dnnl::rnn_direction::unidirectional_right2left + : dnnl::rnn_direction::unidirectional_left2right; + + this->AcquireForwardPrimitiveDescriptor( + this->attr_, + dnnl::prop_kind::forward_inference, + direction, + input_md, + h0_md, + weight_x_md, + weight_h_md, + bias_md, + hidden_md, + dnnl::memory::desc()); + } + } + + bool is_NTC() { return this->is_NTC(this->fwd_pd_->dst_desc()); } + + bool is_NTC(const dnnl::memory::desc& md) { + auto ntc_md = dnnl::memory::desc( + md.get_dims(), md.get_data_type(), dnnl::memory::format_tag::ntc); + return md == ntc_md; + } + + void reorderRNNdata(void* input_data, + void* output_data, + std::vector lod, + const bool is_reverse, + RNNReorderType reorder_type) { + switch (reorder_type) { + // Reorder input memory [WORDS, C] + LoD -> [N, T, C] + case RNNReorderType::PP_NTC: { + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]) * IC; + const auto offset = is_reverse ? (Ti * IC - num_elements) : 0; + memcpy(output_data_iter + n * Ti * IC + offset, + input_data_iter, + sizeof(T) * num_elements); + input_data_iter += num_elements; + } + } break; + // Reorder input memory [WORDS, C] + LoD -> [T, N, C] + case RNNReorderType::PP_TNC: { + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]); + const auto offset = is_reverse ? (Ti - num_elements) : 0; + for (size_t t = 0; t < num_elements; ++t) { + memcpy(output_data_iter + (t + offset) * N * IC + n * IC, + input_data_iter, + sizeof(T) * IC); + input_data_iter += IC; + } + } + } break; + // Reorder output values to PP format [N, T, C] -> [WORDS, C] + case RNNReorderType::NTC_PP: { + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]) * OC; + const auto offset = is_reverse ? (Ti * OC - num_elements) : 0; + memcpy(output_data_iter, + input_data_iter + n * Ti * OC + offset, + sizeof(T_out) * num_elements); + output_data_iter += num_elements; + } + } break; + // Reorder output values to PP format [T, N, C] -> [WORDS, C] + case RNNReorderType::TNC_PP: { + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); + for (int n = 0; n < N; ++n) { + const auto num_elements = lod[n + 1] - lod[n]; + const auto offset = is_reverse ? (Ti - num_elements) : 0; + for (size_t t = 0; t < num_elements; ++t) { + memcpy(output_data_iter, + input_data_iter + (t + offset) * N * OC + n * OC, + sizeof(T_out) * OC); + output_data_iter += OC; + } + } + } break; + } + } + + std::shared_ptr AcquireInputMemoryWithReorder( + const phi::DenseTensor* input, const bool is_reverse) { + const auto name = this->key_ + "@input_mem"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(name)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->src_desc(), + this->engine_); + this->dev_ctx_.SetBlob(name, memory_p); + } + + const auto& input_lod = input->lod()[0]; + auto* x_data = phi::funcs::to_void_cast(input->data()); + + auto* x_onednn_data = memory_p->get_data_handle(); + memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC); + + if (is_NTC(this->fwd_pd_->src_desc())) { + reorderRNNdata( + x_data, x_onednn_data, input_lod, is_reverse, RNNReorderType::PP_NTC); + } else { + reorderRNNdata( + x_data, x_onednn_data, input_lod, is_reverse, RNNReorderType::PP_TNC); + } + return memory_p; + } + + std::shared_ptr AcquireOutputMemory() { + const auto name = this->key_ + "@output_mem"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(name)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->dst_desc(), + this->engine_); + this->dev_ctx_.SetBlob(name, memory_p); + } + return memory_p; + } + + // H0 is for now persistable + template + std::shared_ptr AcquireH0Memory(const phi::DenseTensor* h0) { + const std::string h0_key = memory_key_ + "@h0"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(h0_key)); + + if (!memory_p) { + auto user_h0_memory = dnnl::memory(); + if (h0) { + user_h0_memory = dnnl::memory( + {{1, 1, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldnc}, + this->engine_, + phi::funcs::to_void_cast(h0->data())); + } else { + user_h0_memory = dnnl::memory( + {{1, 1, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldnc}, + this->engine_); + memset(user_h0_memory.get_data_handle(), 0, sizeof(U) * N * OC); + } + memory_p = std::make_shared(this->fwd_pd_->src_iter_desc(), + this->engine_); + + auto& astream = phi::OneDNNContext::tls().get_stream(); + dnnl::reorder(user_h0_memory, *memory_p, attr_) + .execute(astream, user_h0_memory, *memory_p); + + this->dev_ctx_.SetBlob(h0_key, memory_p); + } + return memory_p; + } + + template + std::shared_ptr AcquireWeightXMemory( + const phi::DenseTensor* weight_x, const bool origin_mode) { + const std::string wx_key = this->memory_key_ + "@weight_x"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(wx_key)); + + if (!memory_p) { + auto user_md = OneDNNMemDesc({1, 1, this->IC, this->G, this->OC}, + OneDNNGetDataType(), + OneDNNMemoryFormat::ldigo); + auto user_memory = dnnl::memory(user_md, this->engine_); + + auto* weight_x_data = reinterpret_cast(user_memory.get_data_handle()); + memcpy(weight_x_data, + weight_x->data(), + sizeof(U) * this->IC * this->G * this->OC); + + if (origin_mode == false) { + for (int64_t i = 0; i < this->IC; ++i) { + for (int64_t j = 0; j < this->OC; ++j) { + U minus_one(-1.0f); + weight_x_data[j] = minus_one * weight_x_data[j]; + } + weight_x_data += 3 * this->OC; + } + } + + memory_p = std::make_shared( + this->fwd_pd_->weights_layer_desc(), this->engine_); + + auto& astream = OneDNNContext::tls().get_stream(); + dnnl::reorder(user_memory, *memory_p, this->attr_) + .execute(astream, user_memory, *memory_p); + + this->dev_ctx_.SetBlob(wx_key, memory_p); + } + return memory_p; + } + + template + std::shared_ptr AcquireWeightHMemory( + const phi::DenseTensor* weight_h, const bool origin_mode) { + const std::string wh_key = this->memory_key_ + "@weight_h"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(wh_key)); + + if (!memory_p) { + auto user_md = OneDNNMemDesc({1, 1, this->OC, this->G, this->OC}, + OneDNNGetDataType(), + OneDNNMemoryFormat::ldigo); + auto user_memory = dnnl::memory(user_md, this->engine_); + + // Reorder weights_h from PP format [OC, 2OC] + [OC, OC] to + // oneDNN format [OC, 3OC] + auto* weight_h_data = reinterpret_cast(user_memory.get_data_handle()); + auto* user_weight_h_data = weight_h->data(); + + auto src1_iter = user_weight_h_data; + auto src2_iter = user_weight_h_data + 2 * this->OC * this->OC; + + for (int64_t c = 0; c < this->OC; ++c) { + memcpy(weight_h_data, src1_iter, 2 * this->OC * sizeof(U)); + memcpy(weight_h_data + 2 * this->OC, src2_iter, this->OC * sizeof(U)); + + src1_iter += 2 * this->OC; + src2_iter += this->OC; + weight_h_data += 3 * this->OC; + } + + weight_h_data = reinterpret_cast(user_memory.get_data_handle()); + + if (origin_mode == false) { + for (int64_t i = 0; i < this->OC; ++i) { + for (int64_t j = 0; j < this->OC; ++j) { + U minus_one(-1.0f); + weight_h_data[j] = minus_one * weight_h_data[j]; + } + weight_h_data += 3 * this->OC; + } + } + + memory_p = std::make_shared( + this->fwd_pd_->weights_iter_desc(), this->engine_); + + auto& astream = OneDNNContext::tls().get_stream(); + dnnl::reorder(user_memory, *memory_p, this->attr_) + .execute(astream, user_memory, *memory_p); + + this->dev_ctx_.SetBlob(wh_key, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquireBiasMemory(const phi::DenseTensor* bias, + const bool origin_mode) { + const std::string bias_key = this->memory_key_ + "@bias"; + auto memory_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(bias_key)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->bias_desc(), + this->engine_); + auto* bias_data = reinterpret_cast(memory_p->get_data_handle()); + if (bias) { + const float* user_bias_data = + bias->data(); // Bias in oneDNN is always float + memcpy(bias_data, user_bias_data, sizeof(float) * this->G * this->OC); + } else { + // oneDNN always need bias memory, if it's not provided in PP, let + // oneDNN allocate memory and set it to 0 + memset(bias_data, 0, sizeof(float) * this->G * this->OC); + } + + if (origin_mode == false && bias) { + for (int64_t i = 0; i < this->OC; ++i) { + bias_data[i] *= -1; + } + } + this->dev_ctx_.SetBlob(bias_key, memory_p); + } + return memory_p; + } + + protected: + // RNN dimensions + // N - Batch Size + // Ti - Max sentence length + // IC - Input Channels + // OC - Output Channels + // G - Number of gates + const int64_t N, Ti, IC, OC, G; + + // Memory size of weights, bias and h0 does not depend + // on Ti size, thus we need another key to cache them + std::string memory_key_; + dnnl::primitive_attr attr_; +}; + +template +void RunKernel(const phi::OneDNNContext& dev_ctx, + const DenseTensor& x, + const paddle::optional& h0, + const DenseTensor& weight_x, + const DenseTensor& weight_h, + const paddle::optional& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + DenseTensor* reordered_h0, + DenseTensor* xx, + DenseTensor* batched_input, + DenseTensor* batched_out, + DenseTensor* hidden) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto x_dims = x.dims(); + auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) + ? phi::flatten_to_2d(x_dims, 1) + : x_dims; + + // Get tensor dimensions + const auto x_mat_dims_vec = phi::vectorize(x_mat_dims); + const auto weight_h_dims = phi::vectorize(weight_h.dims()); + const auto& input_lod = x.lod()[0]; + + // Calculate RNN dimensions + const int64_t N = input_lod.size() - 1; // Number of sentences (batches) + const int64_t Ti = // Max length of the sentence in a batch + [&input_lod]() { + size_t res = 0; + for (size_t i = 0; i < (input_lod.size() - 1); ++i) { + res = std::max(res, input_lod[i + 1] - input_lod[i]); + } + return res; + }(); + const int64_t IC = x_mat_dims_vec[1]; // Input channels + const int64_t OC = weight_h_dims[0]; // Output channels + + GRUOneDNNHandler handler(dev_ctx, + onednn_engine, + dev_ctx.GetPlace(), + &x, + &weight_h, + h0.get_ptr(), + is_reverse, + scale_data, + shift_data, + gate_activation, + activation, + scale_weights, + N, + Ti, + IC, + OC); + auto input_memory_p = handler.AcquireInputMemoryWithReorder(&x, is_reverse); + + std::shared_ptr h0_memory_p, weight_h_memory_p, + weight_x_memory_p; + + if (phi::TransToProtoVarType(weight_h.dtype()) == phi::ProtoDataType::FP32) { + h0_memory_p = handler.template AcquireH0Memory(h0.get_ptr()); + weight_x_memory_p = + handler.template AcquireWeightXMemory(&weight_x, origin_mode); + weight_h_memory_p = + handler.template AcquireWeightHMemory(&weight_h, origin_mode); + } else if (phi::TransToProtoVarType(weight_h.dtype()) == + phi::ProtoDataType::BF16) { + h0_memory_p = + handler.template AcquireH0Memory(h0.get_ptr()); + weight_x_memory_p = + handler.template AcquireWeightXMemory( + &weight_x, origin_mode); + weight_h_memory_p = + handler.template AcquireWeightHMemory( + &weight_h, origin_mode); + } else { + h0_memory_p = handler.template AcquireH0Memory(h0.get_ptr()); + weight_x_memory_p = + handler.template AcquireWeightXMemory(&weight_x, origin_mode); + weight_h_memory_p = + handler.template AcquireWeightHMemory(&weight_h, origin_mode); + } + + auto bias_memory_p = handler.AcquireBiasMemory(bias.get_ptr(), origin_mode); + auto hidden_onednn_memory_p = handler.AcquireOutputMemory(); + + std::unordered_map gru_args = { + {DNNL_ARG_SRC_LAYER, *input_memory_p}, + {DNNL_ARG_SRC_ITER, *h0_memory_p}, + {DNNL_ARG_WEIGHTS_LAYER, *weight_x_memory_p}, + {DNNL_ARG_WEIGHTS_ITER, *weight_h_memory_p}, + {DNNL_ARG_BIAS, *bias_memory_p}, + {DNNL_ARG_DST_LAYER, *hidden_onednn_memory_p}}; + + auto gru_forward_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + gru_forward_p->execute(astream, gru_args); + astream.wait(); + + auto* hidden_onednn_data = hidden_onednn_memory_p->get_data_handle(); + auto* hidden_tmp_data = dev_ctx.template Alloc(hidden); + auto* hidden_data = phi::funcs::to_void_cast(hidden_tmp_data); + if (handler.is_NTC()) { + handler.reorderRNNdata(hidden_onednn_data, + hidden_data, + input_lod, + is_reverse, + RNNReorderType::NTC_PP); + } else { + handler.reorderRNNdata(hidden_onednn_data, + hidden_data, + input_lod, + is_reverse, + RNNReorderType::TNC_PP); + } +} + +template +void FusionGRUKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& h0, + const DenseTensor& weight_x, + const DenseTensor& weight_h, + const paddle::optional& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + DenseTensor* reordered_h0, + DenseTensor* xx, + DenseTensor* batched_input, + DenseTensor* batched_out, + DenseTensor* hidden) { + const bool is_bf16 = std::is_same::value; + // BF16 does not support force output + if (!is_bf16 && force_fp32_output) { // NOLINT + RunKernel(dev_ctx, + x, + h0, + weight_x, + weight_h, + bias, + activation, + gate_activation, + is_reverse, + use_seq, + origin_mode, + use_mkldnn, + mkldnn_data_type, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + reordered_h0, + xx, + batched_input, + batched_out, + hidden); + } else { + RunKernel(dev_ctx, + x, + h0, + weight_x, + weight_h, + bias, + activation, + gate_activation, + is_reverse, + use_seq, + origin_mode, + use_mkldnn, + mkldnn_data_type, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + reordered_h0, + xx, + batched_input, + batched_out, + hidden); + } +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fusion_gru, + OneDNN, + ONEDNN, + phi::fusion::FusionGRUKernel, + float, + phi::dtype::bfloat16, + uint8_t) {} diff --git a/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc index 43caa13698b48..6ba3d84b5eb0b 100644 --- a/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "glog/logging.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" namespace phi { namespace fusion { @@ -32,6 +35,8 @@ void Conv2dXPUKernelImpl(const Context& ctx, const paddle::optional& bias, const paddle::optional& branch, const paddle::optional& branch_max, + const paddle::optional& scale_max, + const paddle::optional& out_max_in, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, @@ -66,14 +71,19 @@ void Conv2dXPUKernelImpl(const Context& ctx, int out_c = static_cast(filter_dims[0]); int win_h = static_cast(filter_dims[2]); int win_w = static_cast(filter_dims[3]); - auto* input_data = reinterpret_cast(x.data()); const float* input_max_data = x_max.get_ptr() == nullptr ? nullptr : x_max.get_ptr()->data(); auto* filter_data = reinterpret_cast(filter.data()); auto* filter_max_data = filter_max.data(); + auto* scale_max_data = scale_max.get_ptr() == nullptr + ? nullptr + : scale_max.get_ptr()->data(); const XPUTypeOut* branch_data = nullptr; + const float* branch_max_data = branch_max.get_ptr() == nullptr + ? nullptr + : branch_max.get_ptr()->data(); auto* branch_tensor = branch.get_ptr(); xpu::ctx_guard RAII_GUARD(ctx.x_context()); if (branch_tensor != nullptr) { @@ -92,14 +102,15 @@ void Conv2dXPUKernelImpl(const Context& ctx, branch_data = branch_data_temp; } } - const float* branch_max_data = branch_max.get_ptr() == nullptr - ? nullptr - : branch_max.get_ptr()->data(); + const float* bias_data = bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data(); auto* out_data = reinterpret_cast(ctx.template Alloc(out)); auto* out_max_data = ctx.template Alloc(out_max); + out_max_data = out_max_in.get_ptr() != nullptr + ? const_cast(out_max_in.get_ptr()->data()) + : out_max_data; xpu::Activation_t act(static_cast(act_type)); if (act_type == xpu::Activation_t::LEAKY_RELU) { act.leaky_alpha = act_param; @@ -131,7 +142,7 @@ void Conv2dXPUKernelImpl(const Context& ctx, /* const TY* branch */ branch_data, /* const baidu::xpu::api::Activation_t& act */ act, /* const float* branch_maxptr */ branch_max_data, - /* const float* scale */ nullptr); + /* const float* scale */ scale_max_data); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_xpu"); } @@ -145,6 +156,8 @@ void Conv2dXPUKernelImpl(const Context& ctx, bias, \ branch, \ branch_max, \ + scale_max, \ + out_max_in, \ paddings, \ dilations, \ strides, \ @@ -164,6 +177,8 @@ void Conv2dXPUKernel(const Context& ctx, const paddle::optional& bias, const paddle::optional& branch, const paddle::optional& branch_max, + const paddle::optional& scale_max, + const paddle::optional& out_max_in, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, @@ -174,14 +189,118 @@ void Conv2dXPUKernel(const Context& ctx, DataType out_dtype, DenseTensor* out, DenseTensor* out_max) { - if (out_dtype == DataType::FLOAT32) { - CONV2D_XPU_KERNEL_IMPL(T, int16_t, float, int16_t); - } else if (out_dtype == DataType::FLOAT16) { - CONV2D_XPU_KERNEL_IMPL(T, int16_t, dtype::float16, int16_t); - } else { - PADDLE_THROW(phi::errors::Unimplemented("Not support out_dtype is %s.", - DataTypeToString(out_dtype))); + // Dont use template T param + VLOG(4) << "Conv kernel type: " << x.dtype() << " ," << filter.dtype() << " ," + << out_dtype; + if (x.dtype() == DataType::FLOAT32) { + // float32/float16 kernel + if (filter.dtype() == DataType::INT16) { + if (out_dtype == DataType::FLOAT32) { + CONV2D_XPU_KERNEL_IMPL(float, int16_t, float, int16_t); + } else if (out_dtype == DataType::FLOAT16) { + CONV2D_XPU_KERNEL_IMPL(float, int16_t, dtype::float16, int16_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + } else if (filter.dtype() == DataType::INT8) { + if (out_dtype == DataType::FLOAT32) { + CONV2D_XPU_KERNEL_IMPL(float, int8_t, float, int8_t); + } else if (out_dtype == DataType::INT8) { + CONV2D_XPU_KERNEL_IMPL(float, int8_t, int8_t, int8_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + return; } + + if (x.dtype() == DataType::FLOAT16) { + // float16 kernel + if (filter.dtype() == DataType::INT16) { + if (out_dtype == DataType::FLOAT32) { + CONV2D_XPU_KERNEL_IMPL(phi::dtype::float16, int16_t, float, int16_t); + } else if (out_dtype == DataType::FLOAT16) { + CONV2D_XPU_KERNEL_IMPL( + phi::dtype::float16, int16_t, dtype::float16, int16_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + } else if (filter.dtype() == DataType::INT8) { + if (out_dtype == DataType::FLOAT16) { + CONV2D_XPU_KERNEL_IMPL( + phi::dtype::float16, int8_t, dtype::float16, int8_t); + } else if (out_dtype == DataType::INT8) { + CONV2D_XPU_KERNEL_IMPL(phi::dtype::float16, int8_t, int8_t, int8_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + return; + } + + if (x.dtype() == DataType::INT8) { + if (filter.dtype() == DataType::INT8) { + if (out_dtype == DataType::FLOAT32) { + CONV2D_XPU_KERNEL_IMPL(int8_t, int8_t, float, int8_t); + } else if (out_dtype == DataType::FLOAT16) { + CONV2D_XPU_KERNEL_IMPL(int8_t, int8_t, dtype::float16, int8_t); + } else if (out_dtype == DataType::INT8) { + CONV2D_XPU_KERNEL_IMPL(int8_t, int8_t, int8_t, int8_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + return; + } + + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); } } // namespace fusion @@ -192,4 +311,5 @@ PD_REGISTER_KERNEL(conv2d_xpu, ALL_LAYOUT, phi::fusion::Conv2dXPUKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + int8_t) {} diff --git a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc index 6a6721194e9a8..d6153eff096cb 100644 --- a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "glog/logging.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" @@ -29,6 +30,8 @@ void FcXPUKernelImpl(const Context& ctx, const DenseTensor& w, const DenseTensor& w_max, const paddle::optional& bias, + const paddle::optional& scale_max, + const paddle::optional& out_max_in, int in_num_col_dims, bool transpose_x, float alpha, @@ -53,7 +56,13 @@ void FcXPUKernelImpl(const Context& ctx, bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data(); auto* out_data = reinterpret_cast(ctx.template Alloc(out)); + auto* scale_max_data = scale_max.get_ptr() == nullptr + ? nullptr + : scale_max.get_ptr()->data(); auto* out_max_data = ctx.template Alloc(out_max); + out_max_data = out_max_in.get_ptr() != nullptr + ? const_cast(out_max_in.get_ptr()->data()) + : out_max_data; xpu::Activation_t act(static_cast(act_type)); if (act_type == xpu::Activation_t::LEAKY_RELU) { act.leaky_alpha = act_alpha; @@ -80,7 +89,9 @@ void FcXPUKernelImpl(const Context& ctx, alpha, // alpha beta, // beta bias_data, // bias - act); + act, // act + scale_max_data); // scale + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_xpu"); } @@ -92,6 +103,8 @@ void FcXPUKernelImpl(const Context& ctx, w, \ w_max, \ bias, \ + scale_max, \ + out_max_in, \ in_num_col_dims, \ transpose_x, \ alpha, \ @@ -108,6 +121,8 @@ void FcXPUKernel(const Context& ctx, const DenseTensor& w, const DenseTensor& w_max, const paddle::optional& bias, + const paddle::optional& scale_max, + const paddle::optional& out_max_in, int in_num_col_dims, bool transpose_x, float alpha, @@ -117,14 +132,119 @@ void FcXPUKernel(const Context& ctx, DataType out_dtype, DenseTensor* out, DenseTensor* out_max) { - if (out_dtype == DataType::FLOAT32) { - FC_XPU_KERNEL_IMPL(T, int16_t, float, int16_t); - } else if (out_dtype == DataType::FLOAT16) { - FC_XPU_KERNEL_IMPL(T, int16_t, dtype::float16, int16_t); - } else { - PADDLE_THROW(phi::errors::Unimplemented("Not support out_dtype is %s.", - DataTypeToString(out_dtype))); + // Dont use template T param + VLOG(4) << "Fc kernel type: " << x.dtype() << " ," << w.dtype() << " ," + << out_dtype; + if (x.dtype() == DataType::FLOAT32) { + // float32/float16 kernel + if (w.dtype() == DataType::INT16) { + if (out_dtype == DataType::FLOAT32) { + FC_XPU_KERNEL_IMPL(float, int16_t, float, int16_t); + } else if (out_dtype == DataType::FLOAT16) { + FC_XPU_KERNEL_IMPL(float, int16_t, dtype::float16, int16_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + } else if (w.dtype() == DataType::INT8) { + if (out_dtype == DataType::FLOAT32) { + FC_XPU_KERNEL_IMPL(float, int8_t, float, int8_t); + } else if (out_dtype == DataType::INT8) { + FC_XPU_KERNEL_IMPL(float, int8_t, int8_t, int8_t); + } else if (out_dtype == DataType::FLOAT16) { + FC_XPU_KERNEL_IMPL(float, int8_t, dtype::float16, int8_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + return; + } + + if (x.dtype() == DataType::FLOAT16) { + // float16 kernel + if (w.dtype() == DataType::INT16) { + if (out_dtype == DataType::FLOAT32) { + FC_XPU_KERNEL_IMPL(phi::dtype::float16, int16_t, float, int16_t); + } else if (out_dtype == DataType::FLOAT16) { + FC_XPU_KERNEL_IMPL( + phi::dtype::float16, int16_t, dtype::float16, int16_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + } else if (w.dtype() == DataType::INT8) { + if (out_dtype == DataType::FLOAT16) { + FC_XPU_KERNEL_IMPL(phi::dtype::float16, int8_t, dtype::float16, int8_t); + } else if (out_dtype == DataType::INT8) { + FC_XPU_KERNEL_IMPL(phi::dtype::float16, int8_t, int8_t, int8_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + return; } + + if (x.dtype() == DataType::INT8) { + if (w.dtype() == DataType::INT8) { + if (out_dtype == DataType::FLOAT32) { + FC_XPU_KERNEL_IMPL(int8_t, int8_t, float, int8_t); + } else if (out_dtype == DataType::FLOAT16) { + FC_XPU_KERNEL_IMPL(int8_t, int8_t, dtype::float16, int8_t); + } else if (out_dtype == DataType::INT8) { + FC_XPU_KERNEL_IMPL(int8_t, int8_t, int8_t, int8_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + return; + } + + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); } } // namespace fusion @@ -135,4 +255,5 @@ PD_REGISTER_KERNEL(fc_xpu, ALL_LAYOUT, phi::fusion::FcXPUKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + int8_t) {} diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index c67864bc13f57..2a1c6759bbc8b 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -495,7 +495,8 @@ PD_REGISTER_KERNEL(cos_triple_grad, phi::dtype::complex, phi::dtype::complex) {} -PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softsign_grad, + SoftsignGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_double_grad, SigmoidDoubleGradKernel) diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 6eeba717ece0d..34bbbfbd11859 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -292,7 +292,7 @@ PD_REGISTER_ACTIVATION_KERNEL(softshrink, SoftShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel) -PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softsign, SoftsignKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel) diff --git a/paddle/phi/kernels/gpu/arange_kernel.cu b/paddle/phi/kernels/gpu/arange_kernel.cu index dc75e1b8da122..3c793e106f049 100644 --- a/paddle/phi/kernels/gpu/arange_kernel.cu +++ b/paddle/phi/kernels/gpu/arange_kernel.cu @@ -34,11 +34,11 @@ __global__ void Range(T start, T step, int64_t size, OUT_TYPE* out) { } template -void ArangeKernel(const Context& dev_ctx, - const DenseTensor& start, - const DenseTensor& end, - const DenseTensor& step, - DenseTensor* out) { +void ArangeTensorKernel(const Context& dev_ctx, + const DenseTensor& start, + const DenseTensor& end, + const DenseTensor& step, + DenseTensor* out) { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType start_value = static_cast(GetValue(dev_ctx, start)); @@ -80,16 +80,30 @@ void ArangeNullaryKernel(const Context& dev_ctx, Range<<>>(start_value, step_value, size, out_data); } +template +void ArangeKernel(const Context& dev_ctx, + const Scalar& start, + const Scalar& end, + const Scalar& step, + DenseTensor* out) { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType start_value = start.to(); + MPType end_value = end.to(); + MPType step_value = step.to(); + ArangeNullaryKernel( + dev_ctx, start_value, end_value, step_value, out); +} + template decltype(ArangeNullaryKernel) ArangeNullaryKernel; template decltype(ArangeNullaryKernel) ArangeNullaryKernel; } // namespace phi -PD_REGISTER_KERNEL(arange, +PD_REGISTER_KERNEL(arange_tensor, GPU, ALL_LAYOUT, - phi::ArangeKernel, + phi::ArangeTensorKernel, float, double, int64_t, @@ -100,3 +114,14 @@ PD_REGISTER_KERNEL(arange, kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); } + +PD_REGISTER_KERNEL(arange, + GPU, + ALL_LAYOUT, + phi::ArangeKernel, + float, + double, + int64_t, + int, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 78c3723ceedcb..c3c353859728b 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -22,6 +22,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/batch_norm_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/norm_utils.cu.h" @@ -487,8 +488,8 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData( template void BatchNormGradFunctor(const Context &ctx, const DenseTensor &x, - const DenseTensor &scale, - const DenseTensor &bias, + const paddle::optional &scale, + const paddle::optional &bias, const paddle::optional &mean, const paddle::optional &variance, const DenseTensor &saved_mean, @@ -549,23 +550,41 @@ void BatchNormGradFunctor(const Context &ctx, ctx.template Alloc>(d_bias); } + auto *Scale = scale.get_ptr(); + auto *Bias = bias.get_ptr(); + + phi::DenseTensor new_scale; + phi::DenseTensor new_bias; + + if (Scale) { + new_scale = scale.get(); + } else { + new_scale = phi::Full(ctx, {C}, static_cast(1)); + } + + if (Bias) { + new_bias = bias.get(); + } else { + new_bias = phi::Full(ctx, {C}, static_cast(0)); + } + PADDLE_ENFORCE_EQ( - scale.dims().size(), + new_scale.dims().size(), 1UL, phi::errors::InvalidArgument( "The size of scale's dimensions must equal to 1. But received: " "the size of scale's dimensions is [%d], the dimensions of scale " "is [%s].", - scale.dims().size(), - scale.dims())); + new_scale.dims().size(), + new_scale.dims())); PADDLE_ENFORCE_EQ( - scale.dims()[0], + new_scale.dims()[0], C, phi::errors::InvalidArgument( "The first dimension of scale must equal to Channels[%d]. But " "received: the first dimension of scale is [%d]", C, - scale.dims()[0])); + new_scale.dims()[0])); auto dtype = phi::backends::gpu::CudnnDataType::type; #ifdef PADDLE_WITH_HIP @@ -713,8 +732,8 @@ void BatchNormGradFunctor(const Context &ctx, if (is_inplace) { inplace_functor(compute_format, transformed_x.data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), saved_mean_data, saved_var_data, epsilon, @@ -735,7 +754,7 @@ void BatchNormGradFunctor(const Context &ctx, <<>>( transformed_d_y.template data(), transformed_x.template data(), - scale.template data>(), + new_scale.template data>(), saved_mean_data, saved_var_data, C, @@ -750,7 +769,7 @@ void BatchNormGradFunctor(const Context &ctx, <<>>( transformed_d_y.template data(), transformed_x.template data(), - scale.template data>(), + new_scale.template data>(), saved_mean_data, saved_var_data, C, @@ -880,7 +899,7 @@ void BatchNormGradFunctor(const Context &ctx, <<>>( transformed_d_y.template data(), transformed_x.template data(), - scale.template data>(), + new_scale.template data>(), dscale, dbias, mean_ptr, @@ -897,7 +916,7 @@ void BatchNormGradFunctor(const Context &ctx, <<>>( transformed_d_y.template data(), transformed_x.template data(), - scale.template data>(), + new_scale.template data>(), saved_mean_data, saved_var_data, C, @@ -912,7 +931,7 @@ void BatchNormGradFunctor(const Context &ctx, <<>>( transformed_d_y.template data(), transformed_x.template data(), - scale.template data>(), + new_scale.template data>(), saved_mean_data, saved_var_data, C, @@ -969,7 +988,8 @@ void BatchNormGradFunctor(const Context &ctx, /*dxDesc=*/data_desc_, /*dxData=*/ctx.template Alloc(&transformed_d_x), /*dBnScaleBiasDesc=*/bn_param_desc_, - /*bnScaleData=*/scale.template data>(), + /*bnScaleData=*/ + new_scale.template data>(), /*bnBiasData=*/nullptr, /*dBnScaleData=*/ ctx.template Alloc>(d_scale), @@ -1000,7 +1020,7 @@ void BatchNormGradFunctor(const Context &ctx, data_desc_, ctx.template Alloc(&transformed_d_x), bn_param_desc_, - scale.template data>(), + new_scale.template data>(), ctx.template Alloc>(d_scale), ctx.template Alloc>(d_bias), epsilon, @@ -1023,7 +1043,7 @@ void BatchNormGradFunctor(const Context &ctx, BNBackwardData <<>>( d_y->data(), - scale.data>(), + new_scale.data>(), saved_mean_data, x.data(), saved_var_data, @@ -1051,7 +1071,7 @@ void BatchNormGradFunctor(const Context &ctx, BNBackwardData <<>>( d_y->data(), - scale.data>(), + new_scale.data>(), saved_mean_data, x.data(), saved_var_data, @@ -1080,7 +1100,7 @@ void BatchNormGradFunctor(const Context &ctx, BNBackwardData <<>>( d_y->data(), - scale.data>(), + new_scale.data>(), saved_mean_data, x.data(), saved_var_data, @@ -1134,8 +1154,8 @@ void BatchNormGradFunctor(const Context &ctx, auto px = x; inplace_functor(data_layout, ctx.template Alloc(&px), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), running_mean_data, running_var_data, epsilon, @@ -1152,14 +1172,15 @@ void BatchNormGradFunctor(const Context &ctx, if (data_layout == DataLayout::kNHWC) { if (d_x) { KeBNBackwardData - <<>>(d_y->data(), - scale.data>(), - running_var_data, - epsilon, - C, - H * W, - num, - d_x->data()); + <<>>( + d_y->data(), + new_scale.data>(), + running_var_data, + epsilon, + C, + H * W, + num, + d_x->data()); } if (d_scale && d_bias) { KeBNBackwardScaleBias @@ -1178,14 +1199,15 @@ void BatchNormGradFunctor(const Context &ctx, } else { if (d_x) { KeBNBackwardData - <<>>(d_y->data(), - scale.data>(), - running_var_data, - epsilon, - C, - H * W, - num, - d_x->data()); + <<>>( + d_y->data(), + new_scale.data>(), + running_var_data, + epsilon, + C, + H * W, + num, + d_x->data()); } if (d_scale && d_bias) { KeBNBackwardScaleBias @@ -1205,14 +1227,15 @@ void BatchNormGradFunctor(const Context &ctx, } else { if (d_x) { KeBNBackwardData - <<>>(d_y->data(), - scale.data>(), - running_var_data, - epsilon, - C, - H * W, - num, - d_x->data()); + <<>>( + d_y->data(), + new_scale.data>(), + running_var_data, + epsilon, + C, + H * W, + num, + d_x->data()); } if (d_scale && d_bias) { dim3 block; @@ -1262,8 +1285,8 @@ void BatchNormGradFunctor(const Context &ctx, template void BatchNormGradKernel(const Context &dev_ctx, const DenseTensor &x, - const DenseTensor &scale, - const DenseTensor &bias, + const paddle::optional &scale, + const paddle::optional &bias, const paddle::optional &mean, const paddle::optional &variance, const DenseTensor &saved_mean, @@ -1305,7 +1328,7 @@ template void BatchNormDoubleGradKernel( const Context &ctx, const DenseTensor &x, - const DenseTensor &scale, + const paddle::optional &scale, const paddle::optional &mean, const paddle::optional &variance, const DenseTensor &saved_mean, @@ -1338,10 +1361,20 @@ void BatchNormDoubleGradKernel( running_mean = mean.get_ptr(); running_variance = variance.get_ptr(); } + const auto &x_dims = x.dims(); + int N, C, H, W, D; + phi::funcs::ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); + auto *Scale = scale.get_ptr(); + phi::DenseTensor new_scale; + if (Scale) { + new_scale = scale.get(); + } else { + new_scale = phi::Full(ctx, {C}, static_cast(1)); + } phi::funcs::NormDoubleGradFunctor(ctx, data_layout, &x, - &scale, + &new_scale, &y_grad, &saved_mean, &saved_variance, diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 3b73935699bab..20aa02a5f2485 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -29,6 +29,7 @@ namespace cub = hipcub; #include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/batch_norm_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/norm_utils.cu.h" @@ -515,8 +516,8 @@ void BatchNormKernel(const Context &ctx, const DenseTensor &x, const DenseTensor &mean, const DenseTensor &variance, - const DenseTensor &scale, - const DenseTensor &bias, + const paddle::optional &scale, + const paddle::optional &bias, bool is_test, float momentum, float epsilon_f, @@ -551,6 +552,24 @@ void BatchNormKernel(const Context &ctx, auto dtype = phi::backends::gpu::CudnnDataType::type; + auto *Scale = scale.get_ptr(); + auto *Bias = bias.get_ptr(); + + phi::DenseTensor new_scale; + phi::DenseTensor new_bias; + + if (Scale) { + new_scale = scale.get(); + } else { + new_scale = phi::Full(ctx, {C}, static_cast(1)); + } + + if (Bias) { + new_bias = bias.get(); + } else { + new_bias = phi::Full(ctx, {C}, static_cast(0)); + } + #ifdef PADDLE_WITH_HIP auto compute_format = data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW; @@ -722,8 +741,8 @@ void BatchNormKernel(const Context &ctx, transformed_x.template data(), est_mean->template data>(), est_var->template data>(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -735,8 +754,8 @@ void BatchNormKernel(const Context &ctx, transformed_x.template data(), est_mean->template data>(), est_var->template data>(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -779,8 +798,8 @@ void BatchNormKernel(const Context &ctx, transformed_x.template data(), est_mean->template data>(), est_var->template data>(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -803,8 +822,8 @@ void BatchNormKernel(const Context &ctx, est_mean->template data>(), // est_var->template data>(), inv_var_ptr, - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -816,8 +835,8 @@ void BatchNormKernel(const Context &ctx, transformed_x.template data(), est_mean->template data>(), est_var->template data>(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -838,8 +857,8 @@ void BatchNormKernel(const Context &ctx, data_desc_, ctx.template Alloc(&transformed_y), bn_param_desc_, - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), est_mean->template data>(), est_var->template data>(), epsilon)); @@ -884,8 +903,8 @@ void BatchNormKernel(const Context &ctx, BNForwardTraining <<>>( transformed_x.template data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -900,8 +919,8 @@ void BatchNormKernel(const Context &ctx, BNForwardTraining <<>>( transformed_x.template data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -1002,8 +1021,8 @@ void BatchNormKernel(const Context &ctx, BNForwardTraining2DCompStat <<>>( transformed_x.template data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -1021,8 +1040,8 @@ void BatchNormKernel(const Context &ctx, BNForwardTraining2DWriteRes<<>>( transformed_x.template data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -1063,8 +1082,8 @@ void BatchNormKernel(const Context &ctx, BNForwardTraining2DChannelLastCompStat <<>>( transformed_x.template data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -1083,8 +1102,8 @@ void BatchNormKernel(const Context &ctx, BNForwardTraining2DChannelLastWriteRes <<>>( transformed_x.template data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -1155,8 +1174,8 @@ void BatchNormKernel(const Context &ctx, data_desc_, transformed_y.template data(), bn_param_desc_, - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), this_factor, ctx.template Alloc>(mean_out), ctx.template Alloc>(variance_out), @@ -1180,8 +1199,8 @@ void BatchNormKernel(const Context &ctx, data_desc_, ctx.template Alloc(&transformed_y), bn_param_desc_, - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), this_factor, ctx.template Alloc>(mean_out), ctx.template Alloc>(variance_out), diff --git a/paddle/phi/kernels/gpu/contiguous_kernel.cu b/paddle/phi/kernels/gpu/contiguous_kernel.cu index 357e104afb01c..49b253effd945 100644 --- a/paddle/phi/kernels/gpu/contiguous_kernel.cu +++ b/paddle/phi/kernels/gpu/contiguous_kernel.cu @@ -31,12 +31,12 @@ __global__ void ContiguousCaseZeroFunc( blockDim.z * blockDim.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x; - float coordinate[6] = {threadIdx.x, - threadIdx.y, - threadIdx.z, - blockIdx.x, - blockIdx.y, - blockIdx.z}; + int64_t coordinate[6] = {threadIdx.x, + threadIdx.y, + threadIdx.z, + blockIdx.x, + blockIdx.y, + blockIdx.z}; #pragma unroll for (int dim = N - 1; dim >= 0; --dim) { diff --git a/paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu b/paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu index a89373c607f7d..f8dc67f5bafe8 100644 --- a/paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu @@ -28,7 +28,7 @@ void CummaxGradKernel(const Context& dev_ctx, const DenseTensor& indices, const DenseTensor& out_grad, int axis, - int dtype, + DataType dtype, DenseTensor* x_grad) { dev_ctx.template Alloc(x_grad); phi::funcs::SetConstant functor; @@ -36,11 +36,11 @@ void CummaxGradKernel(const Context& dev_ctx, if (axis < 0) { axis = axis + x.dims().size(); } - auto indices_type = phi::TransToPhiDataType(dtype); - if (indices_type == DataType::INT32) { + + if (dtype == DataType::INT32) { phi::funcs::gpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { phi::funcs::gpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); } @@ -52,7 +52,7 @@ void CumminGradKernel(const Context& dev_ctx, const DenseTensor& indices, const DenseTensor& out_grad, int axis, - int dtype, + DataType dtype, DenseTensor* x_grad) { dev_ctx.template Alloc(x_grad); phi::funcs::SetConstant functor; @@ -60,11 +60,11 @@ void CumminGradKernel(const Context& dev_ctx, if (axis < 0) { axis = axis + x.dims().size(); } - auto indices_type = phi::TransToPhiDataType(dtype); - if (indices_type == DataType::INT32) { + + if (dtype == DataType::INT32) { phi::funcs::gpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { phi::funcs::gpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); } diff --git a/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu b/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu index bf836af72c58f..49903bde6ff99 100644 --- a/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu @@ -312,17 +312,16 @@ template void CummaxKernel(const Context& dev_ctx, const DenseTensor& x, int axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* indices) { - auto indices_type = phi::TransToPhiDataType(dtype); T init = std::is_floating_point::value ? (-1 * std::numeric_limits::infinity()) : std::numeric_limits::lowest(); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, init, out, indices); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, init, out, indices); } @@ -332,16 +331,15 @@ template void CumminKernel(const Context& dev_ctx, const DenseTensor& x, int axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* indices) { - auto indices_type = phi::TransToPhiDataType(dtype); T init = std::is_floating_point::value ? std::numeric_limits::infinity() : std::numeric_limits::max(); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, init, out, indices); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, init, out, indices); } diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index 1f1453a0c6408..3261243c986c0 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -324,8 +324,12 @@ PD_REGISTER_KERNEL(divide_grad, phi::dtype::float16, phi::dtype::bfloat16, double, + int8_t, + uint8_t, + int16_t, int, int64_t, + bool, phi::dtype::complex, phi::dtype::complex) {} @@ -339,6 +343,7 @@ PD_REGISTER_KERNEL(divide_double_grad, double, int, int64_t, + bool, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/expand_grad_kernel.cu b/paddle/phi/kernels/gpu/expand_grad_kernel.cu index 23206f752afde..224e435e58c85 100644 --- a/paddle/phi/kernels/gpu/expand_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/expand_grad_kernel.cu @@ -47,7 +47,13 @@ PD_REGISTER_KERNEL(expand_grad, phi::ExpandGradKernel, float, double, + int, + int64_t, + bool, + int16_t, + uint8_t, + int8_t, phi::dtype::float16, phi::dtype::bfloat16, - int, - int64_t) {} + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/expand_kernel.cu b/paddle/phi/kernels/gpu/expand_kernel.cu index 456aa9b3c5a34..dc632ce4d4e63 100644 --- a/paddle/phi/kernels/gpu/expand_kernel.cu +++ b/paddle/phi/kernels/gpu/expand_kernel.cu @@ -84,8 +84,13 @@ PD_REGISTER_KERNEL(expand, phi::ExpandKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, int, int64_t, - bool) {} + bool, + int16_t, + uint8_t, + int8_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu b/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu index f2be0f073a87d..5bb59357bc976 100644 --- a/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu @@ -24,14 +24,14 @@ namespace phi { template void FrobeniusNormKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out) { - reduce_all = recompute_reduce_all(x, dims, reduce_all); + reduce_all = recompute_reduce_all(x, dims.GetData(), reduce_all); auto out_dtype = x.dtype(); phi::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); SqrtKernel(dev_ctx, *out, out); } diff --git a/paddle/phi/kernels/gpu/full_kernel.cu b/paddle/phi/kernels/gpu/full_kernel.cu index 8829d32596be1..bd1d7db96cfec 100644 --- a/paddle/phi/kernels/gpu/full_kernel.cu +++ b/paddle/phi/kernels/gpu/full_kernel.cu @@ -44,6 +44,7 @@ void FullKernel(const Context& dev_ctx, out->Resize(phi::make_ddim(shape.GetData())); int numel = out->numel(); dev_ctx.template Alloc(out); + if (numel > 0) { // in transformer model the numel of outpout will be zero. std::vector inputs = {}; diff --git a/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu b/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu index da1045c27c58d..a40460b67202e 100644 --- a/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu @@ -64,5 +64,11 @@ PD_REGISTER_KERNEL(gather_nd_grad, double, int64_t, int, + uint8_t, + int8_t, + int16_t, + bool, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/gather_nd_kernel.cu b/paddle/phi/kernels/gpu/gather_nd_kernel.cu index b8ac4aa263afa..e89fa1fd74ed4 100644 --- a/paddle/phi/kernels/gpu/gather_nd_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_nd_kernel.cu @@ -18,7 +18,6 @@ #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/gather.cu.h" -#include "paddle/phi/kernels/funcs/scatter.cu.h" namespace phi { @@ -53,11 +52,15 @@ PD_REGISTER_KERNEL(gather_nd, GPU, ALL_LAYOUT, phi::GatherNdKernel, + bool, float, double, - int64_t, int, + int8_t, + int64_t, int16_t, - bool, + uint8_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu index 915c7f40fa2cb..d63d670945fba 100644 --- a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -289,4 +289,10 @@ PD_REGISTER_KERNEL(index_put_grad, int, int64_t, bool, - phi::dtype::float16) {} + int16_t, + uint8_t, + int8_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu index 3af220ce16b31..ee58eab21c53d 100644 --- a/paddle/phi/kernels/gpu/index_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -179,4 +179,10 @@ PD_REGISTER_KERNEL(index_put, int, int64_t, bool, - phi::dtype::float16) {} + int16_t, + uint8_t, + int8_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/quant_linear_kernel.cu b/paddle/phi/kernels/gpu/quant_linear_kernel.cu new file mode 100644 index 0000000000000..3fd8b2e429400 --- /dev/null +++ b/paddle/phi/kernels/gpu/quant_linear_kernel.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/quant_linear_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(quant_linear, + GPU, + ALL_LAYOUT, + phi::QuantLinearKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/reduce_kernel.cu b/paddle/phi/kernels/gpu/reduce_kernel.cu index 969a3dd1d9ca5..d9714d37febd9 100644 --- a/paddle/phi/kernels/gpu/reduce_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_kernel.cu @@ -370,6 +370,8 @@ PD_REGISTER_KERNEL(sum_grad, double, phi::dtype::float16, phi::dtype::bfloat16, + int8_t, + uint8_t, int16_t, int, int64_t, diff --git a/paddle/phi/kernels/gpu/rnn_grad_kernel.cu.cc b/paddle/phi/kernels/gpu/rnn_grad_kernel.cu.cc index 73d1c499f7b5d..3e8dfe813cad7 100644 --- a/paddle/phi/kernels/gpu/rnn_grad_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/rnn_grad_kernel.cu.cc @@ -152,7 +152,7 @@ void RnnGradKernel(const Context &dev_ctx, #endif weight_data = weight_whole.data(); } else { - weight_data = const_cast(weight_list[0]->data()); + weight_data = const_cast(weight_list[0]->data()); // NOLINT } DenseTensor weight_grad = Full(dev_ctx, {weight_numel}, 0); @@ -250,7 +250,7 @@ void RnnGradKernel(const Context &dev_ctx, SequenceLength, &workspace_size, &reserve_size, - const_cast(&dropout_state)); + const_cast(&dropout_state)); // NOLINT DenseTensor workspace_data_ = Empty(dev_ctx, {static_cast(workspace_size)}); @@ -315,7 +315,7 @@ void RnnGradKernel(const Context &dev_ctx, init_c_grad_data, workspace_data_.data(), workspace_size, - const_cast(reserve_data), + const_cast(reserve_data), // NOLINT reserve_size)); #endif } @@ -335,7 +335,7 @@ void RnnGradKernel(const Context &dev_ctx, weight_grad_data, workspace_data_.data(), workspace_size, - const_cast(reserve_data), + const_cast(reserve_data), // NOLINT reserve_size)); // permute weight grad list from weight grad tensor TensorToPermutedWeight( @@ -355,7 +355,7 @@ void RnnGradKernel(const Context &dev_ctx, workspace_size, rnn.weight_desc(), weight_grad_data, - const_cast(reserve_data), + const_cast(reserve_data), // NOLINT reserve_size)); #endif } @@ -393,7 +393,7 @@ void RnnGradKernel(const Context &dev_ctx, nullptr, workspace_data_.data(), workspace_size, - const_cast(reserve_data), + const_cast(reserve_data), // NOLINT reserve_size)); } @@ -411,7 +411,7 @@ void RnnGradKernel(const Context &dev_ctx, workspace_size, rnn.weight_desc(), weight_grad_data, - const_cast(reserve_data), + const_cast(reserve_data), // NOLINT reserve_size)); } #else diff --git a/paddle/phi/kernels/gpu/rnn_kernel.cu.cc b/paddle/phi/kernels/gpu/rnn_kernel.cu.cc index fae2250190b64..82800607bae9d 100644 --- a/paddle/phi/kernels/gpu/rnn_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/rnn_kernel.cu.cc @@ -253,7 +253,7 @@ void RnnKernel(const Context &dev_ctx, for (auto weight_item : weight_list) { size_t len = weight_item->numel(); auto dim = weight_item->dims(); - const_cast(weight_item) + const_cast(weight_item) // NOLINT ->ShareDataWith( weight_whole.Slice(static_cast(offset), static_cast(offset + len))) @@ -263,7 +263,7 @@ void RnnKernel(const Context &dev_ctx, } #endif } else { - w_data = const_cast(weight_list[0]->data()); + w_data = const_cast(weight_list[0]->data()); // NOLINT } RNNDescriptors rnn(seq_length, diff --git a/paddle/phi/kernels/gpu/scale_kernel.cu b/paddle/phi/kernels/gpu/scale_kernel.cu index bb6aca1a1b637..a445784d7822e 100644 --- a/paddle/phi/kernels/gpu/scale_kernel.cu +++ b/paddle/phi/kernels/gpu/scale_kernel.cu @@ -71,6 +71,7 @@ PD_REGISTER_KERNEL(scale, GPU, ALL_LAYOUT, phi::ScaleKernel, + bool, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/gpu/set_value_grad_kernel.cu b/paddle/phi/kernels/gpu/set_value_grad_kernel.cu index eb0b3189bb29c..66688b417ae30 100644 --- a/paddle/phi/kernels/gpu/set_value_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/set_value_grad_kernel.cu @@ -28,6 +28,9 @@ PD_REGISTER_KERNEL(set_value_grad, int, int64_t, bool, + int16_t, + uint8_t, + int8_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, diff --git a/paddle/phi/kernels/gpu/set_value_kernel.cu b/paddle/phi/kernels/gpu/set_value_kernel.cu index 0abd07c660af9..e97e01f271390 100644 --- a/paddle/phi/kernels/gpu/set_value_kernel.cu +++ b/paddle/phi/kernels/gpu/set_value_kernel.cu @@ -28,6 +28,9 @@ PD_REGISTER_KERNEL(set_value, int, int64_t, bool, + int16_t, + uint8_t, + int8_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, @@ -41,6 +44,9 @@ PD_REGISTER_KERNEL(set_value_with_tensor, int, int64_t, bool, + int16_t, + uint8_t, + int8_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, diff --git a/paddle/phi/kernels/gpu/sign_kernel.cu.cc b/paddle/phi/kernels/gpu/sign_kernel.cu.cc index 71cd1d39b687d..bbccc906a06e3 100644 --- a/paddle/phi/kernels/gpu/sign_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/sign_kernel.cu.cc @@ -25,6 +25,10 @@ PD_REGISTER_KERNEL(sign, GPU, ALL_LAYOUT, phi::SignKernel, + int8_t, + int16_t, + int32_t, + int64_t, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/gpu/slice_grad_kernel.cu.cc b/paddle/phi/kernels/gpu/slice_grad_kernel.cu.cc index 89a6fad5df02a..858afa0178938 100644 --- a/paddle/phi/kernels/gpu/slice_grad_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/slice_grad_kernel.cu.cc @@ -23,11 +23,13 @@ PD_REGISTER_KERNEL(slice_grad, ALL_LAYOUT, phi::SliceGradKernel, bool, - uint8_t, int, + uint8_t, int64_t, float, double, + int16_t, + int8_t, phi::dtype::complex, phi::dtype::complex, phi::dtype::bfloat16, @@ -43,6 +45,8 @@ PD_REGISTER_KERNEL(slice_array_grad, int64_t, float, double, + int16_t, + int8_t, phi::dtype::complex, phi::dtype::complex, phi::dtype::bfloat16, @@ -58,6 +62,8 @@ PD_REGISTER_KERNEL(slice_array_dense_grad, int64_t, float, double, + int16_t, + int8_t, phi::dtype::complex, phi::dtype::complex, phi::dtype::bfloat16, diff --git a/paddle/phi/kernels/gpu/slice_kernel.cu.cc b/paddle/phi/kernels/gpu/slice_kernel.cu.cc index 5b011b32169ee..2dc9d6db78a3c 100644 --- a/paddle/phi/kernels/gpu/slice_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/slice_kernel.cu.cc @@ -28,6 +28,8 @@ PD_REGISTER_KERNEL(slice, int64_t, float, double, + int16_t, + int8_t, phi::dtype::complex, phi::dtype::complex, phi::dtype::bfloat16, @@ -43,6 +45,8 @@ PD_REGISTER_KERNEL(slice_array, int64_t, float, double, + int16_t, + int8_t, phi::dtype::complex, phi::dtype::complex, phi::dtype::bfloat16, @@ -58,6 +62,8 @@ PD_REGISTER_KERNEL(slice_array_dense, int64_t, float, double, + int16_t, + int8_t, phi::dtype::complex, phi::dtype::complex, phi::dtype::bfloat16, diff --git a/paddle/phi/kernels/gpu/stack_grad_kernel.cu b/paddle/phi/kernels/gpu/stack_grad_kernel.cu index 6c72a3562e6a7..a33f1db8eb48e 100644 --- a/paddle/phi/kernels/gpu/stack_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/stack_grad_kernel.cu @@ -46,12 +46,15 @@ PD_REGISTER_KERNEL(stack_grad, GPU, ALL_LAYOUT, phi::StackGradKernel, + bool, float, double, - bool, - int64_t, int, - uint8_t, int8_t, + int64_t, + uint8_t, + int16_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/stack_kernel.cu b/paddle/phi/kernels/gpu/stack_kernel.cu index e1d7d4e6f389c..bb34caedf7638 100644 --- a/paddle/phi/kernels/gpu/stack_kernel.cu +++ b/paddle/phi/kernels/gpu/stack_kernel.cu @@ -34,12 +34,15 @@ PD_REGISTER_KERNEL(stack, GPU, ALL_LAYOUT, phi::StackKernel, + bool, float, double, - bool, - int64_t, int, - uint8_t, int8_t, + int64_t, + int16_t, + uint8_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/strided_copy_kernel.cu b/paddle/phi/kernels/gpu/strided_copy_kernel.cu index e72eca2f936e1..65dae3fc89efe 100644 --- a/paddle/phi/kernels/gpu/strided_copy_kernel.cu +++ b/paddle/phi/kernels/gpu/strided_copy_kernel.cu @@ -48,127 +48,6 @@ __global__ void StridedCopyFunc( } } -template -__global__ void StridedCopyCaseZeroFunc( - const T* input_data, - phi::Array input_stride, - T* output_data, - phi::Array output_stride) { - int64_t input_offset = (blockIdx.z * gridDim.y * gridDim.x + - blockIdx.y * gridDim.x + blockIdx.x) * - blockDim.z * blockDim.y * blockDim.x + - threadIdx.z * blockDim.y * blockDim.x + - threadIdx.y * blockDim.x + threadIdx.x; - int64_t output_offset = input_offset; - float coordinate[6] = {threadIdx.x, - threadIdx.y, - threadIdx.z, - blockIdx.x, - blockIdx.y, - blockIdx.z}; - -#pragma unroll - for (int dim = RANK - 1; dim >= 0; --dim) { - input_offset += coordinate[RANK - 1 - dim] * input_stride[dim]; - output_offset += coordinate[RANK - 1 - dim] * output_stride[dim]; - } - - output_data[output_offset] = input_data[input_offset]; -} - -template -__global__ void StridedCopyCaseOneFunc( - const T* input_data, - phi::Array input_stride, - T* out_data, - phi::Array output_stride, - phi::Array dims, - const int64_t x_max) { - int64_t x = blockIdx.x * blockDim.x + threadIdx.x; - if (x < x_max) { - int64_t input_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; - int64_t output_offset = input_offset; - - int64_t reg_dims[6] = { - dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; - int64_t coordinate[phi::DDim::kMaxRank + 1]; - - switch (N) { - case 1: - coordinate[0] = x % reg_dims[0]; - break; - case 2: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - break; - case 3: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - break; - case 4: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - break; - case 5: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - break; - case 6: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - break; - case 7: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - break; - case 8: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; - break; - case 9: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; - coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); - break; - } - -#pragma unroll - for (int dim = N - 1; dim >= 0; --dim) { - input_offset += coordinate[N - 1 - dim] * input_stride[dim]; - output_offset += coordinate[N - 1 - dim] * output_stride[dim]; - } - - out_data[output_offset] = input_data[input_offset]; - } -} - template __global__ void Strided2ContiguousFunc( const T* input_data, @@ -192,123 +71,6 @@ __global__ void Strided2ContiguousFunc( } } -template -__global__ void Strided2ContiguousCaseZeroFunc( - const T* input_data, - phi::Array input_stride, - T* output_data) { - int64_t input_offset = 0; - int64_t output_offset = (blockIdx.z * gridDim.y * gridDim.x + - blockIdx.y * gridDim.x + blockIdx.x) * - blockDim.z * blockDim.y * blockDim.x + - threadIdx.z * blockDim.y * blockDim.x + - threadIdx.y * blockDim.x + threadIdx.x; - float coordinate[6] = {threadIdx.x, - threadIdx.y, - threadIdx.z, - blockIdx.x, - blockIdx.y, - blockIdx.z}; - -#pragma unroll - for (int dim = RANK - 1; dim >= 0; --dim) { - input_offset += coordinate[RANK - 1 - dim] * input_stride[dim]; - } - - output_data[output_offset] = input_data[input_offset]; -} - -template -__global__ void Strided2ContiguousCaseOneFunc( - const T* input_data, - phi::Array input_stride, - T* out_data, - phi::Array dims, - const int64_t x_max) { - int64_t x = blockIdx.x * blockDim.x + threadIdx.x; - if (x < x_max) { - int64_t input_offset = 0; - int64_t output_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; - - int64_t reg_dims[6] = { - dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; - int64_t coordinate[phi::DDim::kMaxRank + 1]; - - switch (N) { - case 1: - coordinate[0] = x % reg_dims[0]; - break; - case 2: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - break; - case 3: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - break; - case 4: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - break; - case 5: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - break; - case 6: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - break; - case 7: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - break; - case 8: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; - break; - case 9: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; - coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); - break; - } - -#pragma unroll - for (int dim = N - 1; dim >= 0; --dim) { - input_offset += coordinate[N - 1 - dim] * input_stride[dim]; - } - - out_data[output_offset] = input_data[input_offset]; - } -} - template __global__ void Contiguous2StridedFunc( const T* input_data, @@ -332,123 +94,6 @@ __global__ void Contiguous2StridedFunc( } } -template -__global__ void Contiguous2StridedCaseZeroFunc( - const T* input_data, - T* output_data, - phi::Array output_stride) { - int64_t input_offset = (blockIdx.z * gridDim.y * gridDim.x + - blockIdx.y * gridDim.x + blockIdx.x) * - blockDim.z * blockDim.y * blockDim.x + - threadIdx.z * blockDim.y * blockDim.x + - threadIdx.y * blockDim.x + threadIdx.x; - int64_t output_offset = 0; - float coordinate[6] = {threadIdx.x, - threadIdx.y, - threadIdx.z, - blockIdx.x, - blockIdx.y, - blockIdx.z}; - -#pragma unroll - for (int dim = RANK - 1; dim >= 0; --dim) { - output_offset += coordinate[RANK - 1 - dim] * output_stride[dim]; - } - - output_data[output_offset] = input_data[input_offset]; -} - -template -__global__ void Contiguous2StridedCaseOneFunc( - const T* input_data, - T* out_data, - phi::Array output_stride, - phi::Array dims, - const int64_t x_max) { - int64_t x = blockIdx.x * blockDim.x + threadIdx.x; - if (x < x_max) { - int64_t input_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; - int64_t output_offset = 0; - - int64_t reg_dims[6] = { - dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; - int64_t coordinate[phi::DDim::kMaxRank + 1]; - - switch (N) { - case 1: - coordinate[0] = x % reg_dims[0]; - break; - case 2: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - break; - case 3: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - break; - case 4: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - break; - case 5: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - break; - case 6: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - break; - case 7: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - break; - case 8: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; - break; - case 9: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; - coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); - break; - } - -#pragma unroll - for (int dim = N - 1; dim >= 0; --dim) { - output_offset += coordinate[N - 1 - dim] * output_stride[dim]; - } - - out_data[output_offset] = input_data[input_offset]; - } -} - template void StridedCopyKernel(const Context& dev_ctx, const DenseTensor& input, @@ -500,6 +145,8 @@ void StridedCopyKernel(const Context& dev_ctx, } auto numel = input.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; if (numel == 1) { #ifdef PADDLE_WITH_HIP @@ -517,649 +164,1088 @@ void StridedCopyKernel(const Context& dev_ctx, return; } - dim3 grid(1, 1, 1), block(1, 1, 1); - int rank = input_rank; - int tmp = 1; - - for (int i = 0; i < 3 && i < rank; i++) { - tmp *= input_dims[rank - 1 - i]; - } - - if (rank <= 6 && tmp <= 1024 && - (input_dims.size() < 3 || input_dims[rank - 3] <= 64)) { - if (rank >= 1) { - block.x = input_dims[rank - 1]; - } - - if (rank >= 2) { - block.y = input_dims[rank - 2]; - } - - if (rank >= 3) { - block.z = input_dims[rank - 3]; + if (input.meta().is_contiguous()) { + switch (input_rank) { + case 1: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 9, but received %d.", + input_rank)); } - - if (input.meta().is_contiguous()) { - switch (rank) { - case 1: - Contiguous2StridedCaseZeroFunc - <<>>( - input_data, output_data, output_stride); - break; - case 2: - Contiguous2StridedCaseZeroFunc - <<>>( - input_data, output_data, output_stride); - break; - case 3: - Contiguous2StridedCaseZeroFunc - <<>>( - input_data, output_data, output_stride); - break; - case 4: - grid.x = input_dims[rank - 4]; - Contiguous2StridedCaseZeroFunc - <<>>( - input_data, output_data, output_stride); - break; - case 5: - grid.x = input_dims[rank - 4]; - grid.y = input_dims[rank - 5]; - Contiguous2StridedCaseZeroFunc - <<>>( - input_data, output_data, output_stride); - break; - case 6: - grid.x = input_dims[rank - 4]; - grid.y = input_dims[rank - 5]; - grid.z = input_dims[rank - 6]; - Contiguous2StridedCaseZeroFunc - <<>>( - input_data, output_data, output_stride); - break; - } - } else if (out->meta().is_contiguous()) { - switch (rank) { - case 1: - Strided2ContiguousCaseZeroFunc - <<>>( - input_data, input_stride, output_data); - break; - case 2: - Strided2ContiguousCaseZeroFunc - <<>>( - input_data, input_stride, output_data); - break; - case 3: - Strided2ContiguousCaseZeroFunc - <<>>( - input_data, input_stride, output_data); - break; - case 4: - grid.x = input_dims[rank - 4]; - Strided2ContiguousCaseZeroFunc - <<>>( - input_data, input_stride, output_data); - break; - case 5: - grid.x = input_dims[rank - 4]; - grid.y = input_dims[rank - 5]; - Strided2ContiguousCaseZeroFunc - <<>>( - input_data, input_stride, output_data); - break; - case 6: - grid.x = input_dims[rank - 4]; - grid.y = input_dims[rank - 5]; - grid.z = input_dims[rank - 6]; - Strided2ContiguousCaseZeroFunc - <<>>( - input_data, input_stride, output_data); - break; - } - } else { - switch (rank) { - case 1: - StridedCopyCaseZeroFunc<<>>( - input_data, input_stride, output_data, output_stride); - break; - case 2: - StridedCopyCaseZeroFunc<<>>( - input_data, input_stride, output_data, output_stride); - break; - case 3: - StridedCopyCaseZeroFunc<<>>( - input_data, input_stride, output_data, output_stride); - break; - case 4: - grid.x = input_dims[rank - 4]; - StridedCopyCaseZeroFunc<<>>( - input_data, input_stride, output_data, output_stride); - break; - case 5: - grid.x = input_dims[rank - 4]; - grid.y = input_dims[rank - 5]; - StridedCopyCaseZeroFunc<<>>( - input_data, input_stride, output_data, output_stride); - break; - case 6: - grid.x = input_dims[rank - 4]; - grid.y = input_dims[rank - 5]; - grid.z = input_dims[rank - 6]; - StridedCopyCaseZeroFunc<<>>( - input_data, input_stride, output_data, output_stride); - break; - } + } else if (out->meta().is_contiguous()) { + switch (output_rank) { + case 1: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); } } else { - phi::Array cur_input_dims; - block.x = 512; - - if (input.meta().is_contiguous()) { - switch (rank) { - case 1: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1]); - break; - case 2: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - Contiguous2StridedCaseOneFunc - <<>>( - input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2]); - break; - case 3: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 4: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 5: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 6: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 7: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 8: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7] * input_dims[rank - 8]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - cur_input_dims[5] = input_dims[rank - 8]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 9: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7] * input_dims[rank - 8] * - input_dims[rank - 9]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - cur_input_dims[5] = input_dims[rank - 8]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of input should be less than 9, but received %d.", - rank)); - } - } else if (out->meta().is_contiguous()) { - switch (rank) { - case 1: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1]); - break; - case 2: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - Strided2ContiguousCaseOneFunc - <<>>( - input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2]); - break; - case 3: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 4: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 5: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 6: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 7: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 8: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7] * input_dims[rank - 8]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - cur_input_dims[5] = input_dims[rank - 8]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 9: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7] * input_dims[rank - 8] * - input_dims[rank - 9]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - cur_input_dims[5] = input_dims[rank - 8]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of input should be less than 9, but received %d.", - rank)); - } - } else { - switch (rank) { - case 1: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - StridedCopyCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1]); - break; - case 2: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2]); - break; - case 3: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 4: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 5: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 6: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 7: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 8: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7] * input_dims[rank - 8]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - cur_input_dims[5] = input_dims[rank - 8]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 9: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7] * input_dims[rank - 8] * - input_dims[rank - 9]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - cur_input_dims[5] = input_dims[rank - 8]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of input should be less than 9, but received %d.", - rank)); - } + switch (input_rank) { + case 1: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 2: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 3: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 4: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 5: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 6: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 7: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 8: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 9: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 9, but received %d.", + input_rank)); } } } diff --git a/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu b/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu index 08ac3da93bb49..b9ef080b97a9c 100644 --- a/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu @@ -24,10 +24,13 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad, ALL_LAYOUT, phi::StridedSliceRawGradKernel, bool, - int, - int64_t, float, double, + int, + int8_t, + int64_t, + int16_t, + uint8_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, @@ -38,10 +41,13 @@ PD_REGISTER_KERNEL(strided_slice_array_grad, ALL_LAYOUT, phi::StridedSliceArrayGradKernel, bool, - int, - int64_t, float, double, + int, + int8_t, + int64_t, + int16_t, + uint8_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, diff --git a/paddle/phi/kernels/gpu/strided_slice_kernel.cu b/paddle/phi/kernels/gpu/strided_slice_kernel.cu index 9b88322e20a06..1b278c01cb2b0 100644 --- a/paddle/phi/kernels/gpu/strided_slice_kernel.cu +++ b/paddle/phi/kernels/gpu/strided_slice_kernel.cu @@ -24,10 +24,13 @@ PD_REGISTER_KERNEL(strided_slice_raw, ALL_LAYOUT, phi::StridedSliceRawKernel, bool, - int, - int64_t, float, double, + int, + int8_t, + int64_t, + int16_t, + uint8_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, @@ -38,10 +41,13 @@ PD_REGISTER_KERNEL(strided_slice_array, ALL_LAYOUT, phi::StridedSliceArrayKernel, bool, - int, - int64_t, float, double, + int, + int8_t, + int64_t, + int16_t, + uint8_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, diff --git a/paddle/phi/kernels/gpu/transpose_kernel.cu b/paddle/phi/kernels/gpu/transpose_kernel.cu index 7a88c330673d6..323c228c16039 100644 --- a/paddle/phi/kernels/gpu/transpose_kernel.cu +++ b/paddle/phi/kernels/gpu/transpose_kernel.cu @@ -56,13 +56,13 @@ PD_REGISTER_KERNEL(transpose, ALL_LAYOUT, phi::TransposeKernel, bool, - uint8_t, - int8_t, - int16_t, float, double, + int8_t, + int16_t, int32_t, int64_t, + uint8_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, diff --git a/paddle/phi/kernels/gpu/unique_consecutive_kernel.cu b/paddle/phi/kernels/gpu/unique_consecutive_kernel.cu index 448e6ca38b3f5..9c32bff0ccb80 100644 --- a/paddle/phi/kernels/gpu/unique_consecutive_kernel.cu +++ b/paddle/phi/kernels/gpu/unique_consecutive_kernel.cu @@ -29,12 +29,11 @@ void UniqueConsecutiveKernel(const Context& dev_ctx, bool return_inverse, bool return_counts, const std::vector& axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* index, DenseTensor* counts) { - auto data_type = phi::TransToPhiDataType(dtype); - if (data_type == phi::DataType::INT32) { + if (dtype == phi::DataType::INT32) { PADDLE_ENFORCE_LE( x.numel() + 1, INT_MAX, @@ -48,7 +47,7 @@ void UniqueConsecutiveKernel(const Context& dev_ctx, // if 'axis' is not required, flatten the Tensor. if (axis.empty()) { phi::VisitDataTypeTiny( - data_type, + dtype, UniqueConsecutiveFlattenedCUDAFunctor( dev_ctx, x, out, return_inverse, return_counts, index, counts)); } else { @@ -56,7 +55,7 @@ void UniqueConsecutiveKernel(const Context& dev_ctx, int valid_axis = axis[0]; if (valid_axis < 0) valid_axis += x.dims().size(); phi::VisitDataTypeTiny( - data_type, + dtype, UniqueConsecutiveDimsCUDAFunctor(dev_ctx, x, out, diff --git a/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h index 385ea68e6e707..7954441f30c2b 100644 --- a/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h @@ -25,13 +25,13 @@ void FrobeniusNormGradKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& dout, - const std::vector& axis, + const IntArray& axis, bool keep_dim, bool reduce_all, DenseTensor* dx) { - reduce_all = recompute_reduce_all(x, axis, reduce_all); + reduce_all = recompute_reduce_all(x, axis.GetData(), reduce_all); ReduceGradKernel( - ctx, x, out, dout, axis, keep_dim, reduce_all, dx); + ctx, x, out, dout, axis.GetData(), keep_dim, reduce_all, dx); } } // namespace phi diff --git a/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h b/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h index 7dbc3ab3af7ba..eab028a1caccf 100644 --- a/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h +++ b/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h @@ -23,13 +23,13 @@ namespace phi { template void FrobeniusNormKernel(const Context& ctx, const DenseTensor& x, - const std::vector& axis, + const IntArray& axis, bool keep_dim, bool reduce_all, DenseTensor* out) { - reduce_all = recompute_reduce_all(x, axis, reduce_all); + reduce_all = recompute_reduce_all(x, axis.GetData(), reduce_all); Reduce( - ctx, x, reduce_all, axis, keep_dim, x.dtype(), out); + ctx, x, reduce_all, axis.GetData(), keep_dim, x.dtype(), out); } } // namespace phi diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index 64efa8790b4f9..373453d1eefa4 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -1446,6 +1446,10 @@ MatmulJudgeDtypeKernel(const Context& ctx, DenseTensor out_tmp; MatMulFunction( ctx, x_tmp, y_tmp, x_dims, y_dims, &out_tmp, transpose_x, transpose_y); + if (x.dtype() == phi::DataType::INT8) { + phi::CastKernel(ctx, out_tmp, phi::DataType::INT32, out); + return; + } phi::CastKernel(ctx, out_tmp, x.dtype(), out); } diff --git a/paddle/phi/kernels/impl/quant_linear_kernel_impl.h b/paddle/phi/kernels/impl/quant_linear_kernel_impl.h new file mode 100644 index 0000000000000..dbd548f7af6da --- /dev/null +++ b/paddle/phi/kernels/impl/quant_linear_kernel_impl.h @@ -0,0 +1,98 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" + +namespace phi { + +template +void QuantLinearKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& w, + const paddle::optional& bias, + int in_num_col_dims, + const std::string& activation_type, + bool padding_weights, + float scale_in, + const std::vector& scale_weights, + int quant_round_type, + float quant_max_bound, + float quant_min_bound, + DenseTensor* y) { + bool with_relu = activation_type == "relu" ? true : false; + auto w_dims = w.dims(); + + auto input_dims = x.dims(); + std::vector output_dims; + auto in_mat_dims = phi::flatten_to_2d(input_dims, in_num_col_dims); + auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0]; + auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1]; + PADDLE_ENFORCE_EQ( + in_mat_dims[1], + w_dims0, + phi::errors::InvalidArgument( + "The input's second dimension and weight's first dimension is " + "expected to be the same. But received input's second dimension is" + "%d, input's shape is %s; weight's first dimension is %d, weight's" + " shape is %s.", + in_mat_dims[1], + in_mat_dims, + w_dims0, + phi::make_ddim({w_dims0, w_dims1}))); + + output_dims.reserve(static_cast(in_num_col_dims + 1)); + for (int i = 0; i < in_num_col_dims; ++i) { + output_dims.push_back(input_dims[i]); + } + output_dims.push_back(w_dims1); + + y->Resize(phi::make_ddim(output_dims)); + y->set_lod(x.lod()); + + auto out_dims = y->dims(); + int M = phi::product(out_dims) / w_dims1; + + const T* input_data = x.data(); + auto* output_data = dev_ctx.template Alloc(y, y->numel() * sizeof(T)); + auto bias_data = bias ? bias.get_ptr()->data() : NULL; + + PADDLE_ENFORCE_EQ( + w.dtype(), + phi::DataType::INT8, + phi::errors::InvalidArgument( + "The weight's datatype is expected to be int8 when use quant. But " + "received weight's datatype is %d", + static_cast(w.dtype()))); + phi::funcs::FCInt8Functor fc; + fc(dev_ctx, + M, + w_dims1, + w_dims0, + input_data, + &w, + output_data, + scale_in, + scale_weights, + quant_round_type, + quant_max_bound, + quant_min_bound, + bias_data, + with_relu, + padding_weights); + return; +} + +} // namespace phi diff --git a/paddle/phi/kernels/kps/elementwise_kernel.cu b/paddle/phi/kernels/kps/elementwise_kernel.cu index 584e026241bde..6de33dd78d2d0 100644 --- a/paddle/phi/kernels/kps/elementwise_kernel.cu +++ b/paddle/phi/kernels/kps/elementwise_kernel.cu @@ -307,8 +307,12 @@ PD_REGISTER_KERNEL(divide, phi::DivideKernel, float, double, + int8_t, + uint8_t, + int16_t, int, int64_t, + bool, float16, bfloat16, complex64, diff --git a/paddle/phi/kernels/kps/reduce_kernel.cu b/paddle/phi/kernels/kps/reduce_kernel.cu index 1bc00cf11cbdb..506bd36e828bc 100644 --- a/paddle/phi/kernels/kps/reduce_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_kernel.cu @@ -369,6 +369,8 @@ PD_REGISTER_KERNEL(sum_raw, double, float16, bfloat16, + int8_t, + uint8_t, int16_t, int, int64_t, diff --git a/paddle/phi/kernels/legacy/cpu/elementwise_divide_kernel.cc b/paddle/phi/kernels/legacy/cpu/elementwise_divide_kernel.cc index ad09d6830f974..6f4debdcb216f 100644 --- a/paddle/phi/kernels/legacy/cpu/elementwise_divide_kernel.cc +++ b/paddle/phi/kernels/legacy/cpu/elementwise_divide_kernel.cc @@ -62,5 +62,6 @@ PD_REGISTER_KERNEL(divide_raw, double, int, int64_t, + bool, complex64, complex128) {} diff --git a/paddle/phi/kernels/legacy/kps/elementwise_kernel.cu b/paddle/phi/kernels/legacy/kps/elementwise_kernel.cu index f07164bc16885..ad802ee190861 100644 --- a/paddle/phi/kernels/legacy/kps/elementwise_kernel.cu +++ b/paddle/phi/kernels/legacy/kps/elementwise_kernel.cu @@ -77,8 +77,12 @@ PD_REGISTER_KERNEL(divide_raw, phi::DivideRawKernel, float, double, + int8_t, + uint8_t, + int16_t, int, int64_t, + bool, float16, bfloat16, complex64, diff --git a/paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc b/paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc index 00aee2d41b153..2e4bf779d26cd 100644 --- a/paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc @@ -121,19 +121,25 @@ PD_REGISTER_KERNEL(floor_divide_raw, ALL_LAYOUT, phi::FloorDivideRawKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + int32_t, + int64_t) {} PD_REGISTER_KERNEL(maximum_raw, XPU, ALL_LAYOUT, phi::MaximumRawKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + int32_t, + int64_t) {} PD_REGISTER_KERNEL(minimum_raw, XPU, ALL_LAYOUT, phi::MinimumRawKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + int32_t, + int64_t) {} PD_REGISTER_KERNEL(remainder_raw, XPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/onednn/batch_norm_grad_kernel.cc b/paddle/phi/kernels/onednn/batch_norm_grad_kernel.cc index e3e0fef11e913..e648686f3d2e7 100644 --- a/paddle/phi/kernels/onednn/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/batch_norm_grad_kernel.cc @@ -21,8 +21,8 @@ template void phi::BatchNormGradFunctor( \ const ::phi::backend##Context& dev_ctx, \ const DenseTensor& x, \ - const DenseTensor& scale, \ - const DenseTensor& bias, \ + const paddle::optional& scale, \ + const paddle::optional& bias, \ const paddle::optional& mean, \ const paddle::optional& variance, \ const DenseTensor& saved_mean, \ @@ -45,8 +45,8 @@ namespace phi { template void BatchNormGradFunctor(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, @@ -63,8 +63,39 @@ void BatchNormGradFunctor(const Context& dev_ctx, DenseTensor* x_grad, DenseTensor* scale_grad, DenseTensor* bias_grad) { - funcs::BatchNormOneDNNHandler handler( - dev_ctx.GetEngine(), dev_ctx.GetPlace(), epsilon, &x, &scale, &y_grad); + auto Scale = scale.get_ptr(); + auto Bias = bias.get_ptr(); + const bool use_scale = scale ? true : false; + const bool use_bias = bias ? true : false; + + std::vector scale_tz; + std::vector bias_tz; + if (use_scale) { + scale_tz = vectorize(Scale->dims()); + PADDLE_ENFORCE_EQ( + scale_tz.size(), + 1, + errors::InvalidArgument( + "Dims of scale tensor must be 1, but received scale's size is %d", + scale_tz.size())); + } + if (use_bias) { + bias_tz = vectorize(Bias->dims()); + PADDLE_ENFORCE_EQ( + bias_tz.size(), + 1, + errors::InvalidArgument( + "Dims of bias tensor must be 1, but received bias's size is %d", + bias_tz.size())); + } + + funcs::BatchNormOneDNNHandler handler(dev_ctx.GetEngine(), + dev_ctx.GetPlace(), + epsilon, + &x, + use_scale, + use_bias, + &y_grad); T* diff_scale_data = dev_ctx.template Alloc(scale_grad); T* diff_shift_data = dev_ctx.template Alloc(bias_grad); @@ -73,24 +104,29 @@ void BatchNormGradFunctor(const Context& dev_ctx, auto mean_memory = handler.AcquireMeanMemory(&saved_mean); auto variance_memory = handler.AcquireVarianceMemory(&saved_variance); auto diff_dst_memory = handler.AcquireDiffDstMemory(&y_grad); - auto scaleshift_mems = handler.AcquireScaleShiftMemory(&scale, &bias); auto diff_src_memory = handler.AcquireDiffSrcMemory(x_grad); - auto diff_scaleshift_mems = - handler.AcquireDiffScaleShiftMemory(diff_scale_data, diff_shift_data); auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive(); + std::shared_ptr scale_memory(nullptr); + std::shared_ptr diff_scale_memory(nullptr); + std::shared_ptr diff_shift_memory(nullptr); + if (scale) { + scale_memory = handler.AcquireScaleMemory(Scale); + diff_scale_memory = handler.AcquireDiffScaleMemory(diff_scale_data); + } + if (bias) diff_shift_memory = handler.AcquireDiffShiftMemory(diff_shift_data); + auto& astream = OneDNNContext::tls().get_stream(); - batch_norm_bwd_p->execute( - astream, - {{DNNL_ARG_SRC, *src_memory}, - {DNNL_ARG_MEAN, *mean_memory}, - {DNNL_ARG_VARIANCE, *variance_memory}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory}, - {DNNL_ARG_SCALE, *(std::get<0>(scaleshift_mems))}, - {DNNL_ARG_DIFF_SRC, *diff_src_memory}, - {DNNL_ARG_DIFF_SCALE, *(std::get<0>(diff_scaleshift_mems))}, - {DNNL_ARG_DIFF_SHIFT, *(std::get<1>(diff_scaleshift_mems))}}); + batch_norm_bwd_p->execute(astream, + {{DNNL_ARG_SRC, *src_memory}, + {DNNL_ARG_MEAN, *mean_memory}, + {DNNL_ARG_VARIANCE, *variance_memory}, + {DNNL_ARG_DIFF_DST, *diff_dst_memory}, + {DNNL_ARG_SCALE, *scale_memory}, + {DNNL_ARG_DIFF_SRC, *diff_src_memory}, + {DNNL_ARG_DIFF_SCALE, *diff_scale_memory}, + {DNNL_ARG_DIFF_SHIFT, *diff_shift_memory}}); astream.wait(); // set memory descriptor of out tensor @@ -100,8 +136,8 @@ void BatchNormGradFunctor(const Context& dev_ctx, template void BatchNormGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, diff --git a/paddle/phi/kernels/onednn/batch_norm_kernel.cc b/paddle/phi/kernels/onednn/batch_norm_kernel.cc index 61172c074e26a..070058062b6f4 100644 --- a/paddle/phi/kernels/onednn/batch_norm_kernel.cc +++ b/paddle/phi/kernels/onednn/batch_norm_kernel.cc @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/batch_norm_kernel.h" +#include "glog/logging.h" #include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" @@ -28,8 +29,8 @@ void BatchNormKernel(const Context &dev_ctx, const DenseTensor &x, const DenseTensor &mean, const DenseTensor &variance, - const DenseTensor &scale, - const DenseTensor &bias, + const paddle::optional &scale, + const paddle::optional &bias, bool is_test, float momentum, float epsilon, @@ -48,17 +49,20 @@ void BatchNormKernel(const Context &dev_ctx, dev_ctx.HasDnnAttr("fuse_with_relu") ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("fuse_with_relu")) : false; + const bool use_scale = scale ? true : false; + const bool use_bias = bias ? true : false; funcs::BatchNormOneDNNHandler handler(dev_ctx.GetEngine(), dev_ctx.GetPlace(), &x, epsilon, + use_scale, + use_bias, fuse_with_relu, global_stats, test_mode); auto src_memory = handler.AcquireSrcMemory(&x); - auto scaleshift_mems = handler.AcquireScaleShiftMemory(&scale, &bias); auto dst_memory = handler.AcquireDstMemory(y); auto batch_norm_p = handler.AcquireForwardPrimitive(); @@ -76,18 +80,25 @@ void BatchNormKernel(const Context &dev_ctx, y->set_mem_desc(dst_memory->get_desc()); + std::shared_ptr scale_memory(nullptr); + std::shared_ptr shift_memory(nullptr); + auto Scale = scale.get_ptr(); + auto Bias = bias.get_ptr(); + if (scale) scale_memory = handler.AcquireScaleMemory(Scale); + if (bias) shift_memory = handler.AcquireShiftMemory(Bias); + auto &astream = OneDNNContext::tls().get_stream(); batch_norm_p->execute(astream, {{DNNL_ARG_SRC, *src_memory}, - {DNNL_ARG_SCALE, *(std::get<0>(scaleshift_mems))}, - {DNNL_ARG_SHIFT, *(std::get<1>(scaleshift_mems))}, + {DNNL_ARG_SCALE, *scale_memory}, + {DNNL_ARG_SHIFT, *shift_memory}, {DNNL_ARG_MEAN, *mean_memory}, {DNNL_ARG_VARIANCE, *variance_memory}, {DNNL_ARG_DST, *dst_memory}}); astream.wait(); if (!global_stats) { - const unsigned int C = phi::vectorize(scale.dims())[0]; + const unsigned int C = phi::vectorize(mean.dims())[0]; // mkldnn only compute stats for current batch // so we need compute momentum stats via Eigen lib diff --git a/paddle/phi/kernels/onednn/conv_handler.h b/paddle/phi/kernels/onednn/conv_handler.h index 1473cb1b5a248..86baabf45afc1 100644 --- a/paddle/phi/kernels/onednn/conv_handler.h +++ b/paddle/phi/kernels/onednn/conv_handler.h @@ -180,7 +180,7 @@ class ConvOneDNNHandlerT weights_md = funcs::OneDNNMemDesc( weights_tz, data_type, funcs::OneDNNMemoryFormat::any); } - if (input->dims().size() == 4 && input->dims()[1] == 3) { + if (input->dims().size() == 4 && input->dims()[1] <= 4) { chosen_memory_format = funcs::OneDNNMemoryFormat::nhwc; } const auto dst_md = funcs::OneDNNMemDesc( diff --git a/paddle/phi/kernels/reduce_sum_kernel.cc b/paddle/phi/kernels/reduce_sum_kernel.cc index 6f2dc34673f67..0ab3bcca771fb 100644 --- a/paddle/phi/kernels/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/reduce_sum_kernel.cc @@ -48,6 +48,8 @@ PD_REGISTER_KERNEL(sum, int16_t, int, int64_t, + uint8_t, + int8_t, complex64, complex128) { kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); @@ -66,6 +68,8 @@ PD_REGISTER_KERNEL(sum, int16_t, int, int64_t, + uint8_t, + int8_t, complex64, complex128) { kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); diff --git a/paddle/phi/kernels/squeeze_grad_kernel.cc b/paddle/phi/kernels/squeeze_grad_kernel.cc index 473acf9d7a1d1..75294557ace25 100644 --- a/paddle/phi/kernels/squeeze_grad_kernel.cc +++ b/paddle/phi/kernels/squeeze_grad_kernel.cc @@ -40,12 +40,14 @@ PD_REGISTER_KERNEL(squeeze_grad, phi::SqueezeGradKernel, float, double, - phi::dtype::bfloat16, bool, int, uint8_t, int8_t, + int16_t, int64_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -62,6 +64,7 @@ PD_REGISTER_KERNEL(squeeze_grad, int, uint8_t, int8_t, + int16_t, int64_t, phi::dtype::complex, phi::dtype::complex) {} @@ -76,6 +79,7 @@ PD_REGISTER_KERNEL(squeeze_grad, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, diff --git a/paddle/phi/kernels/squeeze_kernel.cc b/paddle/phi/kernels/squeeze_kernel.cc index d495b040921b5..684fd0298a3df 100644 --- a/paddle/phi/kernels/squeeze_kernel.cc +++ b/paddle/phi/kernels/squeeze_kernel.cc @@ -49,14 +49,16 @@ PD_REGISTER_KERNEL(squeeze_infer, CPU, ALL_LAYOUT, phi::SqueezeInferKernel, + bool, float, double, - phi::dtype::bfloat16, - bool, int, - uint8_t, int8_t, int64_t, + int16_t, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -64,14 +66,16 @@ PD_REGISTER_KERNEL(squeeze, CPU, ALL_LAYOUT, phi::SqueezeKernel, + bool, float, double, - phi::dtype::bfloat16, - bool, int, - uint8_t, int8_t, int64_t, + int16_t, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -79,15 +83,16 @@ PD_REGISTER_KERNEL(squeeze_infer, GPU, ALL_LAYOUT, phi::SqueezeInferKernel, + bool, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - bool, int, - uint8_t, int8_t, int64_t, + int16_t, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -95,15 +100,16 @@ PD_REGISTER_KERNEL(squeeze, GPU, ALL_LAYOUT, phi::SqueezeKernel, + bool, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - bool, int, - uint8_t, int8_t, int64_t, + int16_t, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} #endif @@ -116,6 +122,7 @@ PD_REGISTER_KERNEL(squeeze_infer, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, @@ -129,6 +136,7 @@ PD_REGISTER_KERNEL(squeeze, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, diff --git a/paddle/phi/kernels/strided_slice_kernel.cc b/paddle/phi/kernels/strided_slice_kernel.cc index 68377dbe8468e..0852cc8830e2c 100644 --- a/paddle/phi/kernels/strided_slice_kernel.cc +++ b/paddle/phi/kernels/strided_slice_kernel.cc @@ -38,11 +38,15 @@ PD_REGISTER_KERNEL(strided_slice, CPU, ALL_LAYOUT, phi::StridedSliceKernel, - bool, - int, - int64_t, float, double, + bool, + int64_t, + int16_t, + int, + uint8_t, + int8_t, + phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -51,11 +55,14 @@ PD_REGISTER_KERNEL(strided_slice, GPU, ALL_LAYOUT, phi::StridedSliceKernel, - bool, - int, - int64_t, float, double, + bool, + int64_t, + int16_t, + int, + uint8_t, + int8_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, diff --git a/paddle/phi/kernels/unique_consecutive_kernel.h b/paddle/phi/kernels/unique_consecutive_kernel.h index ade35d4d49730..6c88f5947fc38 100644 --- a/paddle/phi/kernels/unique_consecutive_kernel.h +++ b/paddle/phi/kernels/unique_consecutive_kernel.h @@ -26,7 +26,7 @@ void UniqueConsecutiveKernel(const Context& dev_ctx, bool return_inverse, bool return_counts, const std::vector& axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* index, DenseTensor* counts); diff --git a/paddle/phi/kernels/unsqueeze_grad_kernel.cc b/paddle/phi/kernels/unsqueeze_grad_kernel.cc index 3c119db2c73d6..a281bb66b4c67 100644 --- a/paddle/phi/kernels/unsqueeze_grad_kernel.cc +++ b/paddle/phi/kernels/unsqueeze_grad_kernel.cc @@ -39,13 +39,14 @@ PD_REGISTER_KERNEL(unsqueeze_grad, phi::UnsqueezeGradKernel, float, double, - phi::dtype::bfloat16, bool, int, int16_t, uint8_t, int8_t, int64_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -56,14 +57,14 @@ PD_REGISTER_KERNEL(unsqueeze_grad, phi::UnsqueezeGradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, bool, int, int16_t, uint8_t, int8_t, int64_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -77,6 +78,7 @@ PD_REGISTER_KERNEL(unsqueeze_grad, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, diff --git a/paddle/phi/kernels/unsqueeze_kernel.cc b/paddle/phi/kernels/unsqueeze_kernel.cc index c08c31da4ef0c..1f023a7cfb5f4 100644 --- a/paddle/phi/kernels/unsqueeze_kernel.cc +++ b/paddle/phi/kernels/unsqueeze_kernel.cc @@ -27,7 +27,7 @@ void UnsqueezeInferKernel(const Context& dev_ctx, DenseTensor* out) { auto x_dims = x.dims(); auto out_dims = out->dims(); - if (axes.FromTensor()) { + if (axes.FromTensor() && out->dims()[0] == -1) { out_dims = funcs::GetUnsqueezeShape(axes.GetData(), x_dims); } out->Resize(out_dims); @@ -102,16 +102,16 @@ PD_REGISTER_KERNEL(unsqueeze, GPU, ALL_LAYOUT, phi::UnsqueezeKernel, + bool, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - bool, int, - int16_t, - uint8_t, int8_t, int64_t, + int16_t, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} #endif @@ -124,6 +124,7 @@ PD_REGISTER_KERNEL(unsqueeze_infer, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, @@ -137,6 +138,7 @@ PD_REGISTER_KERNEL(unsqueeze, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, diff --git a/paddle/phi/kernels/xpu/abs_kernel.cc b/paddle/phi/kernels/xpu/abs_kernel.cc index 7abdd1f0715b6..053e641041683 100644 --- a/paddle/phi/kernels/xpu/abs_kernel.cc +++ b/paddle/phi/kernels/xpu/abs_kernel.cc @@ -31,5 +31,12 @@ void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { } } // namespace phi -PD_REGISTER_KERNEL( - abs, XPU, ALL_LAYOUT, phi::AbsKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(abs, + XPU, + ALL_LAYOUT, + phi::AbsKernel, + float, + phi::dtype::float16, + int8_t, + int32_t, + int64_t) {} diff --git a/paddle/phi/kernels/xpu/activation_grad_kernel.cc b/paddle/phi/kernels/xpu/activation_grad_kernel.cc index d77d84ee5ae3a..a1b05366b56be 100644 --- a/paddle/phi/kernels/xpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_grad_kernel.cc @@ -700,6 +700,14 @@ PD_REGISTER_KERNEL(square_grad, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(swish_grad, + XPU, + ALL_LAYOUT, + phi::SwishGradKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} + PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) @@ -710,7 +718,6 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sin_grad, SinGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(cos_grad, CosGradKernel) diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index efac9b30ae2eb..4ab94cd52a4ea 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -566,8 +566,13 @@ PD_REGISTER_KERNEL( elu, XPU, ALL_LAYOUT, phi::EluKernel, float, phi::dtype::float16) {} PD_REGISTER_KERNEL( sigmoid, XPU, ALL_LAYOUT, phi::SigmoidKernel, float, phi::dtype::float16) {} -PD_REGISTER_KERNEL( - swish, XPU, ALL_LAYOUT, phi::SwishKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(swish, + XPU, + ALL_LAYOUT, + phi::SwishKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(hardsigmoid, XPU, ALL_LAYOUT, @@ -580,8 +585,13 @@ PD_REGISTER_KERNEL(leaky_relu, phi::LeakyReluKernel, float, phi::dtype::float16) {} -PD_REGISTER_KERNEL( - sqrt, XPU, ALL_LAYOUT, phi::SqrtKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(sqrt, + XPU, + ALL_LAYOUT, + phi::SqrtKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL( tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/arange_kernel.cc b/paddle/phi/kernels/xpu/arange_kernel.cc index 0ae1007e91d29..af3abc19aaddc 100644 --- a/paddle/phi/kernels/xpu/arange_kernel.cc +++ b/paddle/phi/kernels/xpu/arange_kernel.cc @@ -20,11 +20,11 @@ limitations under the License. */ namespace phi { template -void ArangeKernel(const Context& dev_ctx, - const DenseTensor& start, - const DenseTensor& end, - const DenseTensor& step, - DenseTensor* out) { +void ArangeTensorKernel(const Context& dev_ctx, + const DenseTensor& start, + const DenseTensor& end, + const DenseTensor& step, + DenseTensor* out) { T start_value = GetValue(dev_ctx, start); T end_value = GetValue(dev_ctx, end); T step_value = GetValue(dev_ctx, step); @@ -49,8 +49,14 @@ void ArangeKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - arange, XPU, ALL_LAYOUT, phi::ArangeKernel, float, double, int, int64_t) { +PD_REGISTER_KERNEL(arange_tensor, + XPU, + ALL_LAYOUT, + phi::ArangeTensorKernel, + float, + double, + int, + int64_t) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); diff --git a/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc index 09e62bbfd4bde..863bc2759b39a 100644 --- a/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/batch_norm_grad_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" @@ -72,8 +73,8 @@ static int CalculateInvVar(xpu::Context *ctx, template void BatchNormGradKernel(const Context &dev_ctx, const DenseTensor &x, - const DenseTensor &scale, - const DenseTensor &bias, + const paddle::optional &scale, + const paddle::optional &bias, const paddle::optional &mean, const paddle::optional &variance, const DenseTensor &saved_mean, @@ -133,9 +134,27 @@ void BatchNormGradKernel(const Context &dev_ctx, W = W * D; + auto *Scale = scale.get_ptr(); + auto *Bias = bias.get_ptr(); + + phi::DenseTensor new_scale; + phi::DenseTensor new_bias; + + if (Scale) { + new_scale = scale.get(); + } else { + new_scale = phi::Full(dev_ctx, {C}, static_cast(1)); + } + + if (Bias) { + new_bias = bias.get(); + } else { + new_bias = phi::Full(dev_ctx, {C}, static_cast(0)); + } + const auto *x_data = reinterpret_cast(x.data()); const auto *d_y_data = reinterpret_cast(y_grad.data()); - const auto *scale_data = scale.data(); + const auto *scale_data = new_scale.data(); // init output XPUType *x_grad_data = nullptr; @@ -151,22 +170,22 @@ void BatchNormGradKernel(const Context &dev_ctx, } PADDLE_ENFORCE_EQ( - scale.dims().size(), + new_scale.dims().size(), 1UL, phi::errors::InvalidArgument( "The size of scale's dimensions must equal to 1. But received: " "the size of scale's dimensions is [%d], the dimensions of scale " "is [%s].", - scale.dims().size(), - scale.dims())); + new_scale.dims().size(), + new_scale.dims())); PADDLE_ENFORCE_EQ( - scale.dims()[0], + new_scale.dims()[0], C, phi::errors::InvalidArgument( "The first dimension of scale must equal to Channels[%d]. But " "received: the first dimension of scale is [%d]", C, - scale.dims()[0])); + new_scale.dims()[0])); xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); @@ -203,8 +222,8 @@ void BatchNormGradKernel(const Context &dev_ctx, : saved_mean.data(); r = CalculateInvBNY(dev_ctx.x_context(), x_fp32_data, - scale.data(), - bias.data(), + new_scale.data(), + new_bias.data(), mean_data, inv_std_data, N, diff --git a/paddle/phi/kernels/xpu/batch_norm_kernel.cc b/paddle/phi/kernels/xpu/batch_norm_kernel.cc index e2f2d28182b67..2abb1686daed9 100644 --- a/paddle/phi/kernels/xpu/batch_norm_kernel.cc +++ b/paddle/phi/kernels/xpu/batch_norm_kernel.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/batch_norm_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" @@ -25,8 +26,8 @@ void BatchNormKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& mean, const DenseTensor& variance, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, bool is_test, float momentum, float epsilon, @@ -69,9 +70,27 @@ void BatchNormKernel(const Context& dev_ctx, W = W * D; + auto* Scale = scale.get_ptr(); + auto* Bias = bias.get_ptr(); + + phi::DenseTensor new_scale; + phi::DenseTensor new_bias; + + if (Scale) { + new_scale = scale.get(); + } else { + new_scale = phi::Full(dev_ctx, {C}, static_cast(1)); + } + + if (Bias) { + new_bias = bias.get(); + } else { + new_bias = phi::Full(dev_ctx, {C}, static_cast(0)); + } + const auto* x_data = reinterpret_cast(x.data()); - const auto* scale_data = scale.data(); - const auto* bias_data = bias.data(); + const auto* scale_data = new_scale.data(); + const auto* bias_data = new_bias.data(); // alloc memory auto* y_data = reinterpret_cast(dev_ctx.template Alloc(y)); diff --git a/paddle/phi/kernels/xpu/cast_kernel.cc b/paddle/phi/kernels/xpu/cast_kernel.cc index c5fd2d02e3360..bc76e919a6f96 100644 --- a/paddle/phi/kernels/xpu/cast_kernel.cc +++ b/paddle/phi/kernels/xpu/cast_kernel.cc @@ -33,11 +33,41 @@ void CastXPUKernelImpl(const Context& dev_ctx, return; } + if (std::is_same::value) { + int ret = xpu::copy(dev_ctx.x_context(), + reinterpret_cast(in_data), + reinterpret_cast(out_data), + x.numel() * phi::SizeOf(x.dtype())); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "copy"); + return; + } + + if (std::is_same::value && + !std::is_same::value || + !std::is_same::value && + std::is_same::value) { + // bfloat -> non float, or non float -> bfloat, use float buffer + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + float* cast_buffer = RAII_GUARD.alloc_l3_or_gm(numel); + // step 1: InT to float + int r = xpu::cast(dev_ctx.x_context(), + reinterpret_cast(in_data), + cast_buffer, + numel); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + // step 2: float to OutT + r = xpu::cast(dev_ctx.x_context(), + cast_buffer, + reinterpret_cast(out_data), + numel); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + return; + } + int r = xpu::cast(dev_ctx.x_context(), reinterpret_cast(in_data), reinterpret_cast(out_data), numel); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); } @@ -56,6 +86,9 @@ void CastKernel(const Context& dev_ctx, case DataType::FLOAT16: CastXPUKernelImpl(dev_ctx, x, out); break; + case DataType::BFLOAT16: + CastXPUKernelImpl(dev_ctx, x, out); + break; case DataType::INT64: CastXPUKernelImpl(dev_ctx, x, out); break; @@ -85,6 +118,7 @@ PD_REGISTER_KERNEL(cast, int32_t, float, phi::dtype::float16, + phi::dtype::bfloat16, int64_t, bool, int8_t, diff --git a/paddle/phi/kernels/xpu/dequantization_kernel.cc b/paddle/phi/kernels/xpu/dequantization_kernel.cc new file mode 100644 index 0000000000000..9dc9868e75fd9 --- /dev/null +++ b/paddle/phi/kernels/xpu/dequantization_kernel.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void DeQuantizeKernelImpl(const Context& ctx, + const DenseTensor& x, + float scale, + DenseTensor* y) { + using XPUInX = typename XPUTypeTrait::Type; + using XPUOutY = typename XPUTypeTrait::Type; + + auto* y_data = ctx.template Alloc(y); + const auto* x_data = x.data(); + int64_t len = x.numel(); + int max_ptr_size = ctx.x_context()->max_ptr_size(); + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + auto max_data = RAII_GUARD.alloc_l3_or_gm(max_ptr_size); + int r = xpu::constant(ctx.x_context(), max_data, max_ptr_size, scale); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + r = xpu::dequantization( + ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + len, + max_data); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "dequantization"); +} + +template +void DeQuantizeKernel(const Context& ctx, + const DenseTensor& x, + DataType out_dtype, + float scale, + DenseTensor* y) { + switch (out_dtype) { + case DataType::FLOAT32: + DeQuantizeKernelImpl(ctx, x, scale, y); + break; + case DataType::FLOAT16: + DeQuantizeKernelImpl(ctx, x, scale, y); + break; + default: + PADDLE_THROW(phi::errors::Unavailable( + "Not supported dequantize data type from %d -> %d ", + x.dtype(), + out_dtype)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + dequantize_xpu, XPU, ALL_LAYOUT, phi::DeQuantizeKernel, int16_t, int8_t) {} diff --git a/paddle/phi/kernels/xpu/elementwise_kernel.cc b/paddle/phi/kernels/xpu/elementwise_kernel.cc index 386ad2e13ff0e..83dce5437c9ec 100644 --- a/paddle/phi/kernels/xpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/xpu/elementwise_kernel.cc @@ -82,11 +82,25 @@ PD_REGISTER_KERNEL(floor_divide, ALL_LAYOUT, phi::FloorDivideKernel, float, - phi::dtype::float16) {} -PD_REGISTER_KERNEL( - maximum, XPU, ALL_LAYOUT, phi::MaximumKernel, float, phi::dtype::float16) {} -PD_REGISTER_KERNEL( - minimum, XPU, ALL_LAYOUT, phi::MinimumKernel, float, phi::dtype::float16) {} + phi::dtype::float16, + int32_t, + int64_t) {} +PD_REGISTER_KERNEL(maximum, + XPU, + ALL_LAYOUT, + phi::MaximumKernel, + float, + phi::dtype::float16, + int32_t, + int64_t) {} +PD_REGISTER_KERNEL(minimum, + XPU, + ALL_LAYOUT, + phi::MinimumKernel, + float, + phi::dtype::float16, + int32_t, + int64_t) {} PD_REGISTER_KERNEL(remainder, XPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/xpu/full_kernel.cc b/paddle/phi/kernels/xpu/full_kernel.cc index 4adccbb4be813..906078629f488 100644 --- a/paddle/phi/kernels/xpu/full_kernel.cc +++ b/paddle/phi/kernels/xpu/full_kernel.cc @@ -63,13 +63,20 @@ void FullLikeKernel(const Context& dev_ctx, T>::type>::type; auto common_type_value = static_cast(value); + bool is_out_range = true; + if (std::isinf(value) || std::isnan(value)) { + is_out_range = false; + } + if ((common_type_value >= + static_cast(std::numeric_limits::lowest())) && + (common_type_value <= + static_cast(std::numeric_limits::max()))) { + is_out_range = false; + } PADDLE_ENFORCE_EQ( - (common_type_value >= - static_cast(std::numeric_limits::lowest())) && - (common_type_value <= - static_cast(std::numeric_limits::max())), - true, + is_out_range, + false, phi::errors::InvalidArgument( "The filled value is out of range for target type, " "current kernel type is %s, the range should between %f " @@ -79,13 +86,6 @@ void FullLikeKernel(const Context& dev_ctx, static_cast(std::numeric_limits::max()), static_cast(value))); - PADDLE_ENFORCE_EQ(std::isnan(value), - false, - phi::errors::InvalidArgument("The filled value is NaN.")); - PADDLE_ENFORCE_EQ(std::isinf(value), - false, - phi::errors::InvalidArgument("The filled value is Inf.")); - auto out_data = reinterpret_cast(out->data()); if (out->numel() > 0) { int r = xpu::constant(dev_ctx.x_context(), diff --git a/paddle/phi/kernels/xpu/gaussian_kernel.cc b/paddle/phi/kernels/xpu/gaussian_kernel.cc index f8058f94e872f..2c4a29b6bfe51 100644 --- a/paddle/phi/kernels/xpu/gaussian_kernel.cc +++ b/paddle/phi/kernels/xpu/gaussian_kernel.cc @@ -50,4 +50,5 @@ PD_REGISTER_KERNEL(gaussian, ALL_LAYOUT, phi::GaussianKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/index_select_kernel.cc b/paddle/phi/kernels/xpu/index_select_kernel.cc index cbe6e99c43ae9..75c19aa028bce 100644 --- a/paddle/phi/kernels/xpu/index_select_kernel.cc +++ b/paddle/phi/kernels/xpu/index_select_kernel.cc @@ -13,8 +13,8 @@ // limitations under the License. #include "paddle/phi/kernels/index_select_kernel.h" - #include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/utils/data_type.h" @@ -40,14 +40,33 @@ void IndexSelectKernel(const Context& ctx, index_type, phi::DataType::INT32, phi::DataType::INT64)); - auto* in_data = x.data(); std::vector in_shape = phi::vectorize(input_dim); int index_len = output->dims()[dim]; T* out_data = ctx.template Alloc(output); int r = 0; + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + int8_t* index_ptr = nullptr; // temp xpu buffer + int byte_times = SizeOf(index_type); + if (index.place() == CPUPlace()) { + index_ptr = RAII_GUARD.alloc_l3_or_gm(byte_times * index.numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(index_ptr); + const void* cpu_idx_data = nullptr; + if (index_type == phi::DataType::INT64) { + cpu_idx_data = reinterpret_cast(index.data()); + } else if (index_type == phi::DataType::INT32) { + cpu_idx_data = reinterpret_cast(index.data()); + } + memory_utils::Copy(ctx.GetPlace(), + reinterpret_cast(index_ptr), + CPUPlace(), + cpu_idx_data, + byte_times * index.numel()); + } if (index_type == phi::DataType::INT64) { - const int64_t* index_data = index.data(); + const int64_t* index_data = + index_ptr ? reinterpret_cast(index_ptr) + : index.template data(); r = xpu::gather(ctx.x_context(), in_data, index_data, @@ -56,7 +75,8 @@ void IndexSelectKernel(const Context& ctx, index_len, dim); } else { - const int* index_data = index.data(); + const int* index_data = index_ptr ? reinterpret_cast(index_ptr) + : index.template data(); r = xpu::gather(ctx.x_context(), in_data, index_data, diff --git a/paddle/phi/kernels/xpu/nonzero_kernel.cc b/paddle/phi/kernels/xpu/nonzero_kernel.cc index edfdb1e6dfe8b..fe241965fb5c6 100644 --- a/paddle/phi/kernels/xpu/nonzero_kernel.cc +++ b/paddle/phi/kernels/xpu/nonzero_kernel.cc @@ -14,8 +14,7 @@ #include "paddle/phi/kernels/nonzero_kernel.h" -#include "paddle/phi/backends/xpu/xpu_context.h" -#include "paddle/phi/backends/xpu/xpu_header.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" @@ -34,13 +33,7 @@ void NonZeroKernel(const Context& dev_ctx, int* true_num = RAII_GUARD.alloc_l3_or_gm(1); int true_num_cpu; int ret = xpu::nonzero_count(dev_ctx.x_context(), cond_data, true_num, numel); - PADDLE_ENFORCE_EQ( - ret, - XPU_SUCCESS, - phi::errors::External( - "XPU nonzero_count kernel return wrong value[%d %s] in WhereIndex", - ret, - XPUAPIErrorMsg[ret])); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "nonzero_count"); memory_utils::Copy(phi::CPUPlace(), static_cast(&true_num_cpu), @@ -58,17 +51,12 @@ void NonZeroKernel(const Context& dev_ctx, auto condition_shape = phi::vectorize(dims); ret = xpu::where( dev_ctx.x_context(), cond_data, out_data, condition_shape, true_num_cpu); - PADDLE_ENFORCE_EQ(ret, - XPU_SUCCESS, - phi::errors::External( - "XPU masked_select kernel return wrong value[%d %s]", - ret, - XPUAPIErrorMsg[ret])); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "where"); } } // namespace phi PD_REGISTER_KERNEL( - nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) { + nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float, int64_t) { kernel->OutputAt(0).SetDataType(phi::DataType::INT64); } diff --git a/paddle/phi/kernels/xpu/prod_kernel.cc b/paddle/phi/kernels/xpu/prod_kernel.cc index 12f32959edb31..74e58ee63a7ca 100644 --- a/paddle/phi/kernels/xpu/prod_kernel.cc +++ b/paddle/phi/kernels/xpu/prod_kernel.cc @@ -50,4 +50,5 @@ void ProdKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(prod, XPU, ALL_LAYOUT, phi::ProdKernel, float) {} +PD_REGISTER_KERNEL( + prod, XPU, ALL_LAYOUT, phi::ProdKernel, float, int, int64_t) {} diff --git a/paddle/phi/kernels/xpu/quantization_kernel.cc b/paddle/phi/kernels/xpu/quantization_kernel.cc new file mode 100644 index 0000000000000..32b28b034e2da --- /dev/null +++ b/paddle/phi/kernels/xpu/quantization_kernel.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void QuantizeKernelImpl(const Context& ctx, + const DenseTensor& x, + float scale, + DenseTensor* y) { + using XPUInX = typename XPUTypeTrait::Type; + using XPUOutY = typename XPUTypeTrait::Type; + + auto* y_data = ctx.template Alloc(y); + const auto* x_data = x.data(); + int64_t len = x.numel(); + int max_ptr_size = ctx.x_context()->max_ptr_size(); + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + auto max_data = RAII_GUARD.alloc_l3_or_gm(max_ptr_size); + int r = xpu::constant(ctx.x_context(), max_data, max_ptr_size, scale); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + r = xpu::quantization( + ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + len, + max_data); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "quantization"); +} + +template +void QuantizeKernel(const Context& ctx, + const DenseTensor& x, + DataType out_dtype, + float scale, + DenseTensor* y) { + switch (out_dtype) { + case DataType::INT16: + QuantizeKernelImpl(ctx, x, scale, y); + break; + case DataType::INT8: + QuantizeKernelImpl(ctx, x, scale, y); + break; + default: + PADDLE_THROW(phi::errors::Unavailable( + "Not supported quantize data type from %d -> %d ", + x.dtype(), + out_dtype)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(quantize_xpu, + XPU, + ALL_LAYOUT, + phi::QuantizeKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc b/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc new file mode 100644 index 0000000000000..a0fd86fcc3208 --- /dev/null +++ b/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/scatter_nd_add_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void ScatterNdAddGradKernel(const Context &ctx, + const DenseTensor &index, + const DenseTensor &updates, + const DenseTensor &out_grad, + DenseTensor *x_grad, + DenseTensor *updates_grad) { + using XPUT = typename XPUTypeTrait::Type; + int ret = xpu::SUCCESS; + const T *out_grad_data = out_grad.data(); + if (x_grad) { + auto *x_grad_data = ctx.template Alloc(x_grad); + ret = xpu::copy(ctx.x_context(), + reinterpret_cast(out_grad_data), + reinterpret_cast(x_grad_data), + out_grad.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "copy"); + } + + if (updates_grad) { + auto *updates_grad_data = ctx.template Alloc(updates_grad); + if (updates_grad->numel() == 0) { + return; + } + if (index.numel() == 0) { + auto index_dims = index.dims(); + auto index_dims_size = index_dims.size(); + int64_t end_size = index_dims[index_dims_size - 1]; + PADDLE_ENFORCE_EQ( + end_size, + 0, + errors::InvalidArgument( + "Size of the last dim of the index tensor [%d] should be 0", + end_size)); + auto remain_dims = phi::slice_ddim(index_dims, 0, index_dims_size - 1); + int64_t remain_numel = phi::product(remain_dims); + int64_t updates_grad_numel = updates_grad->numel(); + int64_t out_grad_numel = out_grad.numel(); + PADDLE_ENFORCE_EQ( + remain_numel * out_grad_numel, + updates_grad_numel, + errors::InvalidArgument("out_grad numel[%d] * remain numel[%d] " + "should math updates_grad numel[%d]", + out_grad_numel, + remain_numel, + updates_grad_numel)); + ret = xpu::broadcast(ctx.x_context(), + reinterpret_cast(out_grad_data), + reinterpret_cast(updates_grad_data), + {1, out_grad_numel}, + {remain_numel, out_grad_numel}); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "broadcast"); + return; + } + + auto index_shape_vec = vectorize(index.dims()); + if (index_shape_vec.size() == 1) { + index_shape_vec.insert(index_shape_vec.begin(), 1); + } + auto out_grad_shape_vec = vectorize(out_grad.dims()); + xpu::VectorParam out_grad_shape_param = { + out_grad_shape_vec.data(), + static_cast(out_grad_shape_vec.size()), + nullptr}; + + if (index.dtype() == DataType::INT32) { + ret = xpu::gather_nd( + ctx.x_context(), + reinterpret_cast(out_grad_data), + index.data(), + reinterpret_cast(updates_grad_data), + out_grad_shape_param, + index_shape_vec); + } else { + ret = xpu::gather_nd( + ctx.x_context(), + reinterpret_cast(out_grad_data), + index.data(), + reinterpret_cast(updates_grad_data), + out_grad_shape_param, + index_shape_vec); + } + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather_nd"); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(scatter_nd_add_grad, + XPU, + ALL_LAYOUT, + phi::ScatterNdAddGradKernel, + float, + phi::dtype::float16, + int, + int64_t) {} diff --git a/paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc b/paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc index c760a2d0166c9..69e40994eb92d 100644 --- a/paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc +++ b/paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc @@ -34,8 +34,11 @@ void ScatterNdAddKernel(const Context &ctx, if (updates.numel() == 0) return; if (index.numel() == 0) { - int loop_time = - static_cast(index.dims().size() == 0 ? 1 : index.dims()[0]); + int64_t index_dims_size = index.dims().size(); + int loop_time = static_cast( + index_dims_size == 0 ? 1 + : phi::product(phi::slice_ddim( + index.dims(), 0, index_dims_size - 1))); for (int i = 0; i < loop_time; i++) { r = xpu::broadcast_add(ctx.x_context(), diff --git a/paddle/phi/kernels/xpu/set_value_kernel.cc b/paddle/phi/kernels/xpu/set_value_kernel.cc index dc154657c729e..a706ef00b9a41 100644 --- a/paddle/phi/kernels/xpu/set_value_kernel.cc +++ b/paddle/phi/kernels/xpu/set_value_kernel.cc @@ -18,6 +18,7 @@ #include #include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/slice_utils.h" @@ -386,20 +387,31 @@ void SetValueKernel(const Context& dev_ctx, const std::vector& shape, const std::vector& values, DenseTensor* out) { - std::vector assign_values; - assign_values.reserve(values.size()); - for (const auto& val : values) { - assign_values.push_back(val.to()); + // avoid using vector if T is bool or phi::dtype::float16 + int value_size = sizeof(T); + int values_size = values.size(); + int values_length = values_size * value_size; + std::vector assign_values(values_length); + uint8_t* value_data_uint8_cpu = assign_values.data(); + for (int i = 0; i < values_size; i++) { + T value = values[i].to(); + memcpy(value_data_uint8_cpu + i * value_size, &value, value_size); } + using XPUType = typename XPUTypeTrait::Type; + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + T* value_data = + reinterpret_cast(RAII_GUARD.alloc_l3_or_gm(values_size)); + memory_utils::Copy(dev_ctx.GetPlace(), + value_data, + phi::CPUPlace(), + value_data_uint8_cpu, + values_length); auto value_dims = phi::make_ddim(shape); - DenseTensor value_tensor; - TensorFromVector(assign_values, dev_ctx, &value_tensor); - SetValueKernelImpl(dev_ctx, x, - value_tensor.data(), + value_data, value_dims, starts, ends, diff --git a/paddle/phi/kernels/xpu/slice_grad_kernel.cc b/paddle/phi/kernels/xpu/slice_grad_kernel.cc index 3e054f3d8f342..ff5a49610fc54 100644 --- a/paddle/phi/kernels/xpu/slice_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/slice_grad_kernel.cc @@ -85,4 +85,5 @@ PD_REGISTER_KERNEL(slice_grad, phi::SliceGradKernel, float, int, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/slice_kernel.cc b/paddle/phi/kernels/xpu/slice_kernel.cc index a9bdf477d7e13..d3c114db2411b 100644 --- a/paddle/phi/kernels/xpu/slice_kernel.cc +++ b/paddle/phi/kernels/xpu/slice_kernel.cc @@ -120,4 +120,5 @@ PD_REGISTER_KERNEL(slice, float, int, phi::dtype::float16, + phi::dtype::bfloat16, int64_t) {} diff --git a/paddle/phi/kernels/xpu/split_kernel.cc b/paddle/phi/kernels/xpu/split_kernel.cc index 11a20f6f17946..e3aeb7ffdfbe3 100644 --- a/paddle/phi/kernels/xpu/split_kernel.cc +++ b/paddle/phi/kernels/xpu/split_kernel.cc @@ -74,7 +74,8 @@ PD_REGISTER_KERNEL(split, float, int64_t, int, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(split_with_num, XPU, ALL_LAYOUT, @@ -82,4 +83,5 @@ PD_REGISTER_KERNEL(split_with_num, float, int64_t, int, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc b/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc index fbc7a0bf6abcb..709eeaac49546 100644 --- a/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc @@ -163,4 +163,5 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad, int, int16_t, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/stride_slice_kernel.cc b/paddle/phi/kernels/xpu/stride_slice_kernel.cc index a2de8c2c8ffc1..2f026bae02fe4 100644 --- a/paddle/phi/kernels/xpu/stride_slice_kernel.cc +++ b/paddle/phi/kernels/xpu/stride_slice_kernel.cc @@ -171,4 +171,5 @@ PD_REGISTER_KERNEL(strided_slice_raw, int16_t, int64_t, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/transpose_grad_kernel.cc b/paddle/phi/kernels/xpu/transpose_grad_kernel.cc index 043d2c8e3df5a..71b2187bddce1 100644 --- a/paddle/phi/kernels/xpu/transpose_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/transpose_grad_kernel.cc @@ -65,6 +65,7 @@ PD_REGISTER_KERNEL(transpose_grad, phi::TransposeGradKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, int64_t, int, bool) {} diff --git a/paddle/phi/kernels/xpu/transpose_kernel.cc b/paddle/phi/kernels/xpu/transpose_kernel.cc index 398a2281dcea8..dd985ddc7ebc5 100644 --- a/paddle/phi/kernels/xpu/transpose_kernel.cc +++ b/paddle/phi/kernels/xpu/transpose_kernel.cc @@ -60,6 +60,7 @@ PD_REGISTER_KERNEL(transpose, phi::TransposeKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, int64_t, int, bool) {} diff --git a/paddle/phi/kernels/xpu/uniform_kernel.cc b/paddle/phi/kernels/xpu/uniform_kernel.cc index 99388e31e5881..ead65b65a8466 100644 --- a/paddle/phi/kernels/xpu/uniform_kernel.cc +++ b/paddle/phi/kernels/xpu/uniform_kernel.cc @@ -14,12 +14,9 @@ limitations under the License. */ #include "paddle/phi/kernels/uniform_kernel.h" -#include - -#include "paddle/phi/backends/xpu/xpu_context.h" -#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/generator.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/uniform_real_distribution.h" namespace phi { @@ -31,49 +28,31 @@ void UniformKernel(const Context &dev_ctx, const Scalar &max, int seed, DenseTensor *out) { - int diag_num = 0; - int diag_step = 0; - float diag_val = 0.0f; out->Resize(phi::make_ddim(shape.GetData())); T *data = dev_ctx.template Alloc(out); - int64_t size = out->numel(); - - std::unique_ptr data_cpu(new T[size]); - - std::shared_ptr engine; - if (seed) { - engine = std::make_shared(); - engine->seed(seed); - } else { - engine = dev_ctx.GetGenerator()->GetCPUEngine(); - } - UniformRealDistribution( - data_cpu.get(), size, min.to(), max.to(), engine); - if (diag_num > 0) { - PADDLE_ENFORCE_GT( - size, - (diag_num - 1) * (diag_step + 1), - phi::errors::InvalidArgument( - "ShapeInvalid: the diagonal's elements is equal (num-1) " - "* (step-1) with num %d, step %d," - "It should be smaller than %d, but received %d", - diag_num, - diag_step, - (diag_num - 1) * (diag_step + 1), - size)); - for (int64_t i = 0; i < diag_num; ++i) { - int64_t pos = i * diag_step + i; - data_cpu[pos] = diag_val; - } + if (out->numel() == 0) { + return; } - memory_utils::Copy(dev_ctx.GetPlace(), - data, - phi::CPUPlace(), - reinterpret_cast(data_cpu.get()), - size * sizeof(T)); + using XPUType = typename XPUTypeTrait::Type; + int64_t real_seed = seed != 0 ? seed : dev_ctx.GetGenerator()->Random64(); + + // int random(Context* ctx, T* x, int64_t len, T min, T max, int64_t seed); + int r = xpu::random(dev_ctx.x_context(), + reinterpret_cast(data), + out->numel(), + static_cast(min.to()), + static_cast(max.to()), + real_seed); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "random"); } } // namespace phi -PD_REGISTER_KERNEL(uniform, XPU, ALL_LAYOUT, phi::UniformKernel, float) {} +PD_REGISTER_KERNEL(uniform, + XPU, + ALL_LAYOUT, + phi::UniformKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/where_kernel.cc b/paddle/phi/kernels/xpu/where_kernel.cc index 8b644d0cf7f88..4c5a7fbf5cc09 100644 --- a/paddle/phi/kernels/xpu/where_kernel.cc +++ b/paddle/phi/kernels/xpu/where_kernel.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -57,4 +57,5 @@ PD_REGISTER_KERNEL(where, float, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/pir/CMakeLists.txt b/paddle/pir/CMakeLists.txt index 1f87a16ff36a6..b5850bbbae703 100644 --- a/paddle/pir/CMakeLists.txt +++ b/paddle/pir/CMakeLists.txt @@ -1,44 +1,7 @@ add_definitions(-DIR_LIBRARY) set_property(GLOBAL PROPERTY IR_TARGETS "") -set_property(GLOBAL PROPERTY IR_MODULES "") -function(ir_library TARGET_NAME) - set(options STATIC static SHARED shared INTERFACE interface) - set(oneValueArgs "") - set(multiValueArgs SRCS DEPS) - cmake_parse_arguments(ir_library "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN}) - set(OBJ_LIB ir_${TARGET_NAME}) - add_library(${OBJ_LIB} OBJECT ${ir_library_SRCS}) - if(ir_library_SHARED OR ir_library_shared) # build *.so - cc_library( - ${TARGET_NAME} SHARED - SRCS $ - DEPS ${ir_library_DEPS}) - elseif(ir_library_INTERFACE OR ir_library_interface) - cc_library( - ${TARGET_NAME} INTERFACE - SRCS $ - DEPS ${ir_library_DEPS}) - else() - cc_library( - ${TARGET_NAME} - SRCS $ - DEPS ${ir_library_DEPS}) - set_property(GLOBAL APPEND PROPERTY IR_MODULES $) - - get_property(ir_targets GLOBAL PROPERTY IR_TARGETS) - set(ir_targets ${ir_targets} ${TARGET_NAME}) - set_property(GLOBAL PROPERTY IR_TARGETS "${ir_targets}") - - endif() -endfunction() - -add_subdirectory(core) -add_subdirectory(pass) -add_subdirectory(pattern_rewrite) -add_subdirectory(transforms) -add_subdirectory(dialect) +file(GLOB_RECURSE PIR_CPP_SOURCES "*.cc") if(WIN32) if(WITH_SHARED_IR) @@ -76,9 +39,15 @@ set(IR_LIB "${CMAKE_CURRENT_BINARY_DIR}/${IR_NAME}" CACHE FILEPATH "IR Library" FORCE) -get_property(ir_modules GLOBAL PROPERTY IR_MODULES) if(WITH_SHARED_IR) - add_library(pir SHARED ${ir_modules}) + add_library(pir SHARED ${PIR_CPP_SOURCES}) + target_link_libraries(pir ddim) else() - add_library(pir STATIC ${ir_modules}) + cc_library( + pir + SRCS ${PIR_CPP_SOURCES} + DEPS ddim) + get_property(ir_targets GLOBAL PROPERTY IR_TARGETS) + set(ir_targets pir) + set_property(GLOBAL PROPERTY IR_TARGETS "${ir_targets}") endif() diff --git a/paddle/pir/core/CMakeLists.txt b/paddle/pir/core/CMakeLists.txt deleted file mode 100644 index 0fffc4285e376..0000000000000 --- a/paddle/pir/core/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -set(NEWIR_SOURCE_DIR "${PADDLE_SOURCE_DIR}/paddle/pir") -set(NEWIR_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/pir") - -file(GLOB IR_SRCS "*.cc") - -file(GLOB IR_PARSER_SRCS "parser/*.cc") -list(APPEND IR_SRCS ${IR_PARSER_SRCS}) - -ir_library(pir_core SRCS ${IR_SRCS} DEPS ddim) diff --git a/paddle/pir/core/builder.cc b/paddle/pir/core/builder.cc index 6a1608c84ab85..2484e02f5156e 100644 --- a/paddle/pir/core/builder.cc +++ b/paddle/pir/core/builder.cc @@ -73,6 +73,9 @@ DoubleAttribute Builder::double_attr(double value) { Int32Attribute Builder::int32_attr(int32_t value) { return Int32Attribute::get(context_, value); } +IndexAttribute Builder::index_attr(int64_t value) { + return IndexAttribute::get(context_, value); +} Int64Attribute Builder::int64_attr(int64_t value) { return Int64Attribute::get(context_, value); } diff --git a/paddle/pir/core/builder.h b/paddle/pir/core/builder.h index 72c8494cf8906..ae1887230c666 100644 --- a/paddle/pir/core/builder.h +++ b/paddle/pir/core/builder.h @@ -39,6 +39,7 @@ class BoolAttribute; class FloatAttribute; class DoubleAttribute; class Int32Attribute; +class IndexAttribute; class Int64Attribute; class ArrayAttribute; class PointerAttribute; @@ -131,6 +132,7 @@ class Builder { IR_API FloatAttribute float_attr(float value); IR_API DoubleAttribute double_attr(double value); IR_API Int32Attribute int32_attr(int32_t value); + IR_API IndexAttribute index_attr(int64_t value); IR_API Int64Attribute int64_attr(int64_t value); IR_API ArrayAttribute array_attr(const std::vector &value); IR_API PointerAttribute pointer_attr(void *value); diff --git a/paddle/pir/core/builtin_attribute.cc b/paddle/pir/core/builtin_attribute.cc index e14a424c32c8e..0958e24798414 100644 --- a/paddle/pir/core/builtin_attribute.cc +++ b/paddle/pir/core/builtin_attribute.cc @@ -24,6 +24,8 @@ double DoubleAttribute::data() const { return storage()->data(); } int32_t Int32Attribute::data() const { return storage()->data(); } +int64_t IndexAttribute::data() const { return storage()->data(); } + int64_t Int64Attribute::data() const { return storage()->data(); } void* PointerAttribute::data() const { return storage()->data(); } @@ -86,6 +88,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(pir::BoolAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::FloatAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::DoubleAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::Int32Attribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::IndexAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::Int64Attribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::ArrayAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::PointerAttribute) diff --git a/paddle/pir/core/builtin_attribute.h b/paddle/pir/core/builtin_attribute.h index 7d3f86144915c..b09bff8750c40 100644 --- a/paddle/pir/core/builtin_attribute.h +++ b/paddle/pir/core/builtin_attribute.h @@ -55,6 +55,15 @@ class IR_API Int32Attribute : public Attribute { int32_t data() const; }; +class IR_API IndexAttribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(IndexAttribute, IndexAttributeStorage); + + int64_t data() const; +}; + class IR_API Int64Attribute : public Attribute { public: using Attribute::Attribute; @@ -123,6 +132,7 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::FloatAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::DoubleAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Int32Attribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Int64Attribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::IndexAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::ArrayAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::PointerAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::TypeAttribute) diff --git a/paddle/pir/core/builtin_attribute_storage.h b/paddle/pir/core/builtin_attribute_storage.h index fd9dd6eb87128..2ab13326d3ebc 100644 --- a/paddle/pir/core/builtin_attribute_storage.h +++ b/paddle/pir/core/builtin_attribute_storage.h @@ -52,6 +52,7 @@ DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(BoolAttributeStorage, bool); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32AttributeStorage, int32_t); +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(IndexAttributeStorage, int64_t); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64AttributeStorage, int64_t); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(TypeAttributeStorage, Type); diff --git a/paddle/pir/core/builtin_dialect.cc b/paddle/pir/core/builtin_dialect.cc index 60575da6d9472..0fef066ec4727 100644 --- a/paddle/pir/core/builtin_dialect.cc +++ b/paddle/pir/core/builtin_dialect.cc @@ -46,6 +46,7 @@ void BuiltinDialect::initialize() { DoubleAttribute, PointerAttribute, Int32Attribute, + IndexAttribute, Int64Attribute, ArrayAttribute, TypeAttribute>(); diff --git a/paddle/pir/core/builtin_op.h b/paddle/pir/core/builtin_op.h index 19ca96b052692..64649f29175e6 100644 --- a/paddle/pir/core/builtin_op.h +++ b/paddle/pir/core/builtin_op.h @@ -204,7 +204,7 @@ class IR_API ConstantOp : public Op { Type output_type); void VerifySig() const; - + OpResult out() { return result(0); } Attribute value() const; }; diff --git a/paddle/pir/core/builtin_type.cc b/paddle/pir/core/builtin_type.cc index 54fbde1f5adf7..fb168a9a051cc 100644 --- a/paddle/pir/core/builtin_type.cc +++ b/paddle/pir/core/builtin_type.cc @@ -19,19 +19,19 @@ std::vector VectorType::data() const { return storage()->GetAsKey(); } pir::Type DenseTensorType::dtype() const { return storage()->dtype_; } -const DenseTensorTypeStorage::Dim& DenseTensorType::dims() const { +const DenseTensorType::Dim& DenseTensorType::dims() const { return storage()->dims_; } -const DenseTensorTypeStorage::DataLayout& DenseTensorType::data_layout() const { +DenseTensorType::DataLayout DenseTensorType::data_layout() const { return storage()->layout_; } -const DenseTensorTypeStorage::LoD& DenseTensorType::lod() const { +const DenseTensorType::LoD& DenseTensorType::lod() const { return storage()->lod_; } -const size_t& DenseTensorType::offset() const { return storage()->offset_; } +size_t DenseTensorType::offset() const { return storage()->offset_; } } // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::UInt8Type) diff --git a/paddle/pir/core/builtin_type.h b/paddle/pir/core/builtin_type.h index d43626f2e1546..d151f80d3e79c 100644 --- a/paddle/pir/core/builtin_type.h +++ b/paddle/pir/core/builtin_type.h @@ -58,16 +58,23 @@ class DenseTensorType : public Type::TypeBase { public: using Base::Base; + using Dim = DenseTensorTypeStorage::Dim; + using DataLayout = DenseTensorTypeStorage::DataLayout; + using LoD = DenseTensorTypeStorage::LoD; Type dtype() const; - - const DenseTensorTypeStorage::Dim &dims() const; - - const DenseTensorTypeStorage::DataLayout &data_layout() const; - - const DenseTensorTypeStorage::LoD &lod() const; - - const size_t &offset() const; + const Dim &dims() const; + DataLayout data_layout() const; + const LoD &lod() const; + size_t offset() const; + static DenseTensorType get(IrContext *ctx, + Type dtype, + const Dim &dims, + DataLayout layout = DataLayout::kNCHW, + const LoD &lod = {}, + size_t offset = 0u) { + return Base::get(ctx, dtype, dims, layout, lod, offset); + } }; #define DECLARE_BUILTIN_TYPE(__name) \ diff --git a/paddle/pir/core/builtin_type_interfaces.h b/paddle/pir/core/builtin_type_interfaces.h index f1df893f89e3f..40ad58313a0d3 100644 --- a/paddle/pir/core/builtin_type_interfaces.h +++ b/paddle/pir/core/builtin_type_interfaces.h @@ -40,27 +40,17 @@ class ShapedTypeInterface : public TypeInterfaceBase { template struct Model : public Concept { - static inline DataType getElementType(Type type) { + static inline DataType GetElementType(Type type) { return pir::cast(type).dtype(); } - static inline DDim getShape(Type type) { + static inline DDim GetShape(Type type) { return pir::cast(type).dims(); } - Model() : Concept(getElementType, getShape) {} + Model() : Concept(GetElementType, GetShape) {} }; - /// Constructor - ShapedTypeInterface(std::nullptr_t) // NOLINT - : TypeInterfaceBase(Type()), impl_(nullptr) {} - - explicit ShapedTypeInterface(Type type = Type()) - : TypeInterfaceBase(type), - impl_(type - ? type.abstract_type().GetInterfaceImpl() - : nullptr) {} - ShapedTypeInterface(Type type, Concept *impl) : TypeInterfaceBase(type), impl_(impl) {} diff --git a/paddle/pir/core/builtin_type_storage.h b/paddle/pir/core/builtin_type_storage.h index b8b18d09ddd26..10063963df633 100644 --- a/paddle/pir/core/builtin_type_storage.h +++ b/paddle/pir/core/builtin_type_storage.h @@ -53,11 +53,11 @@ struct DenseTensorTypeStorage : public pir::TypeStorage { using DataLayout = phi::DataLayout; using Dim = phi::DDim; using LoD = std::vector>; - using ParamKey = std::tuple; + using ParamKey = std::tuple; - DenseTensorTypeStorage(const pir::Type& dtype, + DenseTensorTypeStorage(Type dtype, const Dim& dims, - const DataLayout& layout, + DataLayout layout, const LoD& lod, size_t offset) : dtype_(dtype), diff --git a/paddle/pir/core/enforce.h b/paddle/pir/core/enforce.h index a3b1401b64d25..e8624b8bbe4e1 100644 --- a/paddle/pir/core/enforce.h +++ b/paddle/pir/core/enforce.h @@ -19,6 +19,13 @@ #include "paddle/utils/string/printf.h" +#if defined(_WIN32) +#define UNUSED +#define __builtin_expect(EXP, C) (EXP) +#else +#define UNUSED __attribute__((unused)) +#endif + #if !defined(_WIN32) #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) #else diff --git a/paddle/pir/core/infer_type_op_interface.cc b/paddle/pir/core/infer_type_op_interface.cc new file mode 100644 index 0000000000000..b238daca2045f --- /dev/null +++ b/paddle/pir/core/infer_type_op_interface.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/core/infer_type_op_interface.h" + +namespace pir { + +bool InferShapedTypeOpInterface::ReifyReturnTypeShapes( + Builder& builder, + std::vector operands, + std::vector& reified_return_shapes) { + return impl_->reify_return_type_shapes( + builder, operands, reified_return_shapes); +} +} // namespace pir + +IR_DEFINE_EXPLICIT_TYPE_ID(pir::InferShapedTypeOpInterface) diff --git a/paddle/pir/core/infer_type_op_interface.h b/paddle/pir/core/infer_type_op_interface.h new file mode 100644 index 0000000000000..6acef20c02340 --- /dev/null +++ b/paddle/pir/core/infer_type_op_interface.h @@ -0,0 +1,72 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/op_base.h" + +// Type inference is currently modelled executionally for operation creation +// using the `InferMetaInterface`. While `InferShapedTypeOpInterface` is used to +// implement the shape and element type inference. The return type can often be +// deduced from the deduced return shape and elemental type (queryable from +// `InferShapedTypeOpInterface`) and so type inference for tensor types can be +// implemented with `InferShapedTypeOpInterface`. + +namespace pir { + +class InferShapedTypeOpInterface + : public pir::OpInterfaceBase { + public: + /// Defined these methods with the interface. + struct Concept { + explicit Concept(bool (*reify_return_type_shapes)( + Builder& builder, // NOLINT + std::vector operands, // NOLINT + std::vector& reified_return_shapes)) // NOLINT + : reify_return_type_shapes(reify_return_type_shapes) {} + bool (*reify_return_type_shapes)( + Builder& builder, + std::vector operands, + std::vector& reified_return_shapes); // NOLINT + }; + + template + struct Model : public Concept { + static inline bool ReifyReturnTypeShapes( + Builder& builder, // NOLINT + std::vector operands, // NOLINT + std::vector& reified_return_shapes) { // NOLINT + return ConcreteOp::ReifyReturnTypeShapes( + builder, operands, reified_return_shapes); + } + + Model() : Concept(ReifyReturnTypeShapes) {} + }; + + /// Constructor + InferShapedTypeOpInterface(Operation* op, Concept* impl) + : pir::OpInterfaceBase(op), impl_(impl) {} + + bool ReifyReturnTypeShapes( + Builder& builder, // NOLINT + std::vector operands, // NOLINT + std::vector& reified_return_shapes); // NOLINT + + private: + Concept* impl_; +}; + +} // namespace pir + +IR_DECLARE_EXPLICIT_TYPE_ID(pir::InferShapedTypeOpInterface) diff --git a/paddle/pir/core/ir_printer.cc b/paddle/pir/core/ir_printer.cc index 81cb3b4bcf224..37c74111d00e1 100644 --- a/paddle/pir/core/ir_printer.cc +++ b/paddle/pir/core/ir_printer.cc @@ -110,6 +110,8 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) { os << "(Int32)" << i.data(); } else if (auto i = attr.dyn_cast()) { os << "(Int64)" << i.data(); + } else if (auto i = attr.dyn_cast()) { + os << "(Index)" << i.data(); } else if (auto p = attr.dyn_cast()) { os << "(Pointer)" << p.data(); } else if (auto arr = attr.dyn_cast()) { @@ -198,7 +200,7 @@ void IrPrinter::PrintValue(Value v) { os << "<>"; return; } - const void* key = static_cast(v.impl()); + const void* key = v.impl(); auto ret = aliases_.find(key); if (ret != aliases_.end()) { os << ret->second; @@ -308,6 +310,11 @@ void IrPrinter::PrintOpReturnType(Operation* op) { [this]() { this->os << ", "; }); } +void IrPrinter::AddValueAlias(Value v, const std::string& alias) { + const void* key = v.impl(); + IR_ENFORCE(aliases_.find(key) == aliases_.end(), "Value already has alias"); + aliases_[key] = alias; +} void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const { printer.PrintGeneralOperation(op); } diff --git a/paddle/pir/core/ir_printer.h b/paddle/pir/core/ir_printer.h index e4d821c01911b..cb7135fb484de 100644 --- a/paddle/pir/core/ir_printer.h +++ b/paddle/pir/core/ir_printer.h @@ -70,6 +70,8 @@ class IR_API IrPrinter : public BasicIrPrinter { void PrintOpReturnType(Operation* op); + void AddValueAlias(Value value, const std::string& alias); + private: size_t cur_result_number_{0}; size_t cur_block_argument_number_{0}; diff --git a/paddle/pir/core/op_base.h b/paddle/pir/core/op_base.h index f0710ff5ec629..217a34a631536 100644 --- a/paddle/pir/core/op_base.h +++ b/paddle/pir/core/op_base.h @@ -141,6 +141,7 @@ class Op : public OpBase { using InterfaceList = typename Filter>::Type; + // TODO(zhangbopd): Use classof static ConcreteOp dyn_cast(Operation *op) { if (op && op->info().id() == TypeId::get()) { return ConcreteOp(op); diff --git a/paddle/pir/core/op_result.cc b/paddle/pir/core/op_result.cc index d14a3c830c8d2..8249872593652 100644 --- a/paddle/pir/core/op_result.cc +++ b/paddle/pir/core/op_result.cc @@ -26,8 +26,7 @@ bool OpResult::classof(Value value) { } Operation *OpResult::owner() const { - CHECK_OPRESULT_NULL_IMPL(owner); - return IMPL_->owner(); + return impl_ ? static_cast(impl_)->owner() : nullptr; } uint32_t OpResult::index() const { diff --git a/paddle/pir/core/op_result.h b/paddle/pir/core/op_result.h index 8860473fe3339..5ca9164a04a23 100644 --- a/paddle/pir/core/op_result.h +++ b/paddle/pir/core/op_result.h @@ -30,6 +30,7 @@ class IR_API OpResult : public Value { public: OpResult(std::nullptr_t ptr = nullptr) : Value(ptr){}; // NOLINT Operation *owner() const; + // Return the result index of this op result. uint32_t index() const; bool operator==(const OpResult &other) const; @@ -38,8 +39,18 @@ class IR_API OpResult : public Value { OpResult(detail::OpResultImpl *impl); // NOLINT // Access classof annd dyn_cast_from. friend Value; + friend struct std::hash; static bool classof(Value value); static OpResult dyn_cast_from(Value value); }; } // namespace pir + +namespace std { +template <> +struct hash { + std::size_t operator()(const pir::OpResult &obj) const { + return std::hash()(obj); + } +}; +} // namespace std diff --git a/paddle/pir/core/op_trait.cc b/paddle/pir/core/op_trait.cc index ccea4e3f06d9b..94d800e2944f2 100644 --- a/paddle/pir/core/op_trait.cc +++ b/paddle/pir/core/op_trait.cc @@ -16,9 +16,9 @@ #include "paddle/pir/core/enforce.h" #include "paddle/pir/core/type_util.h" -namespace pir::detail { +namespace { -void VerifySameOperandsShapeTrait(Operation *op) { +void VerifySameOperandsShapeTrait(pir::Operation *op) { VLOG(4) << "Verify SameOperandsShapeTrait for : " << op->name(); IR_ENFORCE(op->num_operands() > 0, @@ -39,7 +39,7 @@ void VerifySameOperandsShapeTrait(Operation *op) { op->name()); } -void VerifySameOperandsAndResultShapeTrait(Operation *op) { +void VerifySameOperandsAndResultShapeTrait(pir::Operation *op) { VLOG(4) << "Verify SameOperandsAndResultShapeTrait for : " << op->name(); IR_ENFORCE(op->num_operands() > 0, @@ -73,7 +73,7 @@ void VerifySameOperandsAndResultShapeTrait(Operation *op) { op->name()); } -void VerifySameOperandsElementTypeTrait(Operation *op) { +void VerifySameOperandsElementTypeTrait(pir::Operation *op) { VLOG(4) << "Verify SameOperandsElementTypeTrait for : " << op->name(); IR_ENFORCE(op->num_operands() > 0, @@ -91,7 +91,7 @@ void VerifySameOperandsElementTypeTrait(Operation *op) { } } -void VerifySameOperandsAndResultElementTypeTrait(Operation *op) { +void VerifySameOperandsAndResultElementTypeTrait(pir::Operation *op) { VLOG(4) << "Verify SameOperandsAndResultElementTypeTrait for : " << op->name(); @@ -126,7 +126,7 @@ void VerifySameOperandsAndResultElementTypeTrait(Operation *op) { } } -void VerifySameOperandsAndResultTypeTrait(Operation *op) { +void VerifySameOperandsAndResultTypeTrait(pir::Operation *op) { VLOG(4) << "Verify SameOperandsAndResultTypeTrait for : " << op->name(); IR_ENFORCE(op->num_operands() > 0, @@ -169,7 +169,7 @@ void VerifySameOperandsAndResultTypeTrait(Operation *op) { } } -void VerifySameTypeOperandsTrait(Operation *op) { +void VerifySameTypeOperandsTrait(pir::Operation *op) { VLOG(4) << "Verify SameTypeOperandsTrait for : " << op->name(); // For zero or only one operand. @@ -186,7 +186,40 @@ void VerifySameTypeOperandsTrait(Operation *op) { } } -} // namespace pir::detail +void VerifyOneResultTrait(pir::Operation *op) { + IR_ENFORCE(op->num_results() == 1, + "Op %s with OneResultTrait requires 1 result, but got %u results.", + op->name(), + op->num_results()); +} +} // namespace + +namespace pir { +void SameOperandsShapeTrait::Verify(Operation *op) { + return VerifySameOperandsShapeTrait(op); +} + +void SameOperandsAndResultShapeTrait::Verify(Operation *op) { + return VerifySameOperandsAndResultShapeTrait(op); +} + +void SameOperandsElementTypeTrait::Verify(Operation *op) { + return VerifySameOperandsElementTypeTrait(op); +} + +void SameOperandsAndResultElementTypeTrait::Verify(Operation *op) { + return VerifySameOperandsAndResultElementTypeTrait(op); +} + +void SameOperandsAndResultTypeTrait::Verify(Operation *op) { + return VerifySameOperandsAndResultTypeTrait(op); +} +void SameTypeOperandsTrait::Verify(Operation *op) { + return VerifySameTypeOperandsTrait(op); +} + +void OneResultTrait::Verify(Operation *op) { return VerifyOneResultTrait(op); } +} // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsShapeTrait) IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultShapeTrait) @@ -194,3 +227,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsElementTypeTrait) IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultElementTypeTrait) IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultTypeTrait) IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameTypeOperandsTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::OneResultTrait) diff --git a/paddle/pir/core/op_trait.h b/paddle/pir/core/op_trait.h index 760799fd16165..b55352c164765 100644 --- a/paddle/pir/core/op_trait.h +++ b/paddle/pir/core/op_trait.h @@ -18,15 +18,6 @@ namespace pir { -namespace detail { -void VerifySameOperandsShapeTrait(Operation *op); -void VerifySameOperandsAndResultShapeTrait(Operation *op); -void VerifySameOperandsElementTypeTrait(Operation *op); -void VerifySameOperandsAndResultElementTypeTrait(Operation *op); -void VerifySameOperandsAndResultTypeTrait(Operation *op); -void VerifySameTypeOperandsTrait(Operation *op); -} // namespace detail - /// /// \brief Provides verification for ops that are known to have the /// same operand shape. @@ -35,9 +26,7 @@ class SameOperandsShapeTrait : public pir::OpTraitBase { public: explicit SameOperandsShapeTrait(pir::Operation *op) : pir::OpTraitBase(op) {} - static void Verify(Operation *op) { - return detail::VerifySameOperandsShapeTrait(op); - } + static void Verify(Operation *op); }; /// @@ -49,9 +38,7 @@ class SameOperandsAndResultShapeTrait public: explicit SameOperandsAndResultShapeTrait(pir::Operation *op) : pir::OpTraitBase(op) {} - static void Verify(Operation *op) { - return detail::VerifySameOperandsAndResultShapeTrait(op); - } + static void Verify(Operation *op); }; /// @@ -63,9 +50,7 @@ class SameOperandsElementTypeTrait public: explicit SameOperandsElementTypeTrait(pir::Operation *op) : pir::OpTraitBase(op) {} - static void Verify(Operation *op) { - return detail::VerifySameOperandsElementTypeTrait(op); - } + static void Verify(Operation *op); }; /// @@ -77,9 +62,7 @@ class SameOperandsAndResultElementTypeTrait public: explicit SameOperandsAndResultElementTypeTrait(pir::Operation *op) : pir::OpTraitBase(op) {} - static void Verify(Operation *op) { - return detail::VerifySameOperandsAndResultElementTypeTrait(op); - } + static void Verify(Operation *op); }; /// @@ -93,9 +76,7 @@ class SameOperandsAndResultTypeTrait explicit SameOperandsAndResultTypeTrait(pir::Operation *op) : pir::OpTraitBase(op) {} - static void Verify(Operation *op) { - return detail::VerifySameOperandsAndResultTypeTrait(op); - } + static void Verify(Operation *op); }; /// @@ -106,9 +87,26 @@ class SameTypeOperandsTrait : public pir::OpTraitBase { public: explicit SameTypeOperandsTrait(pir::Operation *op) : pir::OpTraitBase(op) {} - static void Verify(Operation *op) { - return detail::VerifySameTypeOperandsTrait(op); + static void Verify(Operation *op); +}; + +/// +/// \brief This trait provides return value APIs for ops that are known to have +/// a single result returned by GetType(). +/// +class OneResultTrait : public OpTraitBase { + public: + // Replace all uses of 'this' value with the new value, updating anything + // in the IR that uses 'this' to use the other value instead. + void ReplaceAllUsesWith(Value new_value) { + this->operation()->result(0).ReplaceAllUsesWith(new_value); + } + + // Replace all uses of 'this' value with the result of 'op'. + void ReplaceAllUsesWith(Operation *op) { + this->operation()->ReplaceAllUsesWith(op->result(0)); } + static void Verify(Operation *op); }; } // namespace pir @@ -119,3 +117,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsElementTypeTrait) IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultElementTypeTrait) IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultTypeTrait) IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameTypeOperandsTrait) +IR_DECLARE_EXPLICIT_TYPE_ID(pir::OneResultTrait) diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index 0dedeafc9ae71..1a6666fcc2a9b 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -123,7 +123,12 @@ Operation *Operation::Create(const std::vector &inputs, // 0. Verify if (op_info) { - op_info.VerifySig(op); + try { + op_info.VerifySig(op); + } catch (const pir::IrNotMetException &e) { + op->Destroy(); + throw e; + } } return op; } @@ -292,6 +297,21 @@ std::string Operation::name() const { auto p_name = info_.name(); return p_name ? p_name : ""; } + +void Operation::Erase() { + if (auto *parent = GetParent()) + parent->erase(*this); + else + Destroy(); +} + +bool Operation::use_empty() { + auto res = results(); + return std::all_of(res.begin(), res.end(), [](OpResult result) { + return result.use_empty(); + }); +} + void Operation::ReplaceAllUsesWith(const std::vector &values) { IR_ENFORCE(num_results_ == values.size(), "the num of result should be the same."); diff --git a/paddle/pir/core/operation.h b/paddle/pir/core/operation.h index d45fc368d2804..a41e648e7e279 100644 --- a/paddle/pir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -83,6 +83,7 @@ class IR_API alignas(8) Operation final { /// uint32_t num_results() const { return num_results_; } OpResult result(uint32_t index) { return op_result_impl(index); } + Type result_type(uint32_t index) { return result(index).type(); } std::vector results(); /// @@ -125,6 +126,16 @@ class IR_API alignas(8) Operation final { pir::OpInfo info() const { return info_; } std::string name() const; + /// + /// \brief Remove this operation from its parent block and delete it. + /// + void Erase(); + + /// + /// \brief Returns true if this operation has no uses. + /// + bool use_empty(); + template T dyn_cast() { return CastUtil::call(this); diff --git a/paddle/pir/core/operation_utils.h b/paddle/pir/core/operation_utils.h index d46890569ead3..77a64a358365d 100644 --- a/paddle/pir/core/operation_utils.h +++ b/paddle/pir/core/operation_utils.h @@ -75,6 +75,9 @@ struct OperationArgument { template void AddOutputs(InputIt first, InputIt last); + void AddOutputs(std::initializer_list type_list) { + AddOutputs(std::begin(type_list), std::end(type_list)); + } template void AddOutputs(const TypeContainer& type_container) { AddOutputs(std::begin(type_container), std::end(type_container)); diff --git a/paddle/pir/core/type.cc b/paddle/pir/core/type.cc index 91933019fb835..a200a07325bc0 100644 --- a/paddle/pir/core/type.cc +++ b/paddle/pir/core/type.cc @@ -31,4 +31,6 @@ bool Type::IsIntOrIndex() const { isa() || isa() || isa(); } +bool Type::IsIndex() const { return isa(); } + } // namespace pir diff --git a/paddle/pir/core/type.h b/paddle/pir/core/type.h index c1b2f155e8d5a..b48da12c12b31 100644 --- a/paddle/pir/core/type.h +++ b/paddle/pir/core/type.h @@ -120,6 +120,7 @@ class IR_API Type { /// type. /// bool IsIntOrIndex() const; + bool IsIndex() const; protected: const Storage *storage_{nullptr}; diff --git a/paddle/pir/core/type_base.cc b/paddle/pir/core/type_base.cc index aec0d93d9fa69..3676d4099be81 100644 --- a/paddle/pir/core/type_base.cc +++ b/paddle/pir/core/type_base.cc @@ -30,7 +30,7 @@ void *AbstractType::GetInterfaceImpl(TypeId interface_id) const { VLOG(6) << "Find no interface!"; return nullptr; } - // TODO(zhangbo63): Add LookUp method like: + // TODO(zhangbopd): Add LookUp method like: // return ir::detail::LookUp( // interface_id, num_interfaces_, num_traits_, this); } diff --git a/paddle/pir/core/value.cc b/paddle/pir/core/value.cc index 13b0b4a5cfee8..cb694eaa7be8f 100644 --- a/paddle/pir/core/value.cc +++ b/paddle/pir/core/value.cc @@ -46,6 +46,8 @@ Value::operator bool() const { return impl_; } pir::Type Value::type() const { return impl_ ? impl_->type() : nullptr; } +Operation *Value::defining_op() const { return dyn_cast().owner(); } + void Value::set_type(pir::Type type) { CHECK_VALUE_NULL_IMPL(set_type); impl_->set_type(type); diff --git a/paddle/pir/core/value.h b/paddle/pir/core/value.h index 96787b973b81a..50d8265f29884 100644 --- a/paddle/pir/core/value.h +++ b/paddle/pir/core/value.h @@ -59,6 +59,16 @@ class IR_API Value { Type type() const; + /// If this value is the result of an operation, return the operation that + /// defines it, else return nullptr; + Operation *defining_op() const; + + template + OpTy defining_op() const { + /// It is safety even if defining_op() return nullptr. + return OpTy::dyn_cast(defining_op()); + } + void set_type(Type type); std::string PrintUdChain(); @@ -93,7 +103,6 @@ class IR_API Value { protected: detail::ValueImpl *impl_{nullptr}; }; - } // namespace pir namespace std { diff --git a/paddle/pir/dialect/CMakeLists.txt b/paddle/pir/dialect/CMakeLists.txt deleted file mode 100644 index 064d328fc53d6..0000000000000 --- a/paddle/pir/dialect/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(control_flow) -add_subdirectory(shape) diff --git a/paddle/pir/dialect/control_flow/CMakeLists.txt b/paddle/pir/dialect/control_flow/CMakeLists.txt deleted file mode 100644 index b30eb7fa567d7..0000000000000 --- a/paddle/pir/dialect/control_flow/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -file(GLOB_RECURSE CONTROL_FLOW_SRCS "*.cc") -ir_library(pir_control_flow SRCS ${CONTROL_FLOW_SRCS} DEPS pir_core) diff --git a/paddle/pir/dialect/control_flow/ir/cf_dialect.cc b/paddle/pir/dialect/control_flow/ir/cf_dialect.cc index ed36c0c81cca6..b10df41168a27 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_dialect.cc +++ b/paddle/pir/dialect/control_flow/ir/cf_dialect.cc @@ -12,9 +12,37 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" -#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#include "paddle/pir/core/ir_printer.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" +#include "paddle/pir/dialect/control_flow/ir/cf_type.h" namespace pir { -void ControlFlowDialect::initialize() { RegisterOps(); } +void ControlFlowDialect::initialize() { + RegisterTypes(); + RegisterOps(); +} + +void ControlFlowDialect::PrintType(pir::Type type, std::ostream &os) const { + os << name(); + os << '.'; + if (type.isa()) { + os << "stack"; + } else if (type.isa()) { + os << "inlet"; + } else if (type.isa()) { + os << "outlet"; + } else { + os << "unknown type"; + } +} + +void ControlFlowDialect::PrintOperation(pir::Operation *op, + pir::IrPrinter &printer) const { + if (auto create_op = op->dyn_cast()) { + create_op.Print(printer); + } else { + printer.PrintGeneralOperation(op); + } +} } // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::ControlFlowDialect) diff --git a/paddle/pir/dialect/control_flow/ir/cf_dialect.h b/paddle/pir/dialect/control_flow/ir/cf_dialect.h index c195ba9638984..a319bd888a65f 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_dialect.h +++ b/paddle/pir/dialect/control_flow/ir/cf_dialect.h @@ -24,7 +24,9 @@ class ControlFlowDialect : public Dialect { initialize(); } static const char *name() { return "cf"; } - + void PrintType(pir::Type type, std::ostream &os) const override; + void PrintOperation(pir::Operation *op, + pir::IrPrinter &printer) const override; // NOLINT private: void initialize(); }; diff --git a/paddle/pir/dialect/control_flow/ir/cf_op.cc b/paddle/pir/dialect/control_flow/ir/cf_op.cc new file mode 100644 index 0000000000000..621ca0e775b84 --- /dev/null +++ b/paddle/pir/dialect/control_flow/ir/cf_op.cc @@ -0,0 +1,194 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/ir_printer.h" +#include "paddle/pir/dialect/control_flow/ir/cf_type.h" + +namespace pir { + +void YieldOp::Build(Builder &builder, + OperationArgument &argument, + const std::vector &inputs) { + argument.AddInputs(inputs); +} + +void CreateStackOp::Build(Builder &builder, OperationArgument &argument) { + auto stack_type = StackType::get(builder.ir_context()); + auto inlet_type = InletType::get(builder.ir_context()); + auto outlet_type = OutletType::get(builder.ir_context()); + argument.AddOutputs({stack_type, inlet_type, outlet_type}); +} +void CreateStackOp::VerifySig() { + VLOG(4) << "Verifying inputs, outputs and attributes for: CreateStackOp."; + // Verify inputs: + IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0."); + + // No attributes should be verify. + + // Verify outputs: + IR_ENFORCE(num_results() == 3u, "The size of outputs must be equal to 3."); + + IR_ENFORCE(result(0).type().isa(), + "The first outputs of cf.create_stack must be stack_type."); + IR_ENFORCE(result(1).type().isa(), + "The first outputs of cf.create_stack must be inlet_type."); + IR_ENFORCE(result(2).type().isa(), + "The first outputs of cf.create_stack must be outlet_type."); + + VLOG(4) << "End Verifying for CreateStackOp."; +} +size_t CreateStackOp::stack_size() { return push_op().stack_size(); } +Value CreateStackOp::inlet_element(size_t index) { + return push_op().inlet_element(index); +} +Value CreateStackOp::outlet_element(size_t index) { + return pop_op().outlet_element(index); +} +PushBackOp CreateStackOp::push_op() { + auto inlet_value = inlet(); + IR_ENFORCE(inlet_value.HasOneUse(), "The inlet value must has one use."); + return inlet_value.first_use().owner()->dyn_cast(); +} +PopBackOp CreateStackOp::pop_op() { + auto outlet_value = outlet(); + IR_ENFORCE(outlet_value.HasOneUse(), "The outlet value must has one use."); + return outlet_value.first_use().owner()->dyn_cast(); +} + +void CreateStackOp::Print(IrPrinter &printer) { // NOLINT + static std::unordered_map> + kConunters; + auto &counter = kConunters[&printer]; + auto iter = counter.insert({*this, counter.size()}); + auto index = iter.first->second; + if (iter.second) { + printer.AddValueAlias(stack(), "%stack_" + std::to_string(index)); + printer.AddValueAlias(inlet(), "%inlet_" + std::to_string(index)); + printer.AddValueAlias(outlet(), "%outlet_" + std::to_string(index)); + } + printer.PrintGeneralOperation(*this); +} + +void PushBackOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value inlet, + const std::vector &elements) { + argument.AddInput(inlet); + argument.AddInputs(elements); +} + +void PushBackOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value inlet, + std::initializer_list element_list) { + argument.AddInput(inlet); + argument.AddInputs(element_list); +} + +void PushBackOp::VerifySig() { + VLOG(4) << "Verifying inputs, outputs ,attributes for: PushBackOp."; + // Verify inputs: + IR_ENFORCE(num_operands() >= 2u, "The size of inputs must no less than 2."); + IR_ENFORCE(operand_source(0).type().isa(), + "The first input of cf.push_back must be inlet_type."); + + // No attributes should be verify. + + // Verify outputs: + IR_ENFORCE(num_results() == 0u, "The size of outputs must be equal to 0."); + VLOG(4) << "End Verifying for PushBackOp."; +} + +size_t PushBackOp::stack_size() { + auto operands_size = num_operands(); + IR_ENFORCE(operands_size >= 2u, + "The operands of push op must no less than 2."); + return operands_size - 1u; +} + +PopBackOp PushBackOp::pop_op() { return create_op().pop_op(); } +void PopBackOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value outlet) { + argument.AddInput(outlet); + + auto push_back_op = outlet.defining_op().push_op(); + + auto elements_size = push_back_op.stack_size(); + + for (size_t index = 0; index < elements_size; ++index) { + argument.AddOutput(push_back_op.inlet_element(index).type()); + } +} + +void PopBackOp::VerifySig() { + VLOG(4) << "Verifying inputs, outputs ,attributes and stack validity for: " + "PopBackOp."; + // Verify inputs: + IR_ENFORCE(num_operands() == 1u, "The size of inputs must equal to 1."); + IR_ENFORCE(operand_source(0).type().isa(), + "The first input of cf.pop_back must be outlet_type."); + + // No attributes should be verify. + + // Verify outputs: + IR_ENFORCE(num_results() >= 1u, + "The size of outputs must no less than to 1."); + // Verify stack validity: + auto pop_back_op = create_op().pop_op(); + IR_ENFORCE(*this == pop_back_op, + "The pop_op of stack_op must be this pop_op self."); + + auto inlet_size = push_op().stack_size(); + IR_ENFORCE(inlet_size == stack_size(), + "The pop elements size must equal to push elements size."); + for (size_t index = 0; index < inlet_size; ++index) { + IR_ENFORCE(outlet_element(index).type() == inlet_element(index).type(), + "The %d element's push type isn't equal to pop type", + index); + } + VLOG(4) << "End Verifying for PopBackOp."; +} + +void HasElementsOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value stack) { + argument.AddInput(stack); + argument.AddOutput(builder.bool_type()); +} +void HasElementsOp::VerifySig() { + VLOG(4) << "Verifying inputs, outputs ,attributes for: HasElementsOp."; + // Verify inputs: + IR_ENFORCE(num_operands() == 1u, "The size of inputs must equal to 1."); + IR_ENFORCE(operand_source(0).type().isa(), + "The first input of cf.has_elements must be stack_type."); + + // No attributes should be verify. + + // Verify outputs: + IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1."); + IR_ENFORCE((*this)->result_type(0) == BoolType::get(ir_context()), + "The type of cf.has_elements' output is not correct."); +} + +} // namespace pir + +IR_DEFINE_EXPLICIT_TYPE_ID(pir::YieldOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::CreateStackOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::PushBackOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::PopBackOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::HasElementsOp) diff --git a/paddle/pir/dialect/control_flow/ir/cf_op.h b/paddle/pir/dialect/control_flow/ir/cf_op.h new file mode 100644 index 0000000000000..b85aa14181845 --- /dev/null +++ b/paddle/pir/dialect/control_flow/ir/cf_op.h @@ -0,0 +1,131 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/op_base.h" + +namespace pir { +class IR_API YieldOp : public Op { + public: + using Op::Op; + static const char *name() { return "cf.yield"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const std::vector &Value); + void VerifySig() {} +}; +class PushBackOp; +class PopBackOp; +class IR_API CreateStackOp : public Op { + public: + using Op::Op; + static const char *name() { return "cf.create_stack"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(Builder &builder, // NOLINT + OperationArgument &argument); // NOLINT + void VerifySig(); + + Value stack() { return result(0); } + Value inlet() { return result(1); } + Value outlet() { return result(2); } + std::tuple out() { return {stack(), inlet(), outlet()}; } + + size_t stack_size(); + Value inlet_element(size_t index); + Value outlet_element(size_t index); + PushBackOp push_op(); + PopBackOp pop_op(); + + void Print(pir::IrPrinter &printer); // NOLINT +}; + +class IR_API PushBackOp : public Op { + public: + using Op::Op; + static const char *name() { return "cf.push_back"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value inlet, + const std::vector &elements); + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value inlet, + std::initializer_list element_list); + void VerifySig(); + + Value stack() { return create_op().stack(); } + Value inlet() { return operand_source(0); } + Value outlet() { return create_op().outlet(); } + size_t stack_size(); + Value inlet_element(size_t index) { return operand_source(index + 1u); } + Value outlet_element(size_t index) { + return create_op().outlet_element(index); + } + CreateStackOp create_op() { return inlet().defining_op(); } + PopBackOp pop_op(); +}; + +class IR_API PopBackOp : public Op { + public: + using Op::Op; + static const char *name() { return "cf.pop_back"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value outlet); + void VerifySig(); + + Value stack() { return create_op().stack(); } + Value inlet() { return create_op().inlet(); } + Value outlet() { return operand_source(0); } + + size_t stack_size() { return num_results(); } + Value inlet_element(size_t index) { return push_op().inlet_element(index); } + Value outlet_element(size_t index) { return result(index); } + CreateStackOp create_op() { return outlet().defining_op(); } + PushBackOp push_op() { return create_op().push_op(); } +}; + +class IR_API HasElementsOp : public Op { + public: + using Op::Op; + static const char *name() { return "cf.has_elements"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value stack); + void VerifySig(); + Value out() { return result(0); } +}; + +} // namespace pir + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::YieldOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::CreateStackOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::PushBackOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::PopBackOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::HasElementsOp); diff --git a/paddle/pir/dialect/control_flow/ir/cf_type.cc b/paddle/pir/dialect/control_flow/ir/cf_type.cc new file mode 100644 index 0000000000000..19ec9af3864e3 --- /dev/null +++ b/paddle/pir/dialect/control_flow/ir/cf_type.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/dialect/control_flow/ir/cf_type.h" + +IR_DEFINE_EXPLICIT_TYPE_ID(pir::StackType) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::InletType) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::OutletType) diff --git a/paddle/pir/dialect/control_flow/ir/cf_type.h b/paddle/pir/dialect/control_flow/ir/cf_type.h new file mode 100644 index 0000000000000..6a954490e959c --- /dev/null +++ b/paddle/pir/dialect/control_flow/ir/cf_type.h @@ -0,0 +1,41 @@ + +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/type_base.h" + +namespace pir { +class IR_API StackType : public Type::TypeBase { + public: + using Base::Base; +}; + +class IR_API InletType : public Type::TypeBase { + public: + using Base::Base; +}; + +class IR_API OutletType : public Type::TypeBase { + public: + using Base::Base; +}; + +} // namespace pir + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::StackType) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::InletType) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::OutletType) diff --git a/paddle/pir/dialect/shape/CMakeLists.txt b/paddle/pir/dialect/shape/CMakeLists.txt deleted file mode 100644 index 0798e78f2b15a..0000000000000 --- a/paddle/pir/dialect/shape/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -file(GLOB_RECURSE SHAPE_SRCS "*.cc") -ir_library(pir_shape SRCS ${SHAPE_SRCS} DEPS pir_core) diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.cc b/paddle/pir/dialect/shape/ir/shape_dialect.cc index 4367670156efc..0353a7610d2b3 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.cc +++ b/paddle/pir/dialect/shape/ir/shape_dialect.cc @@ -15,20 +15,24 @@ #include "paddle/pir/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/dialect/shape/ir/shape_op.h" -namespace pir { -namespace dialect { +namespace pir::shape { ShapeDialect::ShapeDialect(IrContext *context) : Dialect(name(), context, TypeId::get()) { initialize(); } void ShapeDialect::initialize() { - RegisterOps(); + TensorDimOp, + ShapeOfOp, + FromElementsOp, + ExtractOp, + ConstantOp, + IndexCastOp>(); } void ShapeDialect::PrintOperation(Operation *op, IrPrinter &printer) const { @@ -39,7 +43,6 @@ void ShapeDialect::PrintOperation(Operation *op, IrPrinter &printer) const { } } -} // namespace dialect -} // namespace pir +} // namespace pir::shape -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::ShapeDialect) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::ShapeDialect) diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.h b/paddle/pir/dialect/shape/ir/shape_dialect.h index b8fe39bd8d500..4be71aa0127ce 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.h +++ b/paddle/pir/dialect/shape/ir/shape_dialect.h @@ -16,8 +16,7 @@ #include "paddle/pir/core/dialect.h" -namespace pir { -namespace dialect { +namespace pir::shape { /// /// \brief Shape Dialect: /// @@ -32,7 +31,6 @@ class IR_API ShapeDialect : public Dialect { void initialize(); }; -} // namespace dialect -} // namespace pir +} // namespace pir::shape -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::ShapeDialect) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::ShapeDialect) diff --git a/paddle/pir/dialect/shape/ir/shape_op.cc b/paddle/pir/dialect/shape/ir/shape_op.cc index 885f50d080143..d7acec75c0897 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.cc +++ b/paddle/pir/dialect/shape/ir/shape_op.cc @@ -18,9 +18,9 @@ #include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/enforce.h" -namespace pir::dialect { +namespace pir::shape { -const char *SymbolicDim::attributes_name[attributes_num] = { +const char *SymbolicDimOp::attributes_name[attributes_num] = { "known_negative_one", // value = -1 "known_non_negative", // value >= 0 "known_non_size_one", // value != 1 @@ -28,14 +28,14 @@ const char *SymbolicDim::attributes_name[attributes_num] = { "sym_name", "value"}; // NOLINT -void SymbolicDim::Build(Builder &builder, - OperationArgument &argument, - const std::string &sym_name, - int64_t value, - bool known_non_negative, - bool known_negative_one, - bool known_non_size_one, - bool known_non_size_zero) { +void SymbolicDimOp::Build(Builder &builder, + OperationArgument &argument, + const std::string &sym_name, + int64_t value, + bool known_non_negative, + bool known_negative_one, + bool known_non_size_one, + bool known_non_size_zero) { IrContext *ctx = IrContext::Instance(); auto attr_sym_name = StrAttribute::get(ctx, sym_name); auto attr_value = Int64Attribute::get(ctx, value); @@ -52,57 +52,66 @@ void SymbolicDim::Build(Builder &builder, argument.AddAttribute("known_non_size_zero", attr_known_non_size_zero); } -const std::string SymbolicDim::GetSymName() { +const std::string SymbolicDimOp::GetSymName() { return attribute("sym_name").AsString(); } -int64_t SymbolicDim::GetDimSize() { + +int64_t SymbolicDimOp::GetDimSize() { return attribute("value").data(); } -bool SymbolicDim::GetKnownNonNegative() { + +bool SymbolicDimOp::GetKnownNonNegative() { return attribute("known_non_negative").data(); } -bool SymbolicDim::GetKnownNegativeOne() { + +bool SymbolicDimOp::GetKnownNegativeOne() { return attribute("known_negative_one").data(); } -bool SymbolicDim::GetKnownNonSizeOne() { + +bool SymbolicDimOp::GetKnownNonSizeOne() { return attribute("known_non_size_one").data(); } -bool SymbolicDim::GetKnownNonSizeZero() { + +bool SymbolicDimOp::GetKnownNonSizeZero() { return attribute("known_non_size_zero").data(); } -void SymbolicDim::SetSymName(const std::string &attr_value) { +void SymbolicDimOp::SetSymName(const std::string &attr_value) { operation()->set_attribute( "sym_name", StrAttribute::get(IrContext::Instance(), attr_value)); } -void SymbolicDim::SetDimSize(int64_t attr_value) { + +void SymbolicDimOp::SetDimSize(int64_t attr_value) { operation()->set_attribute( "value", Int64Attribute::get(IrContext::Instance(), attr_value)); } -void SymbolicDim::UpdateKnownNonNegative(bool flag) { +void SymbolicDimOp::UpdateKnownNonNegative(bool flag) { operation()->set_attribute("known_non_negative", BoolAttribute::get(IrContext::Instance(), flag)); } -void SymbolicDim::UpdateKnownNegativeOne(bool flag) { + +void SymbolicDimOp::UpdateKnownNegativeOne(bool flag) { operation()->set_attribute("known_negative_one", BoolAttribute::get(IrContext::Instance(), flag)); } -void SymbolicDim::UpdateKnownNonSizeOne(bool flag) { + +void SymbolicDimOp::UpdateKnownNonSizeOne(bool flag) { operation()->set_attribute("known_non_size_one", BoolAttribute::get(IrContext::Instance(), flag)); } -void SymbolicDim::UpdateKnownNonSizeZero(bool flag) { + +void SymbolicDimOp::UpdateKnownNonSizeZero(bool flag) { operation()->set_attribute("known_non_size_zero", BoolAttribute::get(IrContext::Instance(), flag)); } -bool SymbolicDim::IsDynamic() { +bool SymbolicDimOp::IsDynamic() { return GetDimSize() == ShapedTypeInterface::kDynamic; } -bool SymbolicDim::Merge(SymbolicDim other) { - VLOG(4) << "Try to merge two SymbolicDim ops."; +bool SymbolicDimOp::Merge(SymbolicDimOp other) { + VLOG(4) << "Try to merge two SymbolicDimOp."; if (!IsDynamic() && !other.IsDynamic() && GetDimSize() != other.GetDimSize()) return false; @@ -145,11 +154,11 @@ void DimOp::Build(Builder &builder, argument.output_types.emplace_back(IndexType::get(IrContext::Instance())); } -const std::string DimOp::getName() { +const std::string DimOp::GetName() { return attribute("name").AsString(); } -void DimOp::setName(std::string attrName) { +void DimOp::SetName(std::string attrName) { operation()->set_attribute( "name", StrAttribute::get(IrContext::Instance(), attrName)); } @@ -192,6 +201,7 @@ std::vector TieProductEqualOp::lhs() { } return res; } + std::vector TieProductEqualOp::rhs() { int64_t lhs_len = attribute("lhs_len").data(); int64_t rhs_len = attribute("rhs_len").data(); @@ -203,13 +213,14 @@ std::vector TieProductEqualOp::rhs() { } const char *TieShapeOp::attributes_name[attributes_num] = { - SymbolicDim::GetSymbolicDimAttrName().c_str()}; // NOLINT + SymbolicDimOp::GetSymbolicDimAttrName().c_str()}; // NOLINT void TieShapeOp::Build(Builder &builder, OperationArgument &argument, Value input) { argument.AddInput(input); } + void TieShapeOp::Build(Builder &builder, // NOLINT OperationArgument &argument, // NOLINT Value input, @@ -218,8 +229,6 @@ void TieShapeOp::Build(Builder &builder, // NOLINT argument.AddInputs(dims); } -Value TieShapeOp::value() { return operand_source(0); } - std::vector TieShapeOp::dims() { std::vector res; for (uint32_t i = 1; i < num_operands(); i++) { @@ -261,23 +270,82 @@ void TensorDimOp::Build(Builder &builder, OperationArgument &argument, Value source, int64_t index) { - OpResult indexValue = + OpResult index_value = builder .Build(Int64Attribute::get(IrContext::Instance(), index), IndexType::get(IrContext::Instance())) ->result(0); - argument.AddInputs({source, indexValue}); + argument.AddInputs({source, index_value}); argument.output_types.emplace_back(IndexType::get(IrContext::Instance())); } -Value TensorDimOp::source() { return operand_source(0); } +std::optional TensorDimOp::GetConstantIndex() { + auto op = index().dyn_cast().owner(); + int64_t index = + op->dyn_cast().value().dyn_cast().data(); + return index; +} + +void ShapeOfOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value input) { + argument.AddInput(input); +} + +void FromElementsOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const std::vector &elements) { + argument.AddInputs(elements); +} -Value TensorDimOp::index() { return operand_source(1); } -} // namespace pir::dialect +std::vector FromElementsOp::elements() { + std::vector elements; + for (uint32_t idx = 0; idx < num_operands(); idx++) { + elements.push_back(operand_source(static_cast(idx))); + } + return elements; +} + +void ExtractOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value tensor, + std::vector indices) { + argument.AddInput(tensor); + argument.AddInputs(indices); +} + +std::vector ExtractOp::indices() { + std::vector indices; + for (uint32_t idx = 1; idx < num_operands(); idx++) { + indices.push_back(operand_source(static_cast(idx))); + } + return indices; +} + +void ConstantIndexOp::Build(Builder &builder, + OperationArgument &argument, + int64_t value) { + ConstantOp::Build( + builder, argument, builder.index_attr(value), builder.index_type()); +} + +void IndexCastOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Type out, + Value in) { + argument.AddInput(in); +} -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::SymbolicDim) -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::DimOp) -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TieProductEqualOp) -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TieShapeOp) -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::FuncOp) -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TensorDimOp) +} // namespace pir::shape + +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::SymbolicDimOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::DimOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::TieProductEqualOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::TieShapeOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::FuncOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::TensorDimOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::ShapeOfOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::FromElementsOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::ExtractOp); +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::ConstantIndexOp); +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::IndexCastOp); diff --git a/paddle/pir/dialect/shape/ir/shape_op.h b/paddle/pir/dialect/shape/ir/shape_op.h index c838624d2566d..31e35f376c55f 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.h +++ b/paddle/pir/dialect/shape/ir/shape_op.h @@ -14,14 +14,16 @@ #pragma once +#include #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/core/ir_printer.h" #include "paddle/pir/core/op_base.h" +#include "paddle/pir/core/op_trait.h" -namespace pir::dialect { +namespace pir::shape { -class IR_API SymbolicDim : public Op { +class IR_API SymbolicDimOp : public Op { public: using Op::Op; static const char *name() { return "shape.symbolic_dim"; } @@ -61,11 +63,11 @@ class IR_API SymbolicDim : public Op { // Sets `known_non_size_zero` to the value of `flag` void UpdateKnownNonSizeZero(bool flag); - // Returns true if this SymbolicDim is not known at compile-time. + // Returns true if this SymbolicDimOp is not known at compile-time. bool IsDynamic(); - // Try to merge two SymbolicDim ops. - bool Merge(SymbolicDim other); + // Try to merge two SymbolicDimOp. + bool Merge(SymbolicDimOp other); static const std::string GetSymbolicDimAttrName() { return "kSymbolicDimAttr"; @@ -86,8 +88,8 @@ class IR_API DimOp : public Op { OperationArgument &argument, // NOLINT const std::string &name); - const std::string getName(); - void setName(std::string attrValue); + const std::string GetName(); + void SetName(std::string attrValue); OpResult out() { return result(0); } void VerifySig() {} }; @@ -130,7 +132,7 @@ class IR_API TieShapeOp : public Op { OperationArgument &argument, // NOLINT Value input, const std::vector &dims); - Value value(); + Value input() { return operand_source(0); } std::vector dims(); void VerifySig() {} }; @@ -150,7 +152,7 @@ class IR_API FuncOp : public Op { void VerifySig() {} }; -class IR_API TensorDimOp : public Op { +class IR_API TensorDimOp : public Op { public: using Op::Op; static const char *name() { return "shape.tensor_dim"; } @@ -166,17 +168,106 @@ class IR_API TensorDimOp : public Op { OperationArgument &argument, // NOLINT Value source, int64_t index); - Value index(); - Value source(); + + Value source() { return operand_source(0); } + Value index() { return operand_source(1); } + OpResult out() { return result(0); } + void VerifySig() {} + std::optional GetConstantIndex(); +}; + +class IR_API ShapeOfOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.shape_of"; } + + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value input); + + Value input() { return operand_source(0); } + OpResult out() { return result(0); } + void VerifySig() {} +}; + +class IR_API FromElementsOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.from_elements"; } + + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const std::vector &elements); + + std::vector elements(); + OpResult out() { return result(0); } + void VerifySig() {} +}; + +class IR_API ExtractOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.extract"; } + + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value tensor, + std::vector indices); + + Value tensor() { return operand_source(0); } + std::vector indices(); OpResult out() { return result(0); } void VerifySig() {} }; -} // namespace pir::dialect +// Specialization of `constant` op that returns an integer of index type. +class IR_API ConstantIndexOp : public ConstantOp { + public: + using ConstantOp::ConstantOp; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + int64_t value); +}; + +// Cast between index and integer types. +class IR_API IndexCastOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.index_cast"; } + + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Type out, + Value in); + + Value in() { return operand_source(0); } + OpResult out() { return result(0); } + void VerifySig() {} +}; -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::SymbolicDim); -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::DimOp); -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TieProductEqualOp); -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TieShapeOp); -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::FuncOp); -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TensorDimOp); +} // namespace pir::shape + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::SymbolicDimOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::DimOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::TieProductEqualOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::TieShapeOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::FuncOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::TensorDimOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::ShapeOfOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::FromElementsOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::ExtractOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::ConstantIndexOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::IndexCastOp); diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization.cc b/paddle/pir/dialect/shape/transforms/shape_optimization.cc index 767353efdbc5f..df21e6112a7a3 100644 --- a/paddle/pir/dialect/shape/transforms/shape_optimization.cc +++ b/paddle/pir/dialect/shape/transforms/shape_optimization.cc @@ -13,41 +13,41 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/pir/dialect/shape/ir/shape_op.h" - #include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/infer_type_op_interface.h" #include "paddle/pir/core/program.h" +#include "paddle/pir/dialect/shape/ir/shape_op.h" #include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" #include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace pir { namespace { -using PassPipelineRunner = - std::function; bool InsertTieShapeOnValue(pir::Value value, pir::Builder& builder) { // NOLINT - auto ty = value.type().dyn_cast(); + auto type = value.type().dyn_cast(); - if (!ty || ty.dims().size() == 0) return true; - std::vector dimSizes; - for (int64_t dim = 0, rank = ty.dims().size(); dim < rank; ++dim) { - auto dimOp = builder.Build(value, dim); - dimSizes.push_back(dimOp.out()); + if (!type || type.dims().size() == 0) return true; + std::vector dim_sizes; + for (int64_t dim = 0, rank = type.dims().size(); dim < rank; ++dim) { + auto dim_op = builder.Build(value, dim); + dim_sizes.push_back(dim_op.out()); } - builder.Build(value, dimSizes); + builder.Build(value, dim_sizes); return true; } +// Forward declaration bool InsertTieShapeOnRegion(pir::Region* region); bool InsertTieShapeOnOperation(pir::Operation* op, pir::Builder& builder) { // NOLINT - // TODO(zhangbo63): skip more specialized Ops. - if (op->isa() || op->isa()) - return true; + // TODO(zhangbopd): skip more specialized Ops. + if (op->isa() || op->isa()) return true; for (size_t i = 0; i < op->num_regions(); ++i) { if (!InsertTieShapeOnRegion(&(op->region(i)))) return false; @@ -63,7 +63,7 @@ bool InsertTieShapeOnOperation(pir::Operation* op, bool InsertTieShapeOnBlock(pir::Block* block) { pir::Builder builder = pir::Builder(pir::IrContext::Instance(), block, block->begin()); - // TODO(liujinnan): mapping block arguments + // TODO(zhangbopd): mapping block arguments std::vector op_list; for (pir::Operation* op : *block) op_list.push_back(op); @@ -74,18 +74,108 @@ bool InsertTieShapeOnBlock(pir::Block* block) { } bool InsertTieShapeOnRegion(pir::Region* region) { - for (pir::Block* block : *region) { + for (Block* block : *region) { if (!InsertTieShapeOnBlock(block)) return false; } return true; } +struct ExpandShapeOfOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(shape::ShapeOfOp op, + PatternRewriter& rewriter) const override { + // TODO(zhangbopd): Uncomment + // auto type = op.out().type().dyn_cast(); + + // if (!type || !type.dyn_cast().HasStaticShape() || + // !type.dyn_cast().GetElementType().IsIndex()) + // return false; + + // std::vector dim_sizes; + // for (int dim = 0, rank = + // type.dyn_cast().GetShape()[0]; + // dim < rank; + // ++dim) { + // dim_sizes.push_back( + // rewriter.Build(op.input(), dim).out()); + // } + // rewriter.ReplaceOpWithNewOp(op, dim_sizes); + return true; + } +}; + +// Fold dim of an operation that implements the InferShapedTypeOpInterface +template +struct DimOfShapedTypeOpInterfacePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(OpTy dim_op, PatternRewriter& rewriter) const override { + OpResult dim_value = dim_op.source().template dyn_cast(); + if (!dim_value) return false; + + auto shaped_type_op = + dim_value.owner()->dyn_cast(); + + if (!shaped_type_op) return false; + // TODO(zhangbopd): Uncomment + // std::optional dim_index = dim_op.GetConstantIndex(); + // if (!dim_index) return false; + + // std::vector reified_result_shapes; + // if (!shaped_type_op.ReifyReturnTypeShapes( + // rewriter, shaped_type_op->operands(), reified_result_shapes)) + // return false; + + // if (reified_result_shapes.size() != shaped_type_op->num_results()) + // return false; + + // Value result_shape = reified_result_shapes[dim_value.index()]; + // auto result_shape_type = result_shape.type().dyn_cast(); + // auto shaped_type = result_shape_type.dyn_cast(); + // if (!result_shape_type || !shaped_type.GetElementType().IsIntOrIndex()) + // return false; + + // // TODO(zhangbopd): BuildOrFold required. + // std::vector indices; + // indices.push_back(rewriter.Build(*dim_index).out()); + // Value new_value = + // rewriter.Build(result_shape, indices).out(); + + // if (!new_value.type().isa()) + // new_value = + // rewriter.Build(rewriter.index_type(), + // new_value) + // .out(); + + // rewriter.ReplaceOp(dim_op, {new_value}); + return true; + } +}; + bool MaterializeShapeComputation(pir::ModuleOp m) { if (!InsertTieShapeOnRegion(&(m->region(0)))) return false; - // TODO(liujinnan): add rewitter pattern for reifyInferShape. + // TODO(zhangbopd): add rewitter pattern for reifyInferShape. + RewritePatternSet patterns(m.ir_context()); + + patterns.Add>( + patterns.ir_context()); + + IR_ENFORCE(ApplyPatternsGreedily(m, std::move(patterns)), + "fail to materialize shape computation\n"); return true; } +using PassPipelineRunner = + std::function; + +// Returns true if the type is possible to be a shape tensor type. +// Shape tensor type : +// - rank-1 static-shaped tensor type +// - element type of the tensor is int or index +// - number of elements of the tensor < 32, supposing that the +// higiest possible rank is smaller than 32. bool IsCandidateShapeTensorType(Type type) { auto tensor_type = type.dyn_cast(); auto shaped_type = tensor_type.dyn_cast(); @@ -119,21 +209,16 @@ class ShapeComputationIRAnalysis { ModuleOp m_; SymbolicDimMgr& mgr_; - std::unordered_map value_to_sym_dim_; + std::unordered_map value_to_sym_dim_; // shape tensor is the 1D ranked tensor with int/index dtype. - std::unordered_map> shape_tensor_to_sym_dims_; + std::unordered_map> + shape_tensor_to_sym_dims_; - std::unordered_map> dense_tensor_to_sym_dims_; + std::unordered_map> + dense_tensor_to_sym_dims_; }; -// Returns true if the type is possible to be a shape tensor type. -// Shape tensor type : -// - rank-1 static-shaped tensor type -// - element type of the tensor is int or index -// - number of elements of the tensor < 32, supposing that the -// higiest possible rank is smaller than 32. - ShapeComputationIRAnalysis::ShapeComputationIRAnalysis(ModuleOp m, SymbolicDimMgr& mgr) : m_(m), mgr_(mgr) {} @@ -163,7 +248,7 @@ bool ShapeComputationIRAnalysis::RunOnRegion(Region* region, func fn) { } bool ShapeComputationIRAnalysis::RunOnBlock(Block* block, func fn) { - // TODO(liujinnan): mapping block arguments + // TODO(zhangbopd): mapping block arguments std::vector op_list; for (Operation* op : *block) op_list.push_back(op); @@ -181,37 +266,37 @@ bool ShapeComputationIRAnalysis::RunOnOperation(Operation* op, func fn) { } bool ShapeComputationIRAnalysis::BuildShapeOnOperation(Operation* op) { - if (op->isa()) return true; - if (op->isa()) { + if (op->isa()) return true; + if (op->isa()) { Value value = op->operand_source(0); - std::vector symbols; - if (op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) { + std::vector symbols; + if (op->HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) { auto attrs = - op->attribute(SymbolicDim::GetSymbolicDimAttrName()) + op->attribute(SymbolicDimOp::GetSymbolicDimAttrName()) .AsVector(); for (Attribute attr : attrs) { - auto sym = mgr_.symbolTable().Lookup( + auto sym = mgr_.symbolTable().Lookup( attr.dyn_cast().AsString()); - assert(sym); - SymbolicDim root = mgr_.GetRootSymbolicDim(sym); + IR_ENFORCE(sym); + SymbolicDimOp root = mgr_.GetRootSymbolicDim(sym); symbols.push_back(root); } } else { symbols = mgr_.CreateSymbolicDimsForRankedValue(value); std::vector attrs; - for (SymbolicDim sym : symbols) { + for (SymbolicDimOp sym : symbols) { Attribute rootSymbol = StrAttribute::get(m_->ir_context(), sym.GetSymName()); attrs.push_back(rootSymbol); } - op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(), + op->set_attribute(SymbolicDimOp::GetSymbolicDimAttrName(), ArrayAttribute::get(m_->ir_context(), attrs)); } dense_tensor_to_sym_dims_[value] = std::move(symbols); return true; } - for (size_t i = 0; i < op->num_results(); ++i) { - if (!BuildShapeOnValue(op->result(i))) return false; + for (auto& result : op->results()) { + if (!BuildShapeOnValue(result)) return false; } return true; } @@ -219,11 +304,11 @@ bool ShapeComputationIRAnalysis::BuildShapeOnOperation(Operation* op) { bool ShapeComputationIRAnalysis::BuildShapeOnValue(Value value) { Type type = value.type(); if (type.IsIntOrIndex()) { - SymbolicDim sym = mgr_.NewSymbolicDim(); + SymbolicDimOp sym = mgr_.NewSymbolicDim(); value_to_sym_dim_[value] = sym; } else if (IsCandidateShapeTensorType(type)) { auto shaped_type = type.dyn_cast(); - std::vector symbols; + std::vector symbols; for (size_t i = 0, d = shaped_type.GetShape()[0]; i < d; ++i) symbols.push_back(mgr_.NewSymbolicDim()); shape_tensor_to_sym_dims_[value] = std::move(symbols); @@ -237,7 +322,7 @@ bool ShapeComputationIRAnalysis::ApplyOpConstraint(Operation* op) { IR_ENFORCE(ApplyTieShapeOpConstraint(op), "Fail to apply constraint for tie_shape op"); - // TODO(zhangbo63): add more constraints + // TODO(zhangbopd): add more constraints return true; } @@ -247,7 +332,7 @@ bool ShapeComputationIRAnalysis::ApplyIndexOpConstraint(Operation* op) { Type type = op->result(0).type(); if (!type.IsIntOrIndex()) return true; - if (auto dim_op = op->dyn_cast()) { + if (auto dim_op = op->dyn_cast()) { int64_t dim_index = dim_op.index() .dyn_cast() .owner() @@ -267,12 +352,12 @@ bool ShapeComputationIRAnalysis::ApplyIndexOpConstraint(Operation* op) { return false; } } - // TODO(zhangbo63): add support for reifyInferShape. (e.g. mul/add) + // TODO(zhangbopd): add support for reifyInferShape. (e.g. mul/add) return true; } bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) { - if (auto tie_shape = op->dyn_cast()) { + if (auto tie_shape = op->dyn_cast()) { auto& value = dense_tensor_to_sym_dims_[op->operand_source(0)]; for (size_t idx = 0; idx < tie_shape.dims().size(); ++idx) { if (!mgr_.MapSymbolicDimEqual(value_to_sym_dim_[tie_shape.dims()[idx]], @@ -285,7 +370,7 @@ bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) { } bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) { - // TODO(liujinnan): Do some Canonicalizer. + // TODO(zhangbopd): Do some Canonicalizer. pir::SymbolicDimMgr mgr(m); IR_ENFORCE(mgr.Load(), "SymbolicDimMgr Load failed in OptimizeShapeComputation."); @@ -300,7 +385,7 @@ bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) { class ShapeOptimizationPass : public pir::Pass { public: - ShapeOptimizationPass() : pir::Pass("shape_optimization", 0) {} + ShapeOptimizationPass() : pir::Pass("shape_optimization_pass", 0) {} void Run(pir::Operation* op) override { auto module_op = op->dyn_cast(); @@ -328,4 +413,4 @@ std::unique_ptr CreateShapeOptimizationPass() { } // namespace pir -REGISTER_IR_PASS(shape_optimization, pir::ShapeOptimizationPass); +REGISTER_IR_PASS(shape_optimization_pass, pir::ShapeOptimizationPass); diff --git a/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc b/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc index 07f7cf4129a4d..6954858bc8956 100644 --- a/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc @@ -50,22 +50,22 @@ bool CompareSymbolicDimProduct(SymbolicDimProduct& lhs, // NOLINT SymbolicDimMgr::SymbolicDimMgr(ModuleOp m) : m_(m) { for (auto op : *(m.block())) { - if (op->isa()) { + if (op->isa()) { symbol_table_ = SymbolTable(op); return; } } Builder builder = Builder(m_.ir_context(), m_.block(), m_.block()->begin()); - dialect::FuncOp func = builder.Build(); + shape::FuncOp func = builder.Build(); symbol_table_ = SymbolTable(func); } bool SymbolicDimMgr::Load() { - auto func_op = symbol_table_.getOp()->dyn_cast(); - assert(func_op); + auto func_op = symbol_table_.getOp()->dyn_cast(); + IR_ENFORCE(func_op); for (auto op : *(func_op.block())) { symbol_table_.insert(op); - if (SymbolicDim sym_dim_op = op->dyn_cast()) { + if (SymbolicDimOp sym_dim_op = op->dyn_cast()) { symbol_dim_union_set_[sym_dim_op] = sym_dim_op; symbol_name_set_.insert(sym_dim_op.GetSymName()); } @@ -74,10 +74,10 @@ bool SymbolicDimMgr::Load() { } bool SymbolicDimMgr::LoadShapeConstraintGraph() { - // TODO(liujinnan): add more constraint function. currently, only support + // TODO(zhangbopd): add more constraint function. currently, only support // tie_product_equal. auto constraint_vec = - symbol_table_.Lookup("tie_product_equal"); + symbol_table_.Lookup("tie_product_equal"); if (!constraint_vec.size()) return true; @@ -88,8 +88,8 @@ bool SymbolicDimMgr::LoadShapeConstraintGraph() { if (auto constOp = defining_op->dyn_cast()) { product.factor *= constOp.value().dyn_cast().data(); continue; - } else if (auto dimOp = defining_op->dyn_cast()) { - auto sym = symbol_table_.Lookup(dimOp.getName()); + } else if (auto dim_op = defining_op->dyn_cast()) { + auto sym = symbol_table_.Lookup(dim_op.GetName()); if (!sym) return false; product.symbols.push_back(sym); continue; @@ -139,17 +139,17 @@ bool SymbolicDimMgr::MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, SymbolicDimProduct SymbolicDimMgr::SimplifySymbolicDimProduct( const SymbolicDimProduct& x) { - std::vector copied; + std::vector copied; copied.reserve(x.symbols.size()); - for (SymbolicDim op : x.symbols) copied.push_back(GetRootSymbolicDim(op)); + for (SymbolicDimOp op : x.symbols) copied.push_back(GetRootSymbolicDim(op)); std::sort( - copied.begin(), copied.end(), [&](SymbolicDim lhs, SymbolicDim rhs) { + copied.begin(), copied.end(), [&](SymbolicDimOp lhs, SymbolicDimOp rhs) { return CompareSymbolicDimNames(lhs.GetSymName(), rhs.GetSymName()); }); SymbolicDimProduct new_x; new_x.factor = x.factor; - for (SymbolicDim op : copied) { + for (SymbolicDimOp op : copied) { if (!op.IsDynamic()) { new_x.factor *= op.GetDimSize(); } else { @@ -186,13 +186,13 @@ SymbolicDimMgr::SimplifySymbolicDimProductPair(const SymbolicDimProduct& x, new_lhs.factor = lhs.factor / gcd_factor; new_rhs.factor = rhs.factor / gcd_factor; - std::unordered_map lhs_symbol_map; - std::unordered_map rhs_symbol_map; + std::unordered_map lhs_symbol_map; + std::unordered_map rhs_symbol_map; - for (SymbolicDim op : lhs.symbols) ++lhs_symbol_map[op]; - for (SymbolicDim op : rhs.symbols) ++rhs_symbol_map[op]; + for (SymbolicDimOp op : lhs.symbols) ++lhs_symbol_map[op]; + for (SymbolicDimOp op : rhs.symbols) ++rhs_symbol_map[op]; - for (SymbolicDim op : lhs.symbols) { + for (SymbolicDimOp op : lhs.symbols) { auto it = rhs_symbol_map.find(op); if (it != rhs_symbol_map.end() && op.GetKnownNonSizeZero()) { if (--it->second == 0) rhs_symbol_map.erase(it); @@ -201,7 +201,7 @@ SymbolicDimMgr::SimplifySymbolicDimProductPair(const SymbolicDimProduct& x, new_lhs.symbols.push_back(op); } - for (SymbolicDim op : rhs.symbols) { + for (SymbolicDimOp op : rhs.symbols) { auto it = lhs_symbol_map.find(op); if (it != lhs_symbol_map.end() && op.GetKnownNonSizeZero()) { if (--it->second == 0) lhs_symbol_map.erase(it); @@ -224,24 +224,24 @@ const std::string SymbolicDimMgr::GetNextName() { return name; } -SymbolicDim SymbolicDimMgr::NewSymbolicDim(const std::string& name) { - auto func_op = symbol_table_.getOp()->dyn_cast(); - assert(func_op); +SymbolicDimOp SymbolicDimMgr::NewSymbolicDim(const std::string& name) { + auto func_op = symbol_table_.getOp()->dyn_cast(); + IR_ENFORCE(func_op); Builder builder = Builder(m_.ir_context(), func_op.block()); // default settting dim != 0 - dialect::SymbolicDim symbol = - builder.Build(name.empty() ? GetNextName() : name, - ShapedTypeInterface::kDynamic, - false, - false, - false, - true); + SymbolicDimOp symbol = + builder.Build(name.empty() ? GetNextName() : name, + ShapedTypeInterface::kDynamic, + false, + false, + false, + true); symbol_dim_union_set_[symbol] = symbol; symbol_table_.insert(symbol); return symbol; } -SymbolicDim SymbolicDimMgr::NewConstantSymbolicDim(int64_t val) { +SymbolicDimOp SymbolicDimMgr::NewConstantSymbolicDim(int64_t val) { auto it = constant_symbolic_dim_map_.find(val); if (it == constant_symbolic_dim_map_.end()) { auto name = "C" + std::to_string(val); @@ -257,9 +257,9 @@ SymbolicDim SymbolicDimMgr::NewConstantSymbolicDim(int64_t val) { return GetRootSymbolicDim(it->second); } -std::vector SymbolicDimMgr::CreateSymbolicDimsForRankedValue( +std::vector SymbolicDimMgr::CreateSymbolicDimsForRankedValue( Value value) { - std::vector symbols; + std::vector symbols; auto dims = value.type().dyn_cast().dims(); for (int idx = 0; idx < dims.size(); ++idx) { symbols.push_back(dims[idx] == ShapedTypeInterface::kDynamic @@ -269,26 +269,26 @@ std::vector SymbolicDimMgr::CreateSymbolicDimsForRankedValue( return symbols; } -SymbolicDim SymbolicDimMgr::GetRootSymbolicDim(SymbolicDim symbol) { - SymbolicDim current = symbol; - std::vector path; +SymbolicDimOp SymbolicDimMgr::GetRootSymbolicDim(SymbolicDimOp symbol) { + SymbolicDimOp current = symbol; + std::vector path; while (symbol_dim_union_set_[current] != current) { path.push_back(current); current = symbol_dim_union_set_[current]; } - for (SymbolicDim sym : path) symbol_dim_union_set_[sym] = current; + for (SymbolicDimOp sym : path) symbol_dim_union_set_[sym] = current; return current; } -bool SymbolicDimMgr::IsSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { - SymbolicDim lhs_root = GetRootSymbolicDim(lhs); - SymbolicDim rhs_root = GetRootSymbolicDim(rhs); +bool SymbolicDimMgr::IsSymbolicDimEqual(SymbolicDimOp lhs, SymbolicDimOp rhs) { + SymbolicDimOp lhs_root = GetRootSymbolicDim(lhs); + SymbolicDimOp rhs_root = GetRootSymbolicDim(rhs); return lhs_root == rhs_root; } -bool SymbolicDimMgr::MapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { - SymbolicDim lhs_root = GetRootSymbolicDim(lhs); - SymbolicDim rhs_root = GetRootSymbolicDim(rhs); +bool SymbolicDimMgr::MapSymbolicDimEqual(SymbolicDimOp lhs, SymbolicDimOp rhs) { + SymbolicDimOp lhs_root = GetRootSymbolicDim(lhs); + SymbolicDimOp rhs_root = GetRootSymbolicDim(rhs); if (lhs_root != rhs_root) { if (CompareSymbolicDimNames(lhs_root.GetSymName(), rhs_root.GetSymName())) { @@ -315,10 +315,10 @@ SymbolicDimProduct* SymbolicDimMgr::SymbolicDimProductDivide( SymbolicDimProduct* result = new SymbolicDimProduct(); result->factor = new_lhs.factor / new_rhs.factor; - std::unordered_map sym_proc_map; - for (SymbolicDim sym : new_rhs.symbols) ++sym_proc_map[sym]; + std::unordered_map sym_proc_map; + for (SymbolicDimOp sym : new_rhs.symbols) ++sym_proc_map[sym]; - for (SymbolicDim sym : new_lhs.symbols) { + for (SymbolicDimOp sym : new_lhs.symbols) { auto it = sym_proc_map.find(sym); if (it == sym_proc_map.end()) { result->symbols.push_back(sym); @@ -457,13 +457,13 @@ bool SymbolicDimMgr::IsSymbolicDimProductEqual(const SymbolicDimProduct& lhs, } bool SymbolicDimMgr::Save() { - using Name2SymbolFn = std::function; + using Name2SymbolFn = std::function; auto update_attrs = [&](ArrayAttribute attrs, Name2SymbolFn fn) { std::vector new_attrs; for (Attribute attr : attrs.AsVector()) { auto sym = fn(attr.dyn_cast().AsString()); - assert(sym); - SymbolicDim root = GetRootSymbolicDim(sym); + IR_ENFORCE(sym); + SymbolicDimOp root = GetRootSymbolicDim(sym); Attribute root_symbol = StrAttribute::get(m_->ir_context(), root.GetSymName()); new_attrs.push_back(root_symbol); @@ -471,41 +471,41 @@ bool SymbolicDimMgr::Save() { return ArrayAttribute::get(m_->ir_context(), new_attrs); }; - // TODO(liujinnan): update attributes attached in DenseTensorType + // TODO(zhangbopd): update attributes attached in DenseTensorType for (auto op : *(m_.block())) { - if (!op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) continue; + if (!op->HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue; auto attrs = - op->attribute(SymbolicDim::GetSymbolicDimAttrName()); + op->attribute(SymbolicDimOp::GetSymbolicDimAttrName()); auto symbolic_shape_attr = update_attrs(attrs, [&](const std::string& name) { - return symbol_table_.Lookup(name); + return symbol_table_.Lookup(name); }); - op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(), + op->set_attribute(SymbolicDimOp::GetSymbolicDimAttrName(), symbolic_shape_attr); } if (!UpdateProductEqualityMap()) { return false; } - std::unordered_set used_symbolic_ops; + std::unordered_set used_symbolic_ops; std::vector used_symbol_names; - // TODO(liujinnan): collect uses in value. + // TODO(zhangbopd): collect uses in value. auto collect_used_symbols = [&](ArrayAttribute attrs) { for (Attribute attr : attrs.AsVector()) { - auto sym = symbol_table_.Lookup( + auto sym = symbol_table_.Lookup( attr.dyn_cast().AsString()); - assert(sym); + IR_ENFORCE(sym); if (used_symbolic_ops.insert(sym).second) used_symbol_names.push_back(sym.GetSymName()); } }; for (auto op : *(m_.block())) { - if (!op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) continue; + if (!op->HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue; auto attrs = - op->attribute(SymbolicDim::GetSymbolicDimAttrName()); + op->attribute(SymbolicDimOp::GetSymbolicDimAttrName()); collect_used_symbols(attrs); } - auto func_op = symbol_table_.getOp()->dyn_cast(); - assert(func_op); + auto func_op = symbol_table_.getOp()->dyn_cast(); + IR_ENFORCE(func_op); for (auto& p : symbol_dim_union_set_) { if (!used_symbolic_ops.count(p.first)) { func_op.block()->erase(*(p.first.operation())); @@ -514,10 +514,11 @@ bool SymbolicDimMgr::Save() { std::vector candidates; for (auto& outter : product_equality_map_) { - if (std::any_of( - outter.first.symbols.begin(), - outter.first.symbols.end(), - [&](SymbolicDim sym) { return used_symbolic_ops.count(sym) == 0; })) + if (std::any_of(outter.first.symbols.begin(), + outter.first.symbols.end(), + [&](SymbolicDimOp sym) { + return used_symbolic_ops.count(sym) == 0; + })) candidates.push_back(outter.first); } @@ -527,7 +528,7 @@ bool SymbolicDimMgr::Save() { for (auto& inner : outter.second) { if (std::any_of(inner.first.symbols.begin(), inner.first.symbols.end(), - [&](SymbolicDim sym) { + [&](SymbolicDimOp sym) { return used_symbolic_ops.count(sym) == 0; })) candidates.push_back(outter.first); @@ -550,35 +551,35 @@ bool SymbolicDimMgr::Save() { } } - std::unordered_map name_to_symbol; - for (SymbolicDim op : used_symbolic_ops) { + std::unordered_map name_to_symbol; + for (SymbolicDimOp op : used_symbolic_ops) { auto name = op.GetSymName(); op.SetSymName(name_mapping[name]); name_to_symbol[name] = op; } for (auto op : *(m_.block())) { - if (!op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) continue; + if (!op->HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue; auto attrs = - op->attribute(SymbolicDim::GetSymbolicDimAttrName()); + op->attribute(SymbolicDimOp::GetSymbolicDimAttrName()); auto symbolic_shape_attr = update_attrs( attrs, [&](const std::string& name) { return name_to_symbol[name]; }); - op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(), + op->set_attribute(SymbolicDimOp::GetSymbolicDimAttrName(), symbolic_shape_attr); } - // TODO(liujinnan): update attributes attached to values. + // TODO(zhangbopd): update attributes attached to values. return SaveShapeConstraintGraph(); } bool SymbolicDimMgr::SaveShapeConstraintGraph() { - auto func_op = symbol_table_.getOp()->dyn_cast(); - assert(func_op); + auto func_op = symbol_table_.getOp()->dyn_cast(); + IR_ENFORCE(func_op); auto op_it = func_op.block()->rbegin(); while (op_it != func_op.block()->rend()) { - if (((*op_it)->isa()) || - ((*op_it)->isa())) + if (((*op_it)->isa()) || + ((*op_it)->isa())) op_it++; else op_it = decltype(op_it)(func_op.block()->erase(*(*op_it))); @@ -597,8 +598,8 @@ bool SymbolicDimMgr::SaveShapeConstraintGraph() { Int32Type::get(m_->ir_context())) ->result(0)); } - for (SymbolicDim sym : prod.symbols) { - values.push_back(builder.Build(sym.GetSymName()).out()); + for (SymbolicDimOp sym : prod.symbols) { + values.push_back(builder.Build(sym.GetSymName()).out()); } return values; }; @@ -613,7 +614,7 @@ bool SymbolicDimMgr::SaveShapeConstraintGraph() { if (!product_equality_map_[x][y]) continue; auto lhs_operands = build_operands(x); auto rhs_operands = build_operands(y); - builder.Build(lhs_operands, rhs_operands); + builder.Build(lhs_operands, rhs_operands); } } return true; diff --git a/paddle/pir/dialect/shape/utils/shape_optimization_utils.h b/paddle/pir/dialect/shape/utils/shape_optimization_utils.h index 5541e8a8ee2f1..9bce073244124 100644 --- a/paddle/pir/dialect/shape/utils/shape_optimization_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_optimization_utils.h @@ -17,13 +17,13 @@ #include "paddle/pir/dialect/shape/utils/symbol_table.h" namespace pir { -using dialect::SymbolicDim; +using shape::SymbolicDimOp; // Represents a product of symbolic and concrete factors. // Used to prove product equalities symbolically. struct SymbolicDimProduct { // List all symbolic factors that can not be aggregated. - std::vector symbols; + std::vector symbols; // Product of all const factors. int64_t factor = 1; @@ -43,7 +43,7 @@ inline bool operator!=(const SymbolicDimProduct& lhs, } struct SymDimHasher { - size_t operator()(const dialect::SymbolicDim& symbol) const noexcept { + size_t operator()(const SymbolicDimOp& symbol) const noexcept { return std::hash{}(symbol.operation()); } }; @@ -64,29 +64,29 @@ class SymbolicDimMgr { public: explicit SymbolicDimMgr(ModuleOp m); - // Loads pre-defined SymbolicDim ops from the module this mgr runs on. + // Loads pre-defined SymbolicDimOp ops from the module this mgr runs on. bool Load(); // Create a new symbolicDim instance owned by this mgr. - SymbolicDim NewSymbolicDim(const std::string& name = {}); + SymbolicDimOp NewSymbolicDim(const std::string& name = {}); // Create a symbolicDim with static dim size == `val`. - SymbolicDim NewConstantSymbolicDim(int64_t val); + SymbolicDimOp NewConstantSymbolicDim(int64_t val); // Create a symbolicDim with given value. - std::vector CreateSymbolicDimsForRankedValue(Value value); + std::vector CreateSymbolicDimsForRankedValue(Value value); // All symbolic-equal dims form a group. - // Returns the root SymbolicDim of the symbolic-equal symbolic dim group which - // this SymbolicDim belongs to. - SymbolicDim GetRootSymbolicDim(SymbolicDim symbol); + // Returns the root SymbolicDimOp of the symbolic-equal symbolic dim group + // which this SymbolicDimOp belongs to. + SymbolicDimOp GetRootSymbolicDim(SymbolicDimOp symbol); // Returns true if lhs and rhs are known to be equal. - bool IsSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + bool IsSymbolicDimEqual(SymbolicDimOp lhs, SymbolicDimOp rhs); // Marks lhs and rhs have same size and try to merge lhs & rhs static known // info. Returns false if failed to merge lhs & rhs. - bool MapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + bool MapSymbolicDimEqual(SymbolicDimOp lhs, SymbolicDimOp rhs); // Returns the simplified version of SymbolicDimProduct. // This will try to fold some symbolicDim ops with const values. @@ -139,10 +139,10 @@ class SymbolicDimMgr { std::unordered_set symbol_name_set_; - std::unordered_map + std::unordered_map symbol_dim_union_set_; - std::unordered_map constant_symbolic_dim_map_; + std::unordered_map constant_symbolic_dim_map_; // product_equality_map_[A][B] == true : Product[A] == Product[B] using SymbolicDimProductMap = std::unordered_map< diff --git a/paddle/pir/dialect/shape/utils/shape_utils.cc b/paddle/pir/dialect/shape/utils/shape_utils.cc index d746831835ed8..79f270afdba50 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_utils.cc @@ -50,16 +50,16 @@ ShapeConstraintIRAnalysis::ShapeConstraintIRAnalysis(ModuleOp m) : m_(m), mgr_(m) { mgr_.Load(); for (auto op : *(m_.block())) { - auto tie_shape_op = op->dyn_cast(); + auto tie_shape_op = op->dyn_cast(); if (!tie_shape_op) continue; - Value result = tie_shape_op.value(); + Value result = tie_shape_op.input(); auto& symbols = value_to_sym_dims_[result]; auto attrs = tie_shape_op - .attribute(SymbolicDim::GetSymbolicDimAttrName()) + .attribute(SymbolicDimOp::GetSymbolicDimAttrName()) .AsVector(); for (const auto& attr : attrs) { - auto sym_op = mgr_.symbolTable().Lookup( + auto sym_op = mgr_.symbolTable().Lookup( attr.dyn_cast().AsString()); if (!sym_op) continue; symbols.push_back(sym_op); @@ -90,8 +90,8 @@ bool ShapeConstraintIRAnalysis::IsShapeEqual(Value lhs, Value rhs) { lhs_it->second.size() != rhs_it->second.size()) return false; - std::vector lhs_syms; - std::vector rhs_syms; + std::vector lhs_syms; + std::vector rhs_syms; for (auto sym : lhs_it->second) { lhs_syms.push_back(mgr_.GetRootSymbolicDim(sym)); } diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index 0842313962d36..9ac479548465d 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -76,7 +76,7 @@ class ShapeConstraintIRAnalysis : public ShapeAnalysis { SymbolicDimMgr mgr_; // Map a ranked memref value to an array of symbolicDims, each represents one // dimension size of the memref value. - std::unordered_map> + std::unordered_map> value_to_sym_dims_; }; diff --git a/paddle/pir/dialect/shape/utils/symbol_table.cc b/paddle/pir/dialect/shape/utils/symbol_table.cc index c4ed0807b0b43..74c60f3f6b163 100644 --- a/paddle/pir/dialect/shape/utils/symbol_table.cc +++ b/paddle/pir/dialect/shape/utils/symbol_table.cc @@ -18,13 +18,13 @@ namespace pir { const std::string SymbolTable::insert(Operation* symbol) { std::string name; - if (symbol->isa()) { - name = symbol->dyn_cast().GetSymName(); + if (symbol->isa()) { + name = symbol->dyn_cast().GetSymName(); symbol_table_map_.insert({name, symbol}); } - // TODO(liujinnan): add more constraint_func name branch. - if (symbol->isa()) { + // TODO(zhangbopd): add more constraint_func name branch. + if (symbol->isa()) { name = "tie_product_equal"; symbol_func_map_[name].emplace_back(symbol); } diff --git a/paddle/pir/dialect/shape/utils/symbol_table.h b/paddle/pir/dialect/shape/utils/symbol_table.h index f85ba2cfb8099..2c71a142c78d1 100644 --- a/paddle/pir/dialect/shape/utils/symbol_table.h +++ b/paddle/pir/dialect/shape/utils/symbol_table.h @@ -28,22 +28,22 @@ namespace pir { -using dialect::SymbolicDim; +using shape::SymbolicDimOp; class SymbolTable { public: explicit SymbolTable(Operation* symbol_table_op) : symbol_table_op_(symbol_table_op) {} SymbolTable() = default; template - typename std::enable_if::value, - SymbolicDim>::type + typename std::enable_if::value, + SymbolicDimOp>::type Lookup(const std::string& name) const { auto it = symbol_table_map_.find(name); - return it != symbol_table_map_.end() ? it->second->dyn_cast() - : SymbolicDim(nullptr); + return it != symbol_table_map_.end() ? it->second->dyn_cast() + : SymbolicDimOp(nullptr); } template - typename std::enable_if::value, + typename std::enable_if::value, std::vector>::type Lookup(const std::string& name) const { std::vector res; diff --git a/paddle/pir/pass/CMakeLists.txt b/paddle/pir/pass/CMakeLists.txt deleted file mode 100644 index 92f7de3531cf4..0000000000000 --- a/paddle/pir/pass/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -file(GLOB NEW_PASS_SRCS "*.cc") - -ir_library(pir_pass SRCS ${NEW_PASS_SRCS} DEPS pir_core) diff --git a/paddle/pir/pass/ir_printing.cc b/paddle/pir/pass/ir_printing.cc index 6171b71c090fc..901c8bdd89da7 100644 --- a/paddle/pir/pass/ir_printing.cc +++ b/paddle/pir/pass/ir_printing.cc @@ -31,12 +31,8 @@ void PrintIR(Operation *op, bool print_module, std::ostream &os) { return; } - // Find the top-level operation. - auto *top_op = op; - while (auto *parent_op = top_op->GetParentOp()) { - top_op = parent_op; - } - top_op->Print(os); + auto *program = op->GetParentProgram(); + program->Print(os); } } // namespace diff --git a/paddle/pir/pass/pass.h b/paddle/pir/pass/pass.h index f916fcbb1e354..955947896ff32 100644 --- a/paddle/pir/pass/pass.h +++ b/paddle/pir/pass/pass.h @@ -18,10 +18,8 @@ #include #include -#include "paddle/phi/core/enforce.h" #include "paddle/pir/core/enforce.h" #include "paddle/pir/pass/analysis_manager.h" -#include "paddle/pir/pass/pass_registry.h" namespace pir { @@ -85,7 +83,7 @@ class IR_API Pass { protected: virtual void Run(Operation* op) = 0; - virtual inline bool CanApplyOn(Operation* op) const; + virtual bool CanApplyOn(Operation* op) const; virtual bool Initialize(IrContext* context) { return true; } diff --git a/paddle/pir/pass/pass_manager.h b/paddle/pir/pass/pass_manager.h index f606be139c42f..92faed24f1f5d 100644 --- a/paddle/pir/pass/pass_manager.h +++ b/paddle/pir/pass/pass_manager.h @@ -20,13 +20,13 @@ #include #include "paddle/pir/core/program.h" +#include "paddle/pir/pass/pass.h" namespace pir { class IrContext; class Operation; class Program; -class Pass; class PassInstrumentation; class PassInstrumentor; diff --git a/paddle/pir/pass/pass_registry.h b/paddle/pir/pass/pass_registry.h index 71140810b0324..88dbfa443ddc3 100644 --- a/paddle/pir/pass/pass_registry.h +++ b/paddle/pir/pass/pass_registry.h @@ -21,9 +21,8 @@ #include "paddle/pir/core/enforce.h" #include "paddle/pir/core/macros.h" #include "paddle/pir/pass/pass.h" -namespace pir { -class Pass; +namespace pir { using PassCreator = std::function()>; @@ -79,26 +78,26 @@ class PassRegistrar { msg) // Register a new pass that can be applied on the IR. -#define REGISTER_IR_PASS(pass_type, pass_class) \ - STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ - __reg_pass__##pass_type, \ - "REGISTER_IR_PASS must be called in global namespace"); \ - static ::pir::PassRegistrar __pass_registrar_##pass_type##__( \ - #pass_type); \ - int TouchPassRegistrar_##pass_type() { \ - __pass_registrar_##pass_type##__.Touch(); \ - return 0; \ - } \ - static ::pir::PassRegistrar \ - &__pass_tmp_registrar_##pass_type##__ UNUSED = \ - __pass_registrar_##pass_type##__ - -#define USE_PASS(pass_type) \ - STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ - __use_pass_itself_##pass_type, \ - "USE_PASS must be called in global namespace"); \ - extern int TouchPassRegistrar_##pass_type(); \ - static int use_pass_itself_##pass_type##_ UNUSED = \ - TouchPassRegistrar_##pass_type() +#define REGISTER_IR_PASS(pass_type, pass_class) \ + STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ + __reg_pir_pass__##pass_type, \ + "REGISTER_IR_PASS must be called in global namespace"); \ + static ::pir::PassRegistrar \ + __pir_pass_registrar_##pass_type##__(#pass_type); \ + int TouchPirPassRegistrar_##pass_type() { \ + __pir_pass_registrar_##pass_type##__.Touch(); \ + return 0; \ + } \ + static ::pir::PassRegistrar \ + &__pir_ass_tmp_registrar_##pass_type##__ UNUSED = \ + __pir_pass_registrar_##pass_type##__ + +#define USE_PIR_PASS(pass_type) \ + STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ + __use_pir_pass_itself_##pass_type, \ + "USE_PASS must be called in global namespace"); \ + extern int TouchPirPassRegistrar_##pass_type(); \ + static int use_pir_pass_itself_##pass_type##_ UNUSED = \ + TouchPirPassRegistrar_##pass_type() } // namespace pir diff --git a/paddle/pir/pattern_rewrite/CMakeLists.txt b/paddle/pir/pattern_rewrite/CMakeLists.txt deleted file mode 100644 index 27e939f5d05b9..0000000000000 --- a/paddle/pir/pattern_rewrite/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -file(GLOB PATTERN_SRCS "*.cc") - -ir_library(pir_pattern_rewrite SRCS ${PATTERN_SRCS} DEPS pir_core) diff --git a/paddle/pir/pattern_rewrite/pattern_match.cc b/paddle/pir/pattern_rewrite/pattern_match.cc index 028d0779dbf94..7b775ba498581 100644 --- a/paddle/pir/pattern_rewrite/pattern_match.cc +++ b/paddle/pir/pattern_rewrite/pattern_match.cc @@ -116,46 +116,60 @@ void RewriterBase::ReplaceOpWithIf( void RewriterBase::ReplaceOp(Operation* op, const std::vector& new_values) { + // Notify that the rewriter subclass we're about to replace this root. NotifyRootReplaced(op, new_values); + IR_ENFORCE(op->num_results() == new_values.size(), "incorrect # of replacement values"); op->ReplaceAllUsesWith(new_values); + NotifyOperationRemoved(op); - op->GetParent()->erase(*op); + op->Erase(); } void RewriterBase::EraseOp(Operation* op) { - // TODO(wilber): Operation support use_empty. - // IR_ENFORCE(op->use_empty(), "expected 'op' to have no uses"); + IR_ENFORCE(op->use_empty(), "expected 'op' to have no uses"); NotifyOperationRemoved(op); - op->GetParent()->erase(*op); + op->Erase(); } -/// Find uses of `from` and replace it with `to` +// Find uses of `from` and replace it with `to`. void RewriterBase::ReplaceAllUsesWith(Value from, Value to) { - // TODO(wilber): Substitue a low level impl. - from.ReplaceAllUsesWith(to); + for (auto it = from.use_begin(); it != from.use_end();) + UpdateRootInplace(it.owner(), [&]() { (it++)->set_source(to); }); } -// TODO(wilber): iterator maybe should support modify inplace. +// Find uses of `from` and replace them with `to` if the `functor` returns true. void RewriterBase::ReplaceUseIf(Value from, Value to, std::function functor) { - // for (auto it = from.begin(); it != from.end(); ++it) { - // // // TODO: need a lvalue. - // if (functor(*it)) { - // UpdateRootInplace(it.owner(), [&](){it.get().set(to)}); - // } + // Use post-increment operator for iterator since set_source() will change + // `it`. + // TODO(zhangbopd): Uncomment + // for (auto it = from.use_begin(); it != from.use_end();) { + // if (functor(*it)) + // UpdateRootInplace(it.owner(), [&]() { (it++)->set_source(to); }); // } } +// Replace theuses of op with uses of new_op. +// 'op' and 'new_op' are known to have the same number of results void RewriterBase::ReplaceOpWithResultsOfAnotherOp(Operation* op, Operation* new_op) { IR_ENFORCE(op->num_results() == new_op->num_results(), "replacement op doesn't match results of original op"); - // TODO(wilber): Op support results method. - // if (op->num_results() == 1) return ReplaceOp(op, - // new_op->result(0)); return ReplaceOp(op, new_op->GetResults()); + // TODO(zhangbopd): Uncomment + // if (op->num_results() == 1) { + // std::vector new_values; + // new_values.push_back(new_op->result(0)); + // return ReplaceOp(op, new_values); + // } + + // std::vector new_values; + // for (auto res : new_op->results()) { + // new_values.push_back(res); + // } + // return ReplaceOp(op, new_values); } } // namespace pir diff --git a/paddle/pir/pattern_rewrite/pattern_match.h b/paddle/pir/pattern_rewrite/pattern_match.h index 9e7553f4217ca..c1415606c3b24 100644 --- a/paddle/pir/pattern_rewrite/pattern_match.h +++ b/paddle/pir/pattern_rewrite/pattern_match.h @@ -272,9 +272,16 @@ class RewriterBase : public Builder { virtual void ReplaceOp(Operation* op, const std::vector& new_values); - // template - // OpTy ReplaceOpWithNewOp(Operation *op, Args &&...args); + // Replaces the result op with a new op. + // The result values of the two ops must be the same types. + template + OpTy ReplaceOpWithNewOp(Operation* op, Args&&... args) { + auto new_op = Build(std::forward(args)...); + ReplaceOpWithResultsOfAnotherOp(op, new_op.operation()); + return new_op; + } + // This method erases an operation that is known to have no uses. virtual void EraseOp(Operation* op); IR_API void ReplaceAllUsesWith(Value from, Value to); @@ -327,6 +334,7 @@ class RewritePatternSet { public: explicit RewritePatternSet(IrContext* context) : context_(context) {} + // Construct a RewritePatternSet with the given patterns. RewritePatternSet(IrContext* context, std::unique_ptr pattern) : context_(context) { native_patterns_.emplace_back(std::move(pattern)); @@ -344,7 +352,7 @@ class RewritePatternSet { typename... ConstructorArgs, typename = std::enable_if_t> RewritePatternSet& Add(ConstructorArg&& arg, ConstructorArgs&&... args) { - std::initializer_list{ + (void)std::initializer_list{ (AddImpl({}, std::forward(arg), std::forward(args)...), @@ -359,7 +367,7 @@ class RewritePatternSet { RewritePatternSet& AddWithLabel(const std::vector& debug_labels, ConstructorArg&& arg, ConstructorArgs&&... args) { - std::initializer_list{ + (void)std::initializer_list{ (AddImpl(debug_labels, std::forward(arg), std::forward(args)...), diff --git a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc index 00d6cb2f4d306..ff75f86d6da55 100644 --- a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc +++ b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc @@ -131,6 +131,7 @@ class GreedyPatternRewriteDriver : public pir::PatternRewriter { for (uint32_t i = 0; i < op->num_operands(); ++i) { AddOperandToWorklist(op->operand_source(i)); } + if (op->num_regions() == 0) { RemoveFromWorklist(op); } else { diff --git a/paddle/pir/transforms/CMakeLists.txt b/paddle/pir/transforms/CMakeLists.txt deleted file mode 100644 index 4f9f0fa196e9a..0000000000000 --- a/paddle/pir/transforms/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -file(GLOB PATTERN_SRCS "*.cc") - -ir_library( - pir_builtin_transforms - SRCS - ${PATTERN_SRCS} - DEPS - pir_core - pir_pattern_rewrite - pir_pass) diff --git a/paddle/pir/transforms/dead_code_elimination_pass.cc b/paddle/pir/transforms/dead_code_elimination_pass.cc deleted file mode 100644 index 6216fca5037e1..0000000000000 --- a/paddle/pir/transforms/dead_code_elimination_pass.cc +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/pir/transforms/dead_code_elimination_pass.h" -#include "paddle/pir/core/builtin_op.h" -#include "paddle/pir/core/program.h" -#include "paddle/pir/pass/pass.h" -#include "paddle/pir/pass/pass_registry.h" - -namespace { - -// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be -// removed by dce pass. -// Now just a naive implementation. -class DeadCodeEliminationPass : public pir::Pass { - public: - DeadCodeEliminationPass() : pir::Pass("dead_code_elimination", 0) {} - - void Run(pir::Operation *op) override { - auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "DcePass should run on module op."); - auto *block = module_op.block(); - std::vector erased_op; - for (auto &op : *block) { - // TODO(wilber): Support NoSideEffect trait. - // if (!op->HasTrait()) continue; - - bool use_empty = true; - for (uint32_t i = 0; i < op->num_results(); ++i) { - use_empty &= op->result(i).use_empty(); - } - // TODO(wilber): Support Terminator trait. - if (use_empty && op->name() != "pd_op.fetch") { - erased_op.push_back(op); - } - } - - for (auto *op : erased_op) { - if (op->dyn_cast()) { - // Delete parameter from program. - pir::GetParameterOp get_parameter_op = - op->dyn_cast(); - get_parameter_op->GetParentProgram()->parameters().erase( - get_parameter_op->attributes() - .at(get_parameter_op.attributes_name[0]) - .dyn_cast() - .AsString()); - } - block->erase(*op); - } - } - - bool CanApplyOn(pir::Operation *op) const override { - return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; - } -}; - -} // namespace - -namespace pir { - -std::unique_ptr CreateDeadCodeEliminationPass() { - return std::make_unique(); -} - -} // namespace pir - -REGISTER_IR_PASS(dead_code_elimination, DeadCodeEliminationPass); diff --git a/paddle/pir/transforms/reorder_block_ops_pass.cc b/paddle/pir/transforms/reorder_block_ops_pass.cc deleted file mode 100644 index db2d29fe9b0a7..0000000000000 --- a/paddle/pir/transforms/reorder_block_ops_pass.cc +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/pir/transforms/reorder_block_ops_pass.h" - -#include - -#include "paddle/pir/core/builtin_op.h" -#include "paddle/pir/core/program.h" -#include "paddle/pir/pass/pass.h" - -namespace { - -class ReorderBlockOpsPass : public pir::Pass { - public: - ReorderBlockOpsPass() : pir::Pass("ReorderBlockOpsPass", 0) {} - - void Run(pir::Operation *op) override { - IR_ENFORCE(op->num_regions() > 0, - "ReorderBlockOpsPass should run on Operation which regions " - "number greater than 0."); - for (size_t i = 0; i < op->num_regions(); ++i) { - for (auto *block : op->region(i)) { - std::list res_op_list; - std::unordered_map - reorder_op_dep_cnt; // op -> dependent input count - std::unordered_set visited_values; - std::queue op_que; - - auto update_op_que = [&](pir::Operation *op) { - for (size_t i = 0; i < op->results().size(); ++i) { - auto result = op->result(i); - visited_values.insert(result); - for (auto it = result.use_begin(); it != result.use_end(); ++it) { - if (reorder_op_dep_cnt.count(it->owner())) { - reorder_op_dep_cnt[it->owner()]--; - if (reorder_op_dep_cnt[it->owner()] == 0) { - op_que.push(it->owner()); - } - } - } - } - }; - - for (auto &op : *block) { - bool has_dependency = false; - if (op->num_operands() > 0) { - for (size_t i = 0; i < op->num_operands(); ++i) { - auto operand = op->operand_source(i); - if (operand && visited_values.count(op->operand_source(i)) == 0) { - reorder_op_dep_cnt[op]++; - has_dependency = true; - } - } - } - if (!has_dependency) { - res_op_list.push_back(op); - update_op_que(op); - } - } - - if (reorder_op_dep_cnt.empty()) { - return; - } - - while (!op_que.empty()) { - auto *op = op_que.front(); - op_que.pop(); - res_op_list.push_back(op); - update_op_que(op); - } - VLOG(4) << "ReorderBlockOpsPass is applied."; - block->ResetOpListOrder(res_op_list); - } - } - } - - bool CanApplyOn(pir::Operation *op) const override { - return op->num_regions() > 0; - } -}; - -} // namespace - -namespace pir { - -std::unique_ptr CreateReorderBlockOpsPass() { - return std::make_unique(); -} - -} // namespace pir diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 268fb5f0a482c..29db4cd02e4d9 100644 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -626,7 +626,7 @@ EOF function run_mac_test() { - export FLAGS_NEW_IR_OPTEST=True + export FLAGS_PIR_OPTEST=True export FLAGS_CI_PIPELINE=mac mkdir -p ${PADDLE_ROOT}/build cd ${PADDLE_ROOT}/build @@ -776,7 +776,7 @@ EOF } function run_linux_cpu_test() { - export FLAGS_NEW_IR_OPTEST=True + export FLAGS_PIR_OPTEST=True export FLAGS_CI_PIPELINE=py3 mkdir -p ${PADDLE_ROOT}/build cd ${PADDLE_ROOT}/build @@ -988,6 +988,7 @@ function run_sot_test() { export COST_MODEL=False export MIN_GRAPH_SIZE=0 export SOT_LOG_LEVEL=0 + export FLAGS_cudnn_deterministic=True # Install PaddlePaddle $PYTHON_WITH_SPECIFY_VERSION -m pip install ${PADDLE_ROOT}/dist/paddlepaddle-0.0.0-cp${PY_VERSION_NO_DOT}-cp${PY_VERSION_NO_DOT}-linux_x86_64.whl @@ -1253,7 +1254,7 @@ EOF if [ "${APPROVALS}" == "FALSE" ]; then echo "==========================================================================================" echo "This PR make the release inference library size growth exceeds 20 M." - echo "Then you must have one RD (vivienfanghuagood (Recommend), Aurelius84 (For NewIR) qingqing01 or yuanlehome) approval for this PR.\n" + echo "Then you must have one RD (vivienfanghuagood (Recommend), Aurelius84 (ForPir) qingqing01 or yuanlehome) approval for this PR.\n" echo "==========================================================================================" exit 6 fi @@ -1406,6 +1407,8 @@ function get_quickly_disable_ut() { echo ${disable_ut_quickly} echo "=========================================" else + + exit 102 disable_ut_quickly='disable_ut' fi } @@ -3379,14 +3382,16 @@ function build_pr_and_develop() { mkdir ${PADDLE_ROOT}/build/dev_whl && wget -q -P ${PADDLE_ROOT}/build/dev_whl ${dev_url} cp ${PADDLE_ROOT}/build/dev_whl/paddlepaddle_gpu-0.0.0-cp310-cp310-linux_x86_64.whl ${PADDLE_ROOT}/build/python/dist else + cp -r ${PADDLE_ROOT}/build /tmp/ if [[ ${cmake_change} ]];then rm -rf ${PADDLE_ROOT}/build/Makefile ${PADDLE_ROOT}/build/CMakeCache.txt ${PADDLE_ROOT}/build/build.ninja rm -rf ${PADDLE_ROOT}/build/third_party fi - git checkout -b develop_base_pr upstream/$BRANCH git submodule update --init run_setup ${PYTHON_ABI:-""} "rerun-cmake bdist_wheel" ${parallel_number} + rm -rf ${PADDLE_ROOT}/build + mv /tmp/build ${PADDLE_ROOT} if [ ! -d "${PADDLE_ROOT}/build/python/dist/" ]; then mkdir ${PADDLE_ROOT}/build/python/dist/ fi @@ -4088,7 +4093,7 @@ function main() { check_coverage_build ;; gpu_cicheck_coverage) - export FLAGS_NEW_IR_OPTEST=True + export FLAGS_PIR_OPTEST=True parallel_test check_coverage ;; diff --git a/paddle/testing/paddle_gtest_main.cc b/paddle/testing/paddle_gtest_main.cc index 667045eaebf97..8e615f7a6cb11 100644 --- a/paddle/testing/paddle_gtest_main.cc +++ b/paddle/testing/paddle_gtest_main.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" -#include "paddle/fluid/framework/phi_utils.h" +#include "paddle/fluid/framework/init_default_kernel_signature_map.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/platform/init.h" #include "paddle/phi/core/flags.h" diff --git a/paddle/utils/pybind.h b/paddle/utils/pybind.h index 67927031594e0..065cd49297ab4 100644 --- a/paddle/utils/pybind.h +++ b/paddle/utils/pybind.h @@ -15,6 +15,9 @@ #pragma once #include "paddle/phi/api/include/tensor.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#endif #include "paddle/utils/optional.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -74,8 +77,16 @@ struct type_caster { static handle cast(const paddle::Tensor& src, return_value_policy /* policy */, handle /* parent */) { + // TODO(GhostScreaming): pipeline parallel may return a uninitialized + // DistTensor, it should not return None. +#ifdef PADDLE_WITH_DISTRIBUTE + bool return_none = + phi::distributed::DistTensor::classof(src.impl().get()) ? false : true; +#else + bool return_none = true; +#endif return handle(paddle::pybind::ToPyObject( - src, true /* return_py_none_if_not_initialize */)); + src, return_none /* return_py_none_if_not_initialize */)); } }; diff --git a/pyproject.toml b/pyproject.toml index 393d46b6f8a5f..86b8ee2c80403 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,22 +103,3 @@ ignore = [ "test/dygraph_to_static/test_loop.py" = ["C416", "F821"] # Ignore unnecessary lambda in dy2st unittest test_lambda "test/dygraph_to_static/test_lambda.py" = ["PLC3002"] - -# Temporarily ignored -"python/paddle/base/**" = [ - "UP030", - "B019", # Confirmation required - "C416", - "F821", -] - -# B017 -"test/auto_parallel/spmd_rules/test_reshape_rule.py" = ["B017"] -"test/dygraph_to_static/test_assert.py" = ["B017"] -"test/legacy_test/test_cuda_max_memory_allocated.py" = ["B017"] -"test/legacy_test/test_cuda_max_memory_reserved.py" = ["B017"] -"test/legacy_test/test_cuda_memory_allocated.py" = ["B017"] -"test/legacy_test/test_cuda_memory_reserved.py" = ["B017"] -"test/legacy_test/test_eigvals_op.py" = ["B017"] -"test/legacy_test/test_tensordot.py" = ["B017"] -"test/legacy_test/test_top_k_v2_op.py" = ["B017"] diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 5b3e806c3f947..92ac2fcbb5c34 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -33,11 +33,12 @@ # the illogical implement in the monkey-patch methods later. from .framework import monkey_patch_variable from .framework import monkey_patch_math_tensor -from .pir import monkey_patch_opresult +from .pir import monkey_patch_opresult, monkey_patch_program monkey_patch_variable() monkey_patch_math_tensor() monkey_patch_opresult() +monkey_patch_program() from .framework import ( disable_signal_handler, @@ -71,7 +72,6 @@ Tensor.__qualname__ = 'Tensor' import paddle.distributed.fleet # noqa: F401 - from paddle import ( # noqa: F401 distributed, sysconfig, @@ -113,6 +113,7 @@ create_parameter, to_tensor, diag, + diag_embed, diagflat, eye, linspace, @@ -252,6 +253,10 @@ view, view_as, unfold, + masked_fill, + masked_fill_, + index_fill, + index_fill_, ) from .tensor.math import ( # noqa: F401 @@ -407,6 +412,8 @@ i1e, polygamma, polygamma_, + hypot, + hypot_, ) from .tensor.random import ( @@ -566,6 +573,7 @@ 'subtract', 'diag', 'diagflat', + 'diag_embed', 'isnan', 'scatter_nd_add', 'unstack', @@ -904,4 +912,10 @@ 'i1e', 'polygamma', 'polygamma_', + 'masked_fill', + 'masked_fill_', + 'hypot', + 'hypot_', + 'index_fill', + "index_fill_", ] diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 81f5a579eeb44..f12ad697b049c 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -245,6 +245,7 @@ def check_models(models): def _is_valid_optimizer(optimizer): from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( DygraphShardingOptimizer, + DygraphShardingOptimizerV2, ) return isinstance( @@ -252,6 +253,7 @@ def _is_valid_optimizer(optimizer): ( paddle.optimizer.Optimizer, DygraphShardingOptimizer, + DygraphShardingOptimizerV2, ), ) @@ -359,37 +361,38 @@ def amp_guard( % tracer._expected_place ) enable = False - # For xpu: - if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'): - warnings.warn('XPUPlace only support float16 amp.') - enable = False - # For custom device: - if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'): - warnings.warn('CustomPlace only support float16 amp.') - enable = False - # For gpu float16: Compute Capability should >= 7. - # For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11. - if tracer._expected_place.is_gpu_place(): - if (dtype == 'float16') and not _is_gpu_float16_supported(): - prop = paddle.device.cuda.get_device_capability() - warnings.warn( - "For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d." - % (paddle.device.cuda.get_device_name(), prop[0], prop[1]) - ) + if enable: + # For xpu: + if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'): + warnings.warn('XPUPlace only support float16 amp.') enable = False - elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported(): - prop = paddle.device.cuda.get_device_capability() - cuda_version = paddle.version.cuda() - warnings.warn( - "For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s." - % ( - paddle.device.cuda.get_device_name(), - prop[0], - prop[1], - cuda_version, - ) - ) + # For custom device: + if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'): + warnings.warn('CustomPlace only support float16 amp.') enable = False + # For gpu float16: Compute Capability should >= 7. + # For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11. + if tracer._expected_place.is_gpu_place(): + if (dtype == 'float16') and not _is_gpu_float16_supported(): + prop = paddle.device.cuda.get_device_capability() + warnings.warn( + "For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d." + % (paddle.device.cuda.get_device_name(), prop[0], prop[1]) + ) + enable = False + elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported(): + prop = paddle.device.cuda.get_device_capability() + cuda_version = paddle.version.cuda() + warnings.warn( + "For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s." + % ( + paddle.device.cuda.get_device_name(), + prop[0], + prop[1], + cuda_version, + ) + ) + enable = False amp_dtype = dtype amp_global_state().amp_dtype = amp_dtype @@ -482,11 +485,14 @@ def __call__(self, state_dict): def _set_multi_precision(optimizer, multi_precision): from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( DygraphShardingOptimizer, + DygraphShardingOptimizerV2, ) optimizer = ( optimizer._inner_opt - if isinstance(optimizer, DygraphShardingOptimizer) + if isinstance( + optimizer, (DygraphShardingOptimizer, DygraphShardingOptimizerV2) + ) else optimizer ) if hasattr(optimizer, "_multi_precision"): diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index ddfba7a22b12b..39d73cc54e9ac 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -105,12 +105,12 @@ def prepare_grad_outputs(grad_outputs, outputs, state): if output.shape != grad.shape: raise ValueError( "The shape of grad_output[%d] %s should be the same as the shape of output[%d] %s" - % (i, str(output.shape), i, str(grad.shape)) + % (i, str(grad.shape), i, str(output.shape)) ) if output.dtype != grad.dtype: raise ValueError( "The dtype of grad_output[%d] %s should be the same as the dtype of output[%d] %s" - % (i, str(output.dtype), i, str(grad.dtype)) + % (i, str(grad.dtype), i, str(output.dtype)) ) feedop = grad.get_defining_op() update_bwdop_structure( @@ -328,7 +328,7 @@ def append_backward_ops( if op has grad_op, prepare its grad_op's inputs by value_to_valuegrad, eg: value_to_valuegrad[v3] = [[v3_g]]; - v2_g = call_vjp(op3, [v3_g], [v2_stopgradient]) + v2_g = call_vjp(op3, [[v2]], [[v3]],[[v3_g]], [[v2_stopgradient]]) special pattern 1: @@ -339,7 +339,7 @@ def append_backward_ops( v1 is inside python api, we don't describe it in backward process(state) so v1_grad is inside vjp, we don't describe it in backward process(state) - [[v11_g, v12_g], v2_g] = call_vjp(combine_op, [v3_g], [[v11_stopgradient, v12_stopgradient], v2_stop_gradient) + [[v11_g, v12_g], v2_g] = call_vjp(combine_op, [[v11, v12]], [[v3]],[[v3_g]], [[v11_stopgradient, v12_stopgradient], v2_stop_gradient]) op_vjp is: @@ -358,14 +358,16 @@ def append_backward_ops( else continue to next op. ''' - def make_output_grad(op): + def make_output_with_output_grad(op): zero_flag = [False] * op.num_results() + outputs = [] output_grads = [] for i, value in enumerate(op.results()): + new_value = [value] if ( value in state.value_to_valuegrad - and len(state.value_to_valuegrad[value]) - ) > 1: + and len(state.value_to_valuegrad[value]) > 1 + ): # one value is input of more than one fwd_op, # so more than one bwd_op create input_grad, # need add sum op to accumulate gradient @@ -396,12 +398,15 @@ def make_output_grad(op): # pattern case: # this fwd_op's output is vectorType, it will split to # Type by builtin.split op, so need get from split op's ouput - split_zero_flag, split_output_grad = make_output_grad( - value.first_use().owner() - ) + ( + split_zero_flag, + split_outputs, + split_output_grad, + ) = make_output_with_output_grad(value.first_use().owner()) zero_flag[i] = all(split_zero_flag) grad_values = [value[0] for value in split_output_grad] state.value_to_valuegrad[value] = [grad_values] + new_value = [info[0] for info in split_outputs] else: # first case: # this fwd_op's output didn't used by other fwd_op, @@ -424,35 +429,54 @@ def make_output_grad(op): state.value_to_valuegrad[value] = [[grad_value]] + outputs.append(new_value) output_grads.append(state.value_to_valuegrad[value][0]) - return zero_flag, output_grads + return zero_flag, outputs, output_grads - def make_input_stopgradient(op): + def make_input_with_input_stopgradient(op): + inputs = [] input_grad_stopgradients = [] if op.name() == "builtin.combine": grad_semantic_info = [True for _ in range(op.num_operands())] else: grad_semantic_info = op.get_input_grad_semantics() + for input, grad_semantic in zip( op.operands_source(), grad_semantic_info ): if not grad_semantic: + if ( + input.get_defining_op() is not None + and input.get_defining_op().name() == "builtin.combine" + ): + inputs.append( + list(input.get_defining_op().operands_source()) + ) + else: + inputs.append([input]) continue + if ( input.get_defining_op() is not None and input.get_defining_op().name() == "builtin.combine" ): - stop_gradient = make_input_stopgradient(input.get_defining_op()) + ( + combine_inputs, + combine_stop_gradient, + ) = make_input_with_input_stopgradient(input.get_defining_op()) + inputs.append([info[0] for info in combine_inputs]) input_grad_stopgradients.append( - [info[0] for info in stop_gradient] + [info[0] for info in combine_stop_gradient] ) else: + inputs.append([input]) if input.get_defining_op() is None or input in no_grad_set: input_grad_stopgradients.append([True]) else: input_grad_stopgradients.append([False]) - return input_grad_stopgradients + + return inputs, input_grad_stopgradients def update_input_grad_map(op, input_grads): i = 0 @@ -494,7 +518,7 @@ def update_input_grad_map(op, input_grads): for op in clear_effective_forward_ops: if paddle.framework.core.has_vjp(op): # prepare output_grad - zero_flag, output_grads = make_output_grad(op) + zero_flag, outputs, output_grads = make_output_with_output_grad(op) # all(zero_flag) support this op has no contribution for grad # should be delete (prune sub_graph) @@ -502,12 +526,15 @@ def update_input_grad_map(op, input_grads): continue # prepare input_grad stop_gradient info. - input_grad_stopgradients = make_input_stopgradient(op) + ( + inputs, + input_grad_stopgradients, + ) = make_input_with_input_stopgradient(op) # create grad_op before_ops_num = len(block.ops) input_grads = paddle.framework.core.call_vjp( - op, output_grads, input_grad_stopgradients + op, inputs, outputs, output_grads, input_grad_stopgradients ) after_ops_num = len(block.ops) diff --git a/python/paddle/base/__init__.py b/python/paddle/base/__init__.py index 5bab0d5cf84f0..7e5ac9c1d92c4 100644 --- a/python/paddle/base/__init__.py +++ b/python/paddle/base/__init__.py @@ -209,6 +209,7 @@ def remove_flag_if_exists(name): # NOTE(Aurelius84): clean up ExecutorCacheInfo in advance manually. atexit.register(core.clear_executor_cache) +atexit.register(core.pir.clear_pir_compiler_manager) # NOTE(Aganlengzi): clean up KernelFactory in advance manually. # NOTE(wangran16): clean up DeviceManager in advance manually. diff --git a/python/paddle/base/backward.py b/python/paddle/base/backward.py index 876db0abc3aa7..e62a5b9245a1b 100755 --- a/python/paddle/base/backward.py +++ b/python/paddle/base/backward.py @@ -2348,7 +2348,7 @@ def _find_op_path_( # If block is while block, dealing with op specifically again. # TODO(liym27): Consider special types of ops. for i, op in reversed(list(enumerate(block.ops))): - if relevant_op_flags[i] == False and _some_in_set_( + if relevant_op_flags[i] is False and _some_in_set_( op.desc.output_arg_names(), output_names ): relevant_op_flags[i] = True diff --git a/python/paddle/base/data_feeder.py b/python/paddle/base/data_feeder.py index 2449f456fdc66..1be2509c19bbe 100644 --- a/python/paddle/base/data_feeder.py +++ b/python/paddle/base/data_feeder.py @@ -17,6 +17,7 @@ import numpy as np from ..pir import OpResult +from ..pir.core import ParameterMeta from . import core from .framework import ( Variable, @@ -44,7 +45,7 @@ core.VarDesc.VarType.COMPLEX128: 'complex128', } -_PADDLE_NEW_IR_DTYPE_2_NUMPY_DTYPE = { +_PADDLE_PIR_DTYPE_2_NUMPY_DTYPE = { core.DataType.BOOL: 'bool', core.DataType.FLOAT16: 'float16', core.DataType.BFLOAT16: 'uint16', @@ -91,8 +92,8 @@ def convert_dtype(dtype): if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE: return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] if isinstance(dtype, core.DataType): - if dtype in _PADDLE_NEW_IR_DTYPE_2_NUMPY_DTYPE: - return _PADDLE_NEW_IR_DTYPE_2_NUMPY_DTYPE[dtype] + if dtype in _PADDLE_PIR_DTYPE_2_NUMPY_DTYPE: + return _PADDLE_PIR_DTYPE_2_NUMPY_DTYPE[dtype] elif isinstance(dtype, type): # This branch is for NumPy scalar types if dtype in [ @@ -147,7 +148,9 @@ def check_variable_and_dtype( input, input_name, expected_dtype, op_name, extra_message='' ): if in_pir_mode(): - check_type(input, input_name, OpResult, op_name, extra_message) + check_type( + input, input_name, (OpResult, ParameterMeta), op_name, extra_message + ) else: check_type(input, input_name, Variable, op_name, extra_message) check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message) @@ -177,9 +180,7 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''): elif isinstance(input, core.eager.Tensor): raise TypeError( "Please use `with base.dygraph.guard()` as context or `base.enable_dygraph()` to switch to imperative mode firstly. " - "Because received '{}' in {} is a imperative Variable.".format( - input_name, op_name - ) + f"Because received '{input_name}' in {op_name} is a imperative Variable." ) if not isinstance(input, expected_type): raise TypeError( diff --git a/python/paddle/base/device_worker.py b/python/paddle/base/device_worker.py index 755f7257b735a..c20677f6acd5e 100644 --- a/python/paddle/base/device_worker.py +++ b/python/paddle/base/device_worker.py @@ -450,7 +450,7 @@ def _gen_worker_desc(self, trainer_desc): if ( opt_info["use_cvm"] or "no_cvm" in opt_info - and opt_info["no_cvm"] == True + and opt_info["no_cvm"] is True ): sparse_table.emb_dim = self._fleet_desc.server_param.downpour_server_param.downpour_table_param[ i @@ -560,7 +560,7 @@ def _gen_worker_desc(self, trainer_desc): if ( opt_info["use_cvm"] or "no_cvm" in opt_info - and opt_info["no_cvm"] == True + and opt_info["no_cvm"] is True ): sparse_table.emb_dim = self._fleet_desc.server_param.downpour_server_param.downpour_table_param[ i diff --git a/python/paddle/base/dygraph/base.py b/python/paddle/base/dygraph/base.py index 5fad89935d4c7..69ee71395b9aa 100644 --- a/python/paddle/base/dygraph/base.py +++ b/python/paddle/base/dygraph/base.py @@ -595,7 +595,7 @@ def guard(place=None): if place is not None: expected_place = _get_paddle_place(place) else: - expected_place = framework._current_expected_place() + expected_place = framework._current_expected_place_() with framework.program_guard(train, startup): with framework.unique_name.guard(): @@ -928,7 +928,7 @@ def to_variable(value, name=None, zero_copy=None, dtype=None): # (2): when used in flask framework, it may result in hang. # Details: https://github.com/PaddlePaddle/Paddle/issues/26635 # So, we temporally diable the zero_copy strategy. - if zero_copy == True: + if zero_copy is True: warnings.warn( "Currently, zero_copy is not supported, and it will be discarded." ) diff --git a/python/paddle/base/dygraph/math_op_patch.py b/python/paddle/base/dygraph/math_op_patch.py index 5972b545f93e2..172f73bf7f531 100644 --- a/python/paddle/base/dygraph/math_op_patch.py +++ b/python/paddle/base/dygraph/math_op_patch.py @@ -150,7 +150,7 @@ def _index_(var): return int(np.array(var)) @property - def _ndim_(var): + def _ndim(var): return len(var.shape) def ndimension(var): @@ -183,7 +183,7 @@ def _T_(var): ('astype', astype), ('dim', dim), ('ndimension', ndimension), - ('ndim', _ndim_), + ('ndim', _ndim), ('size', _size_), ('T', _T_), # for logical compare diff --git a/python/paddle/base/dygraph/tensor_patch_methods.py b/python/paddle/base/dygraph/tensor_patch_methods.py index b01c7a70e4406..1f5b414ebb559 100644 --- a/python/paddle/base/dygraph/tensor_patch_methods.py +++ b/python/paddle/base/dygraph/tensor_patch_methods.py @@ -869,7 +869,7 @@ def cuda(self, device_id=None, blocking=True): if self.place._equals(res_place): return self else: - res = self._copy_to(res_place, True) + res = self._copy_to(res_place, blocking) res.stop_gradient = self.stop_gradient res.persistable = self.persistable return res diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index f3185fb277ed7..4bc2312e225cf 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -21,7 +21,11 @@ import numpy as np +from paddle import pir + from ..pir import OpResult +from ..pir import Program as PirProgram +from ..pir import Value, translate_to_pir from . import compiler, core, framework, get_flags, set_flags, unique_name from .data_feeder import convert_dtype from .framework import ( @@ -513,7 +517,7 @@ def _add_pir_fetch_ops(program, fetch_list, fetch_var_name): with paddle.static.program_guard(program): for i, fetch_input in enumerate(fetch_list): assert isinstance( - fetch_input, OpResult + fetch_input, (OpResult, Value) ), f"Wrong type for fetch_list[{i}]: {type(fetch_input)}" paddle._pir_ops.fetch(fetch_input, fetch_var_name + str(i), i) @@ -594,6 +598,10 @@ def _to_str(var): return str(var) elif isinstance(var, Operator): return str(id(var)) + elif isinstance(var, OpResult): + return str(var) + elif isinstance(var, Value): + return str(var) else: raise TypeError(str(var) + " should be Variable, Operator or str") @@ -628,11 +636,18 @@ def _prepare_fleet_executor(): def _get_strong_program_cache_key_for_new_exe(program, scope, feed, fetch_list): - return ( - program.desc.cached_hash_str() - + str(scope.raw_address()) - + _get_program_cache_key(feed, fetch_list) - ) + if isinstance(program, PirProgram): + return ( + str(program) + + str(scope.raw_address()) + + _get_program_cache_key(feed, fetch_list) + ) + else: + return ( + program.desc.cached_hash_str() + + str(scope.raw_address()) + + _get_program_cache_key(feed, fetch_list) + ) def _get_strong_program_cache_key(program, feed, fetch_list): @@ -744,6 +759,11 @@ def _can_use_interpreter_core(program, place): return True +@lru_cache() +def _warning_once(msg): + logging.warning(msg) + + class FetchHandler: def __init__(self, var_dict=None, period_secs=60): assert var_dict is not None @@ -793,7 +813,11 @@ def run(self, feed_names, return_numpy=True): tensors = self._new_exe.run(feed_names)._move_to_list() if return_numpy: tensors = as_numpy(tensors, copy=True) - return _merge_tensors(tensors, self._plan.micro_batch_num()) + if not get_flags("FLAGS_enable_pir_in_executor")[ + 'FLAGS_enable_pir_in_executor' + ]: + return _merge_tensors(tensors, self._plan.micro_batch_num()) + return tensors else: if self._plan.micro_batch_num() > 1: raise RuntimeError( @@ -867,6 +891,9 @@ def __init__(self): self._get_cached_program_and_executor = lru_cache(maxsize=8)( self._get_program_and_executor ) + self._get_cached_program_and_executor_pir_mode = lru_cache(maxsize=8)( + self._get_pir_program_and_executor + ) def clear(self): self._get_cached_program_and_executor.cache_clear() @@ -971,7 +998,9 @@ def _get_program_and_executor(self, cached_data): else False ) - if os.getenv("FLAGS_enable_new_ir_in_executor"): + if get_flags('FLAGS_enable_pir_in_executor')[ + 'FLAGS_enable_pir_in_executor' + ]: # todo(phlrain), skip inplace add addto pass in new IR enable_inplace = False enable_addto = False @@ -999,9 +1028,28 @@ def _get_program_and_executor(self, cached_data): ) else: default_job = core.Job("default") - type_to_program = {"default": new_program.desc} + if get_flags("FLAGS_enable_pir_in_executor")[ + 'FLAGS_enable_pir_in_executor' + ]: + type_to_program = { + "default": translate_to_pir(new_program.desc) + } + else: + type_to_program = {"default": new_program.desc} plan = core.Plan([default_job], type_to_program) + if ( + new_program._pass_opt + and "pass_list" in new_program._pass_opt + and len(new_program._pass_opt['pass_list']) > 0 + ): + pm = pir.PassManager() + for p in new_program._pass_opt['pass_list']: + pm.add_pass(p) + for job_type in plan.job_types(): + ir_program = plan.ir_program(job_type) + pm.run(ir_program) + new_exe = _StandaloneExecutor(place, plan, scope) return new_program, new_exe @@ -1015,6 +1063,27 @@ def get_pir_program_and_executor( place, scope, ): + return self._get_cached_program_and_executor_pir_mode( + self._CachedData( + program, + feed, + fetch_list, + feed_var_name, + fetch_var_name, + place, + scope, + ) + ) + + def _get_pir_program_and_executor(self, cached_data): + program = cached_data.program + feed = cached_data.feed + fetch_list = cached_data.fetch_list + feed_var_name = cached_data.feed_var_name + fetch_var_name = cached_data.fetch_var_name + place = cached_data.place + scope = cached_data.scope + _add_pir_fetch_ops( program, fetch_list=fetch_list, fetch_var_name=fetch_var_name ) @@ -1177,10 +1246,8 @@ def _add_micro_scopes_cache(self, program_cache_key, micro_scopes: list): def _get_micro_scopes_cache(self, program_cache_key): return self.micro_scope_cache.get(program_cache_key, None) - # just for testing, will be removed later - @lru_cache() def _log_force_set_program_cache(self, use_program_cache): - logging.warning( + _warning_once( f"use_program_cache is force set to {use_program_cache} by FLAGS_FORCE_USE_PROGRAM_CACHE" ) @@ -1199,7 +1266,7 @@ def _feed_data(self, program, feed, feed_var_name, scope): ) check_feed_shape_type(var, cur_feed) idx = op.desc.attr('col') - pir_flag_name = 'FLAGS_enable_new_ir_in_executor' + pir_flag_name = 'FLAGS_enable_pir_in_executor' if get_flags(pir_flag_name)[pir_flag_name]: core.set_feed_variable( scope, cur_feed, feed_target_name, idx @@ -1695,7 +1762,7 @@ def _run_impl( if isinstance(program, Program) and program._heter_pipeline_opt: # print("program._heter_pipeline_opt: {}".format( # program._heter_pipeline_opt)) - ## change default executor + # change default executor heter_place = program._heter_pipeline_opt["heter_place"] heter_place = framework._get_paddle_place(heter_place) p = core.Place() @@ -1852,12 +1919,12 @@ def _run_impl( varobj = global_block.vars[varname] if ( - vardesc.persistable() == False + vardesc.persistable() is False and vardesc.type() == core.VarDesc.VarType.LOD_TENSOR - and vardesc.need_check_feed() == True - and varobj.stop_gradient == True - and varobj.is_data == True - and varobj.belong_to_optimizer == False + and vardesc.need_check_feed() is True + and varobj.stop_gradient is True + and varobj.is_data is True + and varobj.belong_to_optimizer is False and varname not in feed ): raise ValueError('Need feed data for variable %s' % varname) @@ -1940,7 +2007,9 @@ def _run_inference(self, exe, feed): return exe.run(feed) def _check_fetch_list(self, fetch_list): - is_fetch_var = lambda var: isinstance(var, (Variable, str, OpResult)) + is_fetch_var = lambda var: isinstance( + var, (Variable, str, OpResult, Value) + ) is_tuple_list = lambda var: isinstance(var, (tuple, list)) if fetch_list is None: @@ -1966,7 +2035,7 @@ def _check_fetch_list(self, fetch_list): res.append(var) else: raise TypeError( - "Require fetch_list[{}] 's type shall be one of (Variable, str), but received {}.".format( + "Require fetch_list[{}] 's type shall be one of (OpResult, str), but received {}.".format( i, type(var).__name__ ) ) @@ -2143,7 +2212,7 @@ def _prepare_trainer( ): is_heter = 0 use_ps_gpu = 0 - if not program._fleet_opt is None: + if program._fleet_opt is not None: if program._fleet_opt.get("worker_class", "") == "HeterCpuWorker": is_heter = 1 if program._fleet_opt.get("trainer", "") == "HeterXpuTrainer": @@ -2269,7 +2338,7 @@ def _run_from_dataset( raise RuntimeError( "dataset is need and should be initialized" ) - ## change default executor + # change default executor heter_place = framework._get_paddle_place(heter_place) p = core.Place() p.set_place(heter_place) @@ -2722,7 +2791,7 @@ def _run_using_fleet_executor( if return_numpy: tensor = as_numpy(tensor) else: - tensor = [t for t in tensor] + tensor = list(tensor) if tensor: scope_result_list.append(tensor) diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index ca9bcf5fd8db5..3a8e741d91c44 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -626,12 +626,10 @@ def _set_pipeline_stage(stage): def _fake_interface_only_(func): def __impl__(*args, **kwargs): raise AssertionError( - "'{}' only can be called by `paddle.Tensor` in dynamic graph mode. Suggestions:\n" + f"'{func.__name__}' only can be called by `paddle.Tensor` in dynamic graph mode. Suggestions:\n" " 1. If you are in static graph mode, you can switch to dynamic graph mode by turning off `paddle.enable_static()` or calling `paddle.disable_static()`.\n" " 2. If you are using `@paddle.jit.to_static`, you can call `paddle.jit.enable_to_static(False)`. " - "If you have to translate dynamic graph to static graph, please use other API to replace '{}'.".format( - func.__name__, func.__name__ - ) + f"If you have to translate dynamic graph to static graph, please use other API to replace '{func.__name__}'." ) return __impl__ @@ -2995,7 +2993,7 @@ def __init__( if ( type == 'less_than' and op_attrs['force_cpu'] is not None - ) or op_attrs['force_cpu'] != False: + ) or op_attrs['force_cpu'] is not False: warnings.warn( "The Attr(force_cpu) of Op(%s) will be deprecated in the future, " "please use 'device_guard' instead. 'device_guard' has higher priority when they are " @@ -4266,7 +4264,7 @@ def _rename_var(self, name, new_name): return var def _remove_var(self, name, sync=True): - if sync == True: + if sync is True: self._sync_with_cpp() self.desc._remove_var(name.encode()) del self.vars[name] @@ -4455,7 +4453,7 @@ def _remove_op(self, index, sync=True): Returns: None """ - if sync == True: + if sync is True: self._sync_with_cpp() self.desc._remove_op(index, index + 1) del self.ops[index] @@ -5678,6 +5676,7 @@ def __init__(self): # assigned if this program has been parsed by a pipeline optimizer self._pipeline_opt = None + self._pass_opt = None # assigned if this program has been parsed by a heter pipeline parameter server optimizer self._heter_pipeline_opt = None @@ -6315,7 +6314,8 @@ def clone(self, for_test=False): p.lr_scheduler = self.lr_scheduler if hasattr(self, '_pipeline_opt'): p._pipeline_opt = self._pipeline_opt - + if hasattr(self, '_pass_opt'): + p._pass_opt = self._pass_opt # NOTE(zhiqiu): we sync the cloned program, to update its program by # its desc. p._sync_with_cpp() @@ -7114,9 +7114,7 @@ def condition(var): return is_parameter(var) or is_belong_to_optimizer(var) else: raise ValueError( - "`mode` string should be 'param', 'opt' or 'all', but received {}.".format( - mode - ) + f"`mode` string should be 'param', 'opt' or 'all', but received {mode}." ) var_list = filter(condition, self.list_vars()) diff --git a/python/paddle/base/layer_helper_base.py b/python/paddle/base/layer_helper_base.py index 6b506d6f192b9..74ba6408ef8d8 100644 --- a/python/paddle/base/layer_helper_base.py +++ b/python/paddle/base/layer_helper_base.py @@ -97,7 +97,9 @@ def to_variable(self, value, name=None): name if name else None, True, ) - elif isinstance(value, (Variable, core.eager.Tensor)): + elif isinstance( + value, (Variable, core.eager.Tensor, paddle.pir.OpResult) + ): return value else: raise TypeError( @@ -420,10 +422,10 @@ def create_parameter( is_used = unique_name.dygraph_parameter_name_checker(attr.name) if is_used: raise ValueError( - "parameter name [{}] have be been used. " + f"parameter name [{attr.name}] have be been used. " "In dygraph mode, the name of parameter can't be same." "Please check the parameter attr value passed to self.create_parameter or " - "constructor of dygraph Layers".format(attr.name) + "constructor of dygraph Layers" ) return self.main_program.global_block().create_parameter( dtype=dtype, diff --git a/python/paddle/base/layers/layer_function_generator.py b/python/paddle/base/layers/layer_function_generator.py index f77d26ac50a5f..2cec3b7e58fa1 100644 --- a/python/paddle/base/layers/layer_function_generator.py +++ b/python/paddle/base/layers/layer_function_generator.py @@ -193,7 +193,7 @@ def infer_and_check_dtype(op_proto, *args, **kwargs): dtype = each.dtype elif dtype != each.dtype: raise ValueError( - "operator {0} must input same dtype. {1} vs {2}".format( + "operator {} must input same dtype. {} vs {}".format( op_type, dtype, each.dtype ) ) @@ -337,8 +337,8 @@ def func(x, name=None): func.__name__ = inplace_op_type func.__doc__ = """ -Inplace version of ``{0}`` API, the output Tensor will be inplaced with input ``x``. -Please refer to :ref:`api_base_layers_{1}`. +Inplace version of ``{}`` API, the output Tensor will be inplaced with input ``x``. +Please refer to :ref:`api_base_layers_{}`. """.format( origin_op_type, origin_op_type ) diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index f2b1ac7c6d04d..1f070882758b9 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -355,7 +355,7 @@ def pop(self, *args): if self.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY: raise TypeError( - "Only Variable with VarType.LOD_TENSOR_ARRAY support `append` method, but received type: {}".format( + "Only Variable with VarType.LOD_TENSOR_ARRAY support `pop` method, but received type: {}".format( self.type ) ) @@ -376,7 +376,7 @@ def _neg_(var): return _scalar_op_(var, -1.0, 0.0) @property - def _ndim_(self): + def _ndim(self): """ Returns the dimension of current Variable @@ -393,7 +393,7 @@ def _ndim_(self): >>> # create a static Variable >>> x = paddle.static.data(name='x', shape=[3, 2, 1]) >>> # print the dimension of the Variable - >>> print(x.ndim()) + >>> print(x.ndim) 3 """ return len(self.shape) @@ -627,7 +627,7 @@ def to_dense(var): ('pop', pop), ('dim', dim), ('ndimension', ndimension), - ('ndim', _ndim_), + ('ndim', _ndim), ( '__add__', _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_), diff --git a/python/paddle/base/trainer_desc.py b/python/paddle/base/trainer_desc.py index 255ddf05a580a..3d6c947db484e 100644 --- a/python/paddle/base/trainer_desc.py +++ b/python/paddle/base/trainer_desc.py @@ -112,7 +112,7 @@ def _set_infer(self, infer): def _set_fleet_desc(self, fleet_desc): self._fleet_desc = fleet_desc - ## serialize fleet_desc + # serialize fleet_desc from google.protobuf import text_format fleet_desc_str = text_format.MessageToString(fleet_desc) diff --git a/python/paddle/base/trainer_factory.py b/python/paddle/base/trainer_factory.py index c5743ca22a29e..c8b61fdf7c112 100644 --- a/python/paddle/base/trainer_factory.py +++ b/python/paddle/base/trainer_factory.py @@ -186,7 +186,7 @@ def handler_launch_func(self, scope, handler): elapsed_secs = 0 while True: self.running_lock.acquire() - if self.running == False: + if self.running is False: break if elapsed_secs < period_secs: # TODO(guru4elephant): needs customized condition diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index e485714f2276a..db43a53b5327f 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -14,6 +14,7 @@ import itertools import warnings +from functools import reduce import numpy as np @@ -344,14 +345,10 @@ def get_value_for_bool_tensor(var, item): empty_shape = [0] + list(var.shape[i:]) def idx_not_empty(var, item): - from ..tensor import gather_nd + bool_2_idx = paddle.nonzero(item) + return paddle.gather_nd(var, bool_2_idx) - bool_2_idx = paddle.nonzero(item == True) - return gather_nd(var, bool_2_idx) - - from paddle.static.nn import cond - - return cond( + return paddle.static.nn.cond( item.any(), lambda: idx_not_empty(var, item), lambda: paddle.empty(empty_shape, var.dtype), @@ -484,7 +481,7 @@ def _setitem_impl_(var, item, value): else: raise IndexError( "Valid index accept int, slice, ellipsis, None, list of bool, Variable, " - "but received {}.".format(slice_item) + f"but received {slice_item}." ) axes.append(dim) @@ -496,9 +493,7 @@ def _setitem_impl_(var, item, value): if slice_info.indexes: if len(slice_info.indexes) != len(item): raise IndexError( - "Valid index accept int or slice or ellipsis or list, but received {}.".format( - item - ) + f"Valid index accept int or slice or ellipsis or list, but received {item}." ) return slice_info.set_item(var, value) attrs = { @@ -767,9 +762,7 @@ def parse_index(x, indices): has_advanced_index = True estimated_dim += 1 - elif isinstance( - slice_item, (paddle.base.Variable, paddle.pir.OpResult) - ): + elif isinstance(slice_item, paddle.base.Variable): # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing. if ( slice_item.dtype == paddle.bool @@ -789,6 +782,23 @@ def parse_index(x, indices): has_advanced_index = True estimated_dim += 1 + elif isinstance(slice_item, paddle.pir.OpResult): + # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing. + if slice_item.dtype == paddle.pir.core.DataType.BOOL: + if slice_item.ndim == 0: + # 0-D bool Tensor, same as single PY-bool. + none_axes.append(dim) + + elif slice_item.shape[0] != x.shape[dim]: + raise IndexError( + "The shape of boolean index {} did not match indexed tensor {} along axis {}".format( + slice_item.shape[0], x.shape[dim], dim + ) + ) + advanced_index[estimated_dim] = (estimated_dim, slice_item) + has_advanced_index = True + estimated_dim += 1 + else: raise IndexError( "Valid index accept int / bool / slice / ellipsis / list / Tuple / Ndarray / Tensor, but received {}.".format( @@ -1051,7 +1061,7 @@ def get_tensor_with_basic_indexing( ) attrs['infer_flags'] = infer_flags - from . import in_dynamic_or_pir_mode + from . import in_dynamic_or_pir_mode, in_pir_mode if in_dynamic_or_pir_mode(): if "StartsTensorList" in inputs.keys(): @@ -1071,6 +1081,13 @@ def get_tensor_with_basic_indexing( if len(decrease_axes) > 0: out = paddle._C_ops.squeeze(out, decrease_axes) else: + if in_pir_mode(): + if isinstance(st, (list, tuple)): + if paddle.utils._contain_var(st): + st = paddle.utils.get_int_tensor_list(st) + if isinstance(end, (list, tuple)): + if paddle.utils._contain_var(end): + end = paddle.utils.get_int_tensor_list(end) out = paddle._C_ops.slice( x, axes, diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index ca885072918d9..2091f07c437d5 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -16,6 +16,8 @@ import typing from paddle import pir +from paddle.autograd import ir_backward +from paddle.base.core import call_decomp, has_decomp from paddle.base.libpaddle.pir import Block, Operation, Program from paddle.framework import core @@ -30,6 +32,26 @@ def _build_tensor_tuple(xs): return TypeError(f"Type {type(xs)} is not supported.") +def _analyse_decomp_results(orig_outs, decomp_outs, op): + intermediate_status = op.get_output_intermediate_status() + assert len(orig_outs) == len(decomp_outs) == len(intermediate_status) + res = [] + for org_item, new_item, value in zip( + orig_outs, decomp_outs, intermediate_status + ): + if isinstance(org_item, pir.OpResult): + if value: + assert new_item[0] is None + else: + assert len(new_item) == 1 and isinstance( + new_item[0], pir.OpResult + ) + res.append(new_item[0]) + else: + res.append(new_item) + return res + + def _prepare_python_api_arguments(op): """ For standard api of operator, its inputs should keep consistent with organization of its inputs and attrs. @@ -37,28 +59,35 @@ def _prepare_python_api_arguments(op): Args: op (Operator): The target operator. """ - op_inputs = [] + combine_op_name = "builtin.combine" + inputs = [] for x in op.operands(): - op_input = x.source() - upper_op = op_input.get_defining_op() - if ( - isinstance(upper_op, Operation) - and upper_op.name() == 'builtin.combine' - ): - op_input = [item.source() for item in upper_op.operands()] - op_inputs.append(op_input) - # The inputs of PIR op builtin.combine will be restored as list of tensor. - if op.name() in ["builtin.combine"]: - return (op_inputs,) - - op_attrs_dict = op.attrs() - op_attrs_name = op.get_attr_names() - op_attrs = [op_attrs_dict[x] for x in op_attrs_name] - api_arguments = op_inputs + op_attrs + input = x.source() + if input and input.initialized(): + prev_op = input.get_defining_op() + if ( + isinstance(prev_op, Operation) + and prev_op.name() == combine_op_name + ): + input = [item.source() for item in prev_op.operands()] + inputs.append(input) + else: + # for optional input, such as scale for layer_norm op, + # if it is not set, there will be an empty OpResult which is not initialized in ops.operands + # therefore append None for it. + inputs.append(None) + + # The inputs of Pir op builtin.combine will be restored as list of tensor. + if op.name() == combine_op_name: + return (inputs,) + + api_arguments = inputs + [op.attrs()[x] for x in op.get_attr_names()] return tuple(api_arguments) -def _check_op_results(op_name, orig_outs, new_outs, orig_vars, dst_vars): +def _check_op_results( + op_name, orig_outs, new_outs, orig_vars=None, dst_vars=None +): """ Check whether the replaced outputs are consistent with origin outputs. @@ -88,8 +117,9 @@ def _check_op_results(op_name, orig_outs, new_outs, orig_vars, dst_vars): # to keep same as phi op definition, orig_out may receive None continue elif new_out is not None: - if orig_out in orig_vars.keys(): - dst_vars[orig_vars[orig_out]] = new_out + if orig_vars is not None and dst_vars is not None: + if orig_out in orig_vars.keys(): + dst_vars[orig_vars[orig_out]] = new_out orig_dtype = orig_out.dtype new_dtype = new_out.dtype orig_shape = orig_out.shape @@ -215,7 +245,8 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter): for idx, op in enumerate(ops_list): op_name = op.name() decom_rule = register.get_decomp_rule(op_name) - lower = decom_rule and op_filter(op) + has_sink_decomp_rule = has_decomp(op) + lower = (decom_rule or has_sink_decomp_rule) and op_filter(op) if op.name() == "builtin.combine": temp_op = op @@ -231,7 +262,13 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter): pir.set_insertion_point(op) input_args = _prepare_python_api_arguments(op) orig_outs = op.results() - new_outs = _build_tensor_tuple(decom_rule(*input_args)) + if has_sink_decomp_rule: + decomp_outs = call_decomp(op) + new_outs = _analyse_decomp_results( + orig_outs, decomp_outs, op + ) + else: + new_outs = _build_tensor_tuple(decom_rule(*input_args)) # Todo: To cover such case: some outputs are no longer needed after decomposition. _check_op_results( @@ -261,3 +298,308 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter): raise TypeError( f"Expect type Block or Sequence of Block, but got type {type(block)}" ) + + +def get_leaf_ops(block, global_outputs): + ''' + This API checks which op contributes to the outputs of the entire computation graph, + as well as determining the corresponding output index. + + Args: + block (Block): the block of program to be processed. + global_outputs (tuple(Value)): the outputs of the entire computation graph. + + Returns: + related_ops (tuple(pir.Operation)): a tuple of op that contributes to the outputs of the entire graph. + related_ops_output_indexes (tuple(tuple())) : a tuple records the mapping of tuple(the output index of the op, the output index of the entire graph) + ''' + if not isinstance(block, Block): + raise TypeError(f"block should be Block, but got type {type(block)}") + if not isinstance(global_outputs, list): + raise TypeError("The type of global_outputs should be list") + + related_ops = [] + related_ops_output_indexes = [] + + op_to_op_valid_result = {} + for op in block.ops: + op_valid_result = [] + for x in op.results(): + if x.initialized(): + op_valid_result.append(x) + op_to_op_valid_result[op] = op_valid_result + + for global_output in global_outputs: + for op in op_to_op_valid_result.keys(): + if global_output in op_to_op_valid_result[op]: + if op not in related_ops: + related_ops.append(op) + related_ops_output_indexes.append( + [ + [ + op.results().index(global_output), + global_outputs.index(global_output), + ] + ] + ) + else: + related_ops_output_indexes[related_ops.index(op)].append( + [ + op.results().index(global_output), + global_outputs.index(global_output), + ] + ) + + return tuple(related_ops), tuple(related_ops_output_indexes) + + +def replace_graph_outputs( + global_outputs, + op_outputs, + op_index, + related_ops_output_indexes, +): + ''' + This API replace the outputs of the entire computation graph with the new outputs of the op, + when the op contributes to the outputs of the entire computation graph. + ''' + for index in related_ops_output_indexes[op_index]: + global_outputs[index[1]] = op_outputs[index[0]] + + +def decompose_fwd_op( + block: Block, fwd_op: pir.Operation, grad_var_to_var_map: dict +) -> tuple: + ''' + Decompose the fwd_op into a list of primitive ops. + + Args: + block (Block): the block to which the fwd_op belongs. + fwd_op (pir.Operation): the forward op to be decomposed. + grad_var_to_var_map (dict): a dict obtained from distributed processing, + which maps the backward grad variable to its corresponding forward variable. + Returns: + new_outputs (tuple(Value)): the new outputs after decomposing. + has_decomposed: whether the forward op has been successfully decomposed. + ''' + + if not core._is_fwd_prim_enabled(): + raise RuntimeError( + "To decompose forward op, please set `core._set_prim_forward_enabled(True)` firstly" + ) + + with pir.core.program_guard(block.program): + op_name = fwd_op.name() + orig_outs = fwd_op.results() + decom_rule = register.get_decomp_rule(op_name) + has_sink_decomp_rule = has_decomp(fwd_op) + lower = decom_rule or has_sink_decomp_rule + + if lower: + input_args = _prepare_python_api_arguments(fwd_op) + pir.set_insertion_point(fwd_op) + if has_sink_decomp_rule: + decomp_outs = call_decomp(fwd_op) + new_outs = _analyse_decomp_results( + orig_outs, decomp_outs, fwd_op + ) + else: + new_outs = _build_tensor_tuple(decom_rule(*input_args)) + + _check_op_results(op_name, orig_outs, new_outs) + + # update_grad_var_to_var_map + for grad_var, var in grad_var_to_var_map.items(): + if var in orig_outs: + grad_var_to_var_map[grad_var] = new_outs[ + orig_outs.index(var) + ] + + fwd_op.replace_all_uses_with(new_outs) + block.remove_op(fwd_op) + return new_outs, True + else: + return tuple(orig_outs), False + + +def decompose_bwd_op_directly( + block: Block, + fwd_op: pir.Operation, + bwd_op: pir.Operation, + grad_var_to_var_map: dict, +) -> tuple: + ''' + Decompose the bwd_op into a list of primitive ops. + If fwd_op has composite vjp rules (including custom vjp), call call_vjp() to get a list of primitive operators in backward graph, then replace bwd_op. + + Args: + block (Block): the block to which the bwd_op belongs. + fwd_op (pir.Operation): the forward op. + bwd_op (pir.Operation): the backward op to be decomposed. + grad_var_to_var_map (dict): a dict obtained from distributed processing, + which maps the backward grad variable to its corresponding forward variable. + Return: + new_input_grads (tuple(Value)): new results of backward op after decomposing. + has_decomposed: whether the backward op has been successfully decomposed. If a fwd op does not have composite vjp rules and can not be decomposed directly, this function will return False. + ''' + + if not core._is_bwd_prim_enabled(): + raise RuntimeError( + "To decompose backward op, please set `core._set_prim_backward_enabled(True)` firstly" + ) + + # prepare forward and backward op's input and outputs infos + fwd_inputs = [x.source() for x in fwd_op.operands()] + fwd_outputs = fwd_op.results() + bwd_inputs = [x.source() for x in bwd_op.operands()] + grad_inputs = bwd_op.results() + res = [] + + # prepare the input args of call_vjp(fwd_op, inputs, outputs, out_grads, stop_gradients) + grad_outputs = [] + for bwd_input in bwd_inputs: + if not (bwd_input in fwd_inputs or bwd_input in fwd_outputs): + grad_outputs.append([bwd_input]) + fwd_outputs_ = [[fwd_output] for fwd_output in fwd_outputs] + fwd_inputs_ = [ + [fwd_op.operand_source(i)] for i in range(0, fwd_op.num_operands()) + ] + stop_gradients = [] + for grad_input in grad_inputs: + if grad_input.initialized(): + stop_gradients.append([False]) + else: + stop_gradients.append([True]) + + # record the backward op's position for subsequent replacement + bwd_op_idx = block.ops.index(bwd_op) + before_num_ops = len(block.ops) + # generate primitive operators corresponding to the backward op + new_grad_inputs = core.call_vjp( + fwd_op, fwd_inputs_, fwd_outputs_, grad_outputs, stop_gradients + ) + after_num_ops = len(block.ops) + num_appended_ops = after_num_ops - before_num_ops + + # if forward op has no composite vjp rules, call_vjp() appends the same op as original backward op, + # which means the backward op can not be decomposed directly, return False + if num_appended_ops == 1 and block.ops[-1].name() == bwd_op.name(): + block.remove_op(block.ops[-1]) + return None, False + else: + # record new outputs of the decomposed backward op + for grad_input in new_grad_inputs: + if grad_input[0] is not None and grad_input[0].initialized(): + res.append(grad_input[0]) + else: + res.append(pir.fake_op_result()) + + # update_grad_var_to_var_map + for idx, grad_input in enumerate(grad_inputs): + if grad_input in grad_var_to_var_map.keys(): + grad_var_to_var_map[res[idx]] = grad_var_to_var_map.pop( + grad_input + ) + + # move the list of primitive operators to the position of backward op + insert_idx = bwd_op_idx + for i in range(before_num_ops, after_num_ops): + block.move_op(block.ops[i], insert_idx) + insert_idx += 1 + + # replace the following use of original backward op's outputs with new outputs, and then remove original backward op + bwd_op.replace_all_uses_with(res) + block.remove_op(bwd_op) + + return tuple(res), True + + +def decompose_bwd_op_after_fwd_op( + block: Block, + fwd_op: pir.Operation, + bwd_op: pir.Operation, + grad_var_to_var_map: dict, + fwd_inputs: dict, + fwd_outputs_after_decompose: tuple, +) -> tuple: + ''' + Decompose the bwd_op into a list of primitive ops. + If fwd_op has no composite vjp rules, and fwd_op has been decomposed to a list of primitive operators in forward graph previously, + call grad() for the decomposed forward subgraph to get a list of primitive operators in backward graph, then replace bwd_op. + + Args: + block (Block): the block to which the bwd_op belongs. + fwd_op (pir.Operation): the forward op. + bwd_op (pir.Operation): the backward op to be decomposed. + grad_var_to_var_map (dict): a dict obtained from distributed processing, + which maps the backward grad variable to its corresponding forward variable. + fwd_inputs: (tuple(Value)): the original input of the forward op, + fwd_outputs_after_decompose (tuple(Value)): the output of the decomposed forward op, if forward op has no vjp rules, forward op shoule be decomposed firstly, + fwd_outputs_after_decompose means the new output of the decomposed forward op. If forward op has vjp rules, fwd_outputs_after_decompose is None. + Return: + new_input_grads (tuple(Value)): results of backward op after decomposing. + ''' + + if not core._is_bwd_prim_enabled(): + raise RuntimeError( + "To decompose backward op, please set `core._set_prim_backward_enabled(True)` firstly" + ) + if fwd_outputs_after_decompose is None: + raise RuntimeError( + "To decompose backward op, please decompose forward op firstly" + ) + + # prepare forward and backward op's input and outputs infos + bwd_inputs = [x.source() for x in bwd_op.operands()] + grad_inputs = bwd_op.results() + res = [] + + # prepare the input args of grad(outputs, inputs, out_grads) + grad_outputs = tuple( + bwd_input + for bwd_input in bwd_inputs + if not ( + bwd_input in fwd_inputs or bwd_input in fwd_outputs_after_decompose + ) + ) + fwd_outputs_ = tuple( + grad_var_to_var_map[grad_output] for grad_output in grad_outputs + ) + fwd_inputs_ = tuple( + grad_var_to_var_map[grad_input] + for grad_input in grad_inputs + if grad_input.initialized() + ) + + # record the backward op's position for subsequent replacement + bwd_op_idx = block.ops.index(bwd_op) + before_num_ops = len(block.ops) + # generate primitive operators corresponding to the backward op + new_grad_inputs = ir_backward.grad(fwd_outputs_, fwd_inputs_, grad_outputs) + after_num_ops = len(block.ops) + + # record new outputs of the decomposed backward op + input_grads_idx = 0 + for idx, grad_input in enumerate(grad_inputs): + if grad_input.initialized(): + res.append(new_grad_inputs[input_grads_idx]) + input_grads_idx += 1 + else: + res.append(pir.fake_op_result()) + + # update_grad_var_to_var_map + for idx, grad_input in enumerate(grad_inputs): + if grad_input in grad_var_to_var_map.keys(): + grad_var_to_var_map[res[idx]] = grad_var_to_var_map.pop(grad_input) + + # move the list of primitive operators to the position of backward op + insert_idx = bwd_op_idx + for i in range(before_num_ops, after_num_ops): + block.move_op(block.ops[i], insert_idx) + insert_idx += 1 + + # replace the following use of original backward op's outputs with new outputs, and then remove original backward op + bwd_op.replace_all_uses_with(res) + block.remove_op(bwd_op) + + return tuple(res) diff --git a/python/paddle/decomposition/rules.py b/python/paddle/decomposition/rules.py index 6e59f9858e74a..d64cba8d657ba 100644 --- a/python/paddle/decomposition/rules.py +++ b/python/paddle/decomposition/rules.py @@ -18,7 +18,6 @@ from .register import register_decomp -@register_decomp('pd_op.mean') def mean(x, axis, keepdim): """define composite rule of op mean""" x_shape = x.shape diff --git a/python/paddle/device/__init__.py b/python/paddle/device/__init__.py index 7ee16ffcf5464..f6c3bfc78a9a6 100644 --- a/python/paddle/device/__init__.py +++ b/python/paddle/device/__init__.py @@ -58,10 +58,12 @@ def is_compiled_with_custom_device(device_type): """ + Whether paddle was built with Paddle_CUSTOM_DEVICE . Args: - std::string, the registered device type, like "npu". + device_type (str): the registered device type, like "npu". + Return: bool, ``True`` if CustomDevice is supported, otherwise ``False``. @@ -70,12 +72,14 @@ def is_compiled_with_custom_device(device_type): >>> import paddle >>> support_npu = paddle.device.is_compiled_with_custom_device("npu") + """ return core.is_compiled_with_custom_device(device_type) def is_compiled_with_ipu(): """ + Whether paddle was built with WITH_IPU=ON to support Graphcore IPU. Returns (bool): `True` if IPU is supported, otherwise `False`. @@ -85,12 +89,14 @@ def is_compiled_with_ipu(): >>> import paddle >>> support_ipu = paddle.is_compiled_with_ipu() + """ return core.is_compiled_with_ipu() def IPUPlace(): """ + Return a Graphcore IPU Place Examples: @@ -101,12 +107,14 @@ def IPUPlace(): >>> import paddle >>> paddle.device.set_device('ipu') >>> place = paddle.device.IPUPlace() + """ return core.IPUPlace() def is_compiled_with_xpu(): """ + Whether paddle was built with WITH_XPU=ON to support Baidu Kunlun Returns (bool): whether paddle was built with WITH_XPU=ON @@ -116,15 +124,17 @@ def is_compiled_with_xpu(): >>> import paddle >>> support_xpu = paddle.device.is_compiled_with_xpu() + """ return core.is_compiled_with_xpu() def XPUPlace(dev_id): """ + Return a Baidu Kunlun Place - Parameters: + Args: dev_id(int): Baidu Kunlun device id Examples: @@ -135,12 +145,14 @@ def XPUPlace(dev_id): >>> import paddle >>> paddle.device.set_device('xpu') >>> place = paddle.device.XPUPlace(0) + """ return core.XPUPlace(dev_id) def get_cudnn_version(): """ + This function return the version of cudnn. the retuen value is int which represents the cudnn version. For example, if it return 7600, it represents the version of cudnn is 7.6. @@ -249,11 +261,12 @@ def _convert_to_place(device): def set_device(device): """ + Paddle supports running calculations on various types of devices, including CPU, GPU, XPU, NPU and IPU. They are represented by string identifiers. This function can specify the global device which the OP will run. - Parameters: + Args: device(str): This parameter determines the specific running device. It can be ``cpu``, ``gpu``, ``xpu``, ``npu``, ``gpu:x``, ``xpu:x``, ``npu:x`` and ``ipu``, where ``x`` is the index of the GPUs, XPUs or NPUs. @@ -271,6 +284,7 @@ def set_device(device): >>> x1 = paddle.ones(name='x1', shape=[1, 2], dtype='int32') >>> x2 = paddle.zeros(name='x2', shape=[1, 2], dtype='int32') >>> data = paddle.stack([x1,x2], axis=1) + """ place = _convert_to_place(device) framework._set_expected_place(place) @@ -279,6 +293,7 @@ def set_device(device): def get_device(): """ + This function can get the current global device of the program is running. It's a string which is like 'cpu', 'gpu:x', 'xpu:x' and 'npu:x'. if the global device is not set, it will return a string which is 'gpu:x' when cuda is avaliable or it @@ -318,6 +333,7 @@ def get_device(): def get_all_device_type(): """ + Get all available device types. Returns: @@ -340,12 +356,14 @@ def get_all_device_type(): >>> # Case 4: paddlepaddle-gpu package installed, and custom deivce 'CustomCPU' and 'CustomGPU' is registerd. >>> # Output: ['cpu', 'gpu', 'CustomCPU', 'CustomGPU'] + """ return core.get_all_device_type() def get_all_custom_device_type(): """ + Get all available custom device types. Returns: @@ -362,12 +380,14 @@ def get_all_custom_device_type(): >>> # Case 2: paddlepaddle-gpu package installed, and custom deivce 'CustomCPU' and 'CustomGPU' is registerd. >>> # Output: ['CustomCPU', 'CustomGPU'] + """ return core.get_all_custom_device_type() def get_available_device(): """ + Get all available devices. Returns: @@ -390,12 +410,14 @@ def get_available_device(): >>> # Case 4: paddlepaddle-gpu package installed, and custom deivce 'CustomCPU' and 'CustomGPU' is registerd. >>> # Output: ['cpu', 'gpu:0', 'gpu:1', 'CustomCPU', 'CustomGPU:0', 'CustomGPU:1'] + """ return core.get_available_device() def get_available_custom_device(): """ + Get all available custom devices. Returns: @@ -412,22 +434,27 @@ def get_available_custom_device(): >>> # Case 2: paddlepaddle-gpu package installed, and custom deivce 'CustomCPU' and 'CustomGPU' is registerd. >>> # Output: ['CustomCPU', 'CustomGPU:0', 'CustomGPU:1'] + """ return core.get_available_custom_device() class Event: ''' + A device event wrapper around StreamBase. - Parameters: + + Args: device(str|paddle.CUDAPlace(n)|paddle.CustomPlace(n)): Which device the stream runn on. If device is None, the device is the current device. Default: None. - It can be ``gpu``, ``gpu:x``,``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec, + It can be ``gpu``, ``gpu:x``, ``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec, where ``x`` is the index of the GPUs, XPUs. And it can be paddle.CUDAPlace(n) or paddle.CustomPlace(n). enable_timing (bool, optional): indicates if the event should measure time, default is False blocking (bool, optional): if True, ``wait`` will be blocking, default is False interprocess (bool): if True, the event can be shared between processes, default is False + Returns: Event: The event. + Examples: .. code-block:: python @@ -439,6 +466,7 @@ class Event: >>> e2 = paddle.device.Event('custom_cpu') >>> e3 = paddle.device.Event('custom_cpu:0') >>> e4 = paddle.device.Event(paddle.CustomPlace('custom_cpu', 0)) + ''' def __init__( @@ -478,12 +506,16 @@ def __init__( def record(self, stream=None): ''' + Records the event in a given stream. - Parameters: + + Args: stream(Stream, optional): The given stream. By default, stream is None, event will be recorded in current_stream. + Returns: None. + Examples: .. code-block:: python @@ -496,6 +528,7 @@ def record(self, stream=None): >>> s = paddle.device.Stream() >>> e.record(s) + ''' if stream is None: stream = current_stream(self.device) @@ -504,9 +537,12 @@ def record(self, stream=None): def query(self): ''' + Checks if all work currently captured by event has completed. + Returns: bool: Whether all work currently captured by event has completed. + Examples: .. code-block:: python @@ -517,15 +553,19 @@ def query(self): >>> e = paddle.device.Event() >>> e.record() >>> e.query() + ''' return self.event_base.query() def elapsed_time(self, end_event): ''' + Returns the time elapsed in milliseconds after the event was recorded and before the end_event was recorded. + Returns: int: The time. + Examples: .. code-block:: python @@ -539,16 +579,20 @@ def elapsed_time(self, end_event): >>> e2 = paddle.device.Event() >>> e2.record() >>> e1.elapsed_time(e2) + ''' return 0 def synchronize(self): ''' + Waits for the event to complete. Waits until the completion of all work currently captured in this event. This prevents the CPU thread from proceeding until the event completes. + Returns: None. + Examples: .. code-block:: python @@ -559,6 +603,7 @@ def synchronize(self): >>> e = paddle.device.Event() >>> e.record() >>> e.synchronize() + ''' self.event_base.synchronize() @@ -568,16 +613,20 @@ def __repr__(self): class Stream: ''' + A device stream wrapper around StreamBase. - Parameters: + + Args: device(str|paddle.CUDAPlace(n)|paddle.CustomPlace(n)): Which device the stream runn on. If device is None, the device is the current device. Default: None. - It can be ``gpu``, ``gpu:x``,``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec, + It can be ``gpu``, ``gpu:x``, ``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec, where ``x`` is the index of the GPUs, XPUs. And it can be paddle.CUDAPlace(n) or paddle.CustomPlace(n). priority(int, optional): priority of the CUDA stream. Can be either 1 (high priority) or 2 (low priority). By default, streams have priority 2. + Returns: Stream: The stream. + Examples: .. code-block:: python @@ -589,6 +638,7 @@ class Stream: >>> s2 = paddle.device.Stream('custom_cpu') >>> s3 = paddle.device.Stream('custom_cpu:0') >>> s4 = paddle.device.Stream(paddle.CustomPlace('custom_cpu', 0)) + ''' def __init__(self, device=None, priority=2, stream_base=None): @@ -633,11 +683,15 @@ def __init__(self, device=None, priority=2, stream_base=None): def wait_event(self, event): ''' + Makes all future work submitted to the stream wait for an event. - Parameters: + + Args: event (Event): an event to wait for. + Returns: None. + Examples: .. code-block:: python @@ -650,18 +704,23 @@ def wait_event(self, event): >>> e = paddle.device.Event() >>> e.record(s1) >>> s2.wait_event(e) + ''' self.stream_base.wait_event(event.event_base) def wait_stream(self, stream): ''' + Synchronizes with another stream. All future work submitted to this stream will wait until all kernels submitted to a given stream at the time of call complete. - Parameters: + + Args: stream (Stream): a stream to synchronize. + Returns: None. + Examples: .. code-block:: python @@ -672,17 +731,22 @@ def wait_stream(self, stream): >>> s1 = paddle.device.Stream() >>> s2 = paddle.device.Stream() >>> s1.wait_stream(s2) + ''' self.stream_base.wait_stream(stream.stream_base) def record_event(self, event=None): ''' + Records an event. - Parameters: + + Args: event (Event, optional): event to record. If not given, a new one - will be allocated. + will be allocated. + Returns: Event: Recorded event. + Examples: .. code-block:: python @@ -695,6 +759,7 @@ def record_event(self, event=None): >>> e2 = paddle.device.Event() >>> s.record_event(e2) + ''' if event is None: event = Event(self.device) @@ -703,9 +768,12 @@ def record_event(self, event=None): def query(self): ''' + Checks if all the work submitted has been completed. + Returns: bool: Whether all kernels in this stream are completed. + Examples: .. code-block:: python @@ -715,14 +783,18 @@ def query(self): >>> paddle.set_device('custom_cpu') >>> s = paddle.device.Stream() >>> s.query() + ''' return self.stream_base.query() def synchronize(self): ''' + Wait for all the kernels in this stream to complete. + Returns: None. + Examples: .. code-block:: python @@ -732,6 +804,7 @@ def synchronize(self): >>> paddle.set_device('custom_cpu') >>> s = paddle.device.Stream() >>> s.synchronize() + ''' self.stream_base.synchronize() @@ -758,13 +831,17 @@ def __repr__(self): def current_stream(device=None): ''' + Return the current stream by the device. - Parameters: + + Args: device(str|paddle.CUDAPlace(n)|paddle.CustomPlace(n)): The device which want to get stream from. If device is None, the device is the current device. Default: None. - It can be ``gpu``, ``gpu:x``,``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec, + It can be ``gpu``, ``gpu:x``, ``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec, where ``x`` is the index of the GPUs, CustomDevicecs. And it can be paddle.CUDAPlace(n) or paddle.CustomPlace(n). + Returns: Stream: The stream to the device. + Examples: .. code-block:: python @@ -776,6 +853,7 @@ def current_stream(device=None): >>> s2 = paddle.device.current_stream("custom_cpu:0") >>> place = paddle.CustomPlace('custom_cpu', 0) >>> s3 = paddle.device.current_stream(place) + ''' if device is None: place = paddle.framework._current_expected_place() @@ -804,11 +882,15 @@ def current_stream(device=None): def set_stream(stream): ''' + Set the current stream. - Parameters: + + Args: stream(Stream): The selected stream. + Returns: Stream: The previous stream. + Examples: .. code-block:: python @@ -818,6 +900,7 @@ def set_stream(stream): >>> paddle.set_device('custom_cpu') >>> s = paddle.device.Stream() >>> paddle.device.set_stream(s) + ''' prev_stream = current_stream(stream.stream_base.place) @@ -844,13 +927,17 @@ def set_stream(stream): class stream_guard: ''' + Notes: This API only supports dynamic graph mode currently. A context manager that specifies the current stream context by the given stream. - Parameters: + + Args: stream(Stream, optional): the selected stream. If stream is None, just yield. + Returns: None. + Examples: .. code-block:: python @@ -865,6 +952,7 @@ class stream_guard: >>> with paddle.device.stream_guard(s): ... s.wait_stream(paddle.device.default_stream()) ... data4 = data1 + data3 + ''' def __init__(self, stream=None): @@ -899,13 +987,15 @@ def __exit__(self, *args): def synchronize(device=None): """ + Wait for the compute on the given device to finish. - Parameters: + + Args: device(str|paddle.CUDAPlace(n)|paddle.XPUPlace(n)|paddle.CustomPlace(n)): The device which want to wait for. If device is None, the device is the current device. Default: None. It can be ``gpu``, ``gpu:x``, ``xpu``, ``xpu:x``, ``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec, where ``x`` is the index of the GPUs, XPUs. And it can be paddle.CUDAPlace(n) or paddle.XPUPlace(n) or paddle.CustomPlace(n). - Examples: + Examples: .. code-block:: python >>> # doctest: +REQUIRES(env:CUSTOM_DEVICE) @@ -916,6 +1006,7 @@ def synchronize(device=None): >>> paddle.device.synchronize("custom_cpu:0") >>> place = paddle.CustomPlace('custom_cpu', 0) >>> paddle.device.synchronize(place) + """ if device is None: diff --git a/python/paddle/device/cuda/__init__.py b/python/paddle/device/cuda/__init__.py index 0a094319f893f..ea974e22c9f2a 100644 --- a/python/paddle/device/cuda/__init__.py +++ b/python/paddle/device/cuda/__init__.py @@ -469,9 +469,9 @@ def get_device_properties(device=None): device_id = int(device[4:]) else: raise ValueError( - "The current string {} is not expected. Because paddle.device." + f"The current string {device} is not expected. Because paddle.device." "cuda.get_device_properties only support string which is like 'gpu:x'. " - "Please input appropriate string again!".format(device) + "Please input appropriate string again!" ) else: raise ValueError( diff --git a/python/paddle/device/cuda/cuda_graphed_layer.py b/python/paddle/device/cuda/cuda_graphed_layer.py new file mode 100644 index 0000000000000..9765cc19690c5 --- /dev/null +++ b/python/paddle/device/cuda/cuda_graphed_layer.py @@ -0,0 +1,146 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +from .graphs import CUDAGraph + + +class CUDAGraphContext: + def __init__(self, layer, num_warmup_steps): + self.step = 0 + self.layer = layer + self.forward_graph = CUDAGraph() + self.backward_graph = CUDAGraph() + self.num_warmup_steps = num_warmup_steps + + +def detach(x): + if isinstance(x, paddle.Tensor): + x_detached = x.detach() + x_detached.stop_gradient = x.stop_gradient + return x_detached + else: + return x + + +def get_grad(x): + if isinstance(x, paddle.Tensor): + return x.grad + else: + return x + + +class _CUDAGraphedLayer(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, context, *args): + args = [detach(x) for x in args] + + if context.step < context.num_warmup_steps: + with paddle.enable_grad(): + y = context.layer(*args) + ctx.save_for_backward(context, args, y) + return y.detach() + + elif context.step == context.num_warmup_steps: + context.args_static = args + context.forward_graph.capture_begin() + with paddle.enable_grad(): + y = context.layer(*context.args_static) + context.forward_graph.capture_end() + + context.forward_graph.replay() + context.y_static = y + + ctx.save_for_backward(context, context.args_static, y) + return y.detach() + else: + for x_staic, x in zip(context.args_static, args): + if isinstance(x_staic, paddle.Tensor): + x_staic.copy_(x, True) + + context.forward_graph.replay() + y = context.y_static + + ctx.save_for_backward(context, context.args_static, y) + return y.detach() + + @staticmethod + def backward(ctx, dy): + context, args, y = ctx.saved_tensor() + + if context.step < context.num_warmup_steps: + y.backward(dy) + elif context.step == context.num_warmup_steps: + context.dy_static = dy + context.backward_graph.capture_begin() + context.y_static.backward(context.dy_static) + context.backward_graph.capture_end() + context.backward_graph.replay() + else: + context.dy_static.copy_(dy, True) + context.backward_graph.replay() + + def get_grad(x): + return x.grad if isinstance(x, paddle.Tensor) else x + + args_grad = tuple(get_grad(x) for x in args) + context.step += 1 + return args_grad + + +class CUDAGraphedLayer(paddle.nn.Layer): + """ + CUDAGraphedLayer: A PaddlePaddle Layer to convert an eager mode model to utilize CUDA Graphs. + + CUDA Graphs provide a way to capture kernel-level operations of a model and play + them back efficiently, allowing for potential speedups in repetitive computations, + such as those during training iterations. This layer is a wrapper that enables + the usage of CUDA Graphs with PaddlePaddle models. + + Overview: + - The layer encapsulates another layer (the model to be converted). + - During the first few (num_warmup_steps) iterations, the layer operates in + eager mode without any CUDA Graphs. + - After the warmup steps, the layer captures the forward and backward computations + and replays them using CUDA Graphs in subsequent iterations. + + Usage: + model = Model() + graphed_model = CUDAGraphedLayer(model) + + Parameters: + - layer (paddle.nn.Layer): The PaddlePaddle model/layer to be converted. + - num_warmup_steps (int): The number of iterations before the CUDA Graph + capture begins. Default is 3. + + Notes: + - Restrictions: + * CPU-GPU Synchronization: Operations that synchronize the CPU with the GPU, like device to host transfers, are not allowed. + * CPU Work: Any operations on the CPU within the captured graph are not recorded. + * Memory Address (Pointer) Consistency: Replays consistently read from and write to identical virtual memory addresses. + * Dynamic Operations: + - Control Flow: Dynamic control flows, especially those based on CPU data like if/else statements, are prohibited. + - Tensor Shapes: Dynamic tensor shapes are not supported. + + - Allowed Operations: + * CUDA RNG Operations: CUDA-based Random Number Generation operations are allowed. + """ + + def __init__(self, layer: paddle.nn.Layer, num_warmup_steps=3): + super().__init__() + self.context = CUDAGraphContext(layer, num_warmup_steps) + + def forward(self, *args): + return _CUDAGraphedLayer.apply(self.context, *args) diff --git a/python/paddle/distributed/auto_parallel/__init__.py b/python/paddle/distributed/auto_parallel/__init__.py index 4486b3220fa4d..099ef04e5b048 100644 --- a/python/paddle/distributed/auto_parallel/__init__.py +++ b/python/paddle/distributed/auto_parallel/__init__.py @@ -18,6 +18,7 @@ from .interface import shard_tensor from .interface import shard_op from .interface import recompute +from .interface import exclude_ops_in_recompute from .interface import fetch from .random import parallel_manual_seed diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 0c4c83ef403ce..114f852815183 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -136,6 +136,10 @@ def shard_tensor( >>> print(d_tensor) """ + if place is None: + place = paddle.framework._current_expected_place() + place = paddle.framework._get_paddle_place(place) + # 1. create dense tensor # `paddle.to_tensor` supports both dynamic and static mode tensor = paddle.to_tensor( @@ -154,7 +158,7 @@ def shard_tensor( tensor, dist_attr=dist_attr, **tensor.__dict__ ) else: - return paddle.Tensor(tensor, dist_attr=dist_attr) + return paddle.Tensor(tensor, dist_attr=dist_attr, place=place) else: # TODO(zhiqiu): we need to refine the static shard_tensor return shard_tensor_static( diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index e4ab9d0f88662..e80e68281c09a 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -56,6 +56,8 @@ def set_field_default_config(category, field, default_value): set_field_default_config(RECOMPUTE, "enable", False) set_field_default_config(RECOMPUTE, "checkpoints", []) set_field_default_config(RECOMPUTE, "no_recompute_segments", []) +set_field_default_config(RECOMPUTE, "sr", 0) +set_field_default_config(RECOMPUTE, "refined_ops_patterns", []) # List[Dict] set_field_default_config(RECOMPUTE, "enable_tuning", False) ######################################### @@ -157,3 +159,11 @@ def set_field_default_config(category, field, default_value): set_field_default_config(DP_OPTIMIZATION, "fuse_all_reduce_ops", True) set_field_default_config(DP_OPTIMIZATION, "fuse_grad_size_in_MB", 32) set_field_default_config(DP_OPTIMIZATION, "overlap_comm_cacl", True) + +######################################### +# model parallel configuration +######################################### +MP_OPTIMIZATION = "mp_optimization" +set_field_default_config( + MP_OPTIMIZATION, "allreduce_matmul_grad_overlapping", False +) diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index c8ab91a7346f0..43472340d7df7 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -219,15 +219,57 @@ def __call__(self, *args, **kwargs): for idx in range(op_size, new_op_size): op = cur_block.ops[idx] - op._set_attr( - 'op_namescope', "/auto_parallel/rc_" + str(_g_recompute_idx) - ) + if op.has_attr( + "op_namescope" + ) and 'auto_parallel/exclude_rc' in op.attr("op_namescope"): + op._set_attr( + 'op_namescope', + "/auto_parallel/rc_" + + str(_g_recompute_idx) + + "_exclude_rc", + ) + else: + op._set_attr( + 'op_namescope', + '/auto_parallel/rc_' + str(_g_recompute_idx), + ) return output return RecomputeOperator(op) +def exclude_ops_in_recompute(run_function): + """ + Exclude some operators in recompute segements. + Args: + run_function (callabe): The callabe function to be excluded. + + Returns: + ExcludeOperator: The callable object. + + """ + + class ExcludeOperator: + def __init__(self, run_function): + self._run_function = run_function + + def __call__(self, *args, **kwargs): + default_prog = paddle.static.default_main_program() + cur_block = default_prog.current_block() + op_size = len(cur_block.ops) + output = self._run_function(*args, **kwargs) + new_op_size = len(cur_block.ops) + + for idx in range(op_size, new_op_size): + op = cur_block.ops[idx] + op._set_attr('op_namescope', "/auto_parallel/exclude_rc") + + return output + + return ExcludeOperator(run_function) + + _g_collections = {} diff --git a/python/paddle/distributed/auto_parallel/static/completion.py b/python/paddle/distributed/auto_parallel/static/completion.py index d1024a226c64e..729ae7c055969 100644 --- a/python/paddle/distributed/auto_parallel/static/completion.py +++ b/python/paddle/distributed/auto_parallel/static/completion.py @@ -35,6 +35,7 @@ from .process_group import get_world_process_group from .utils import ( __no_shape_var_type__, + _g_gradient_clip_ops, is_gradient_clip_op, is_naive_data_parallel, ) @@ -1689,13 +1690,7 @@ def complete_update_annotation(self, serial_main_program): op = ops[idx] if int(op.attr('op_role')) == int(OpRole.Optimize): if is_gradient_clip_op(op): - if op.type in [ - "sum", - "sqrt", - "fill_constant", - "elementwise_max", - "elementwise_div", - ]: + if op.type in _g_gradient_clip_ops: # complete op dist_attr with global world ranks op_dist_attr = OperatorDistAttr() op_dist_attr.process_mesh = ProcessMesh(world_ranks) diff --git a/python/paddle/distributed/auto_parallel/static/dist_tensor.py b/python/paddle/distributed/auto_parallel/static/dist_tensor.py index 32a4f43434118..3a3d3ba7a7b95 100644 --- a/python/paddle/distributed/auto_parallel/static/dist_tensor.py +++ b/python/paddle/distributed/auto_parallel/static/dist_tensor.py @@ -42,9 +42,7 @@ def _validate_sizes_and_dist_attr( and all(isinstance(x, int) and x >= 0 for x in sizes) ): raise ValueError( - "The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}".format( - sizes - ) + f"The sizes must be list or tuple and item in sizes must be non-negative integer, but got {sizes}" ) if not ( isinstance(dims_mapping, (list, tuple)) diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index 9043b43ee9d4a..fa48521939d72 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -1132,43 +1132,69 @@ def evaluate( else: self._switch_mode(self._mode) - micro_batch_size = self._validate_batch_size(batch_size) - valid_dataloader = self._prepare_dataloader_from_generator( - dataset=valid_data, - capacity=70, - iterable=False, - batch_size=micro_batch_size, - steps_per_epoch=steps, - collate_fn=collate_fn, - ) + if auto_utils.use_new_executor(): + local_batch_size = self._validate_batch_size(batch_size) + valid_dataloader = self._prepare_dataloader( + valid_data, + return_list=False, + batch_size=local_batch_size, + collate_fn=collate_fn, + ) + steps_per_epoch = len(valid_dataloader) if steps is None else steps + else: + micro_batch_size = self._validate_batch_size(batch_size) + valid_dataloader = self._prepare_dataloader_from_generator( + dataset=valid_data, + capacity=70, + iterable=False, + batch_size=micro_batch_size, + steps_per_epoch=steps, + collate_fn=collate_fn, + ) + steps_per_epoch = valid_dataloader._steps + local_batch_size = micro_batch_size + if self._strategy.pipeline.enable: + local_batch_size = micro_batch_size * self._acc_steps fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) cbks = config_callbacks( callbacks, engine=self, - batch_size=micro_batch_size, + batch_size=local_batch_size, log_freq=log_freq, verbose=verbose, metrics=self._metrics_name(), ) - eval_steps = valid_dataloader._steps + eval_steps = steps_per_epoch cbks.on_begin( 'eval', {'steps': eval_steps, 'metrics': self._metrics_name()} ) logs = {} - for step, _ in enumerate(valid_dataloader): - cbks.on_batch_begin('eval', step, logs) + for step, batch in enumerate(valid_dataloader): + if auto_utils.use_new_executor(): + batches = self._validate_batch(batch) + else: + batches = [{}] + try: - outs = self._executor.run( - self.main_program, - fetch_list=fetch_names, - use_program_cache=self._strategy.use_cache, - return_numpy=self._strategy.return_numpy, - ) + for micro_batch in batches: + cbks.on_batch_begin('eval', step, logs) + outs = self._executor.run( + self.main_program, + feed=micro_batch, + fetch_list=fetch_names, + use_program_cache=self._strategy.use_cache, + return_numpy=self._strategy.return_numpy, + ) except core.EOFException: break + + if steps_per_epoch and step >= steps_per_epoch: + if not auto_utils.use_new_executor(): + valid_dataloader._reset() + break logs = self._prepare_logger( outs, None, step, None, fetch_names, fetch_indices, self._mode ) @@ -1240,34 +1266,57 @@ def predict( else: self._switch_mode(self._mode) - micro_batch_size = self._validate_batch_size(batch_size) - test_dataloader = self._prepare_dataloader_from_generator( - dataset=test_data, - capacity=70, - iterable=False, - batch_size=micro_batch_size, - steps_per_epoch=steps, - collate_fn=collate_fn, - ) + if auto_utils.use_new_executor(): + local_batch_size = self._validate_batch_size(batch_size) + test_dataloader = self._prepare_dataloader( + test_data, + return_list=False, + batch_size=local_batch_size, + collate_fn=collate_fn, + ) + steps_per_epoch = len(test_dataloader) if steps is None else steps + else: + micro_batch_size = self._validate_batch_size(batch_size) + test_dataloader = self._prepare_dataloader_from_generator( + dataset=test_data, + capacity=70, + iterable=False, + batch_size=micro_batch_size, + steps_per_epoch=steps, + collate_fn=collate_fn, + ) + steps_per_epoch = test_dataloader._steps fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) outputs = [] cbks = config_callbacks(callbacks, engine=self, verbose=verbose) - test_steps = test_dataloader._steps + test_steps = steps_per_epoch cbks.on_begin('predict', {'steps': test_steps}) logs = {} - for step, _ in enumerate(test_dataloader): - cbks.on_batch_begin('predict', step, logs) + for step, batch in enumerate(test_dataloader): + if auto_utils.use_new_executor(): + batches = self._validate_batch(batch) + else: + batches = [{}] + try: - outs = self._executor.run( - self.main_program, - fetch_list=fetch_names, - use_program_cache=self._strategy.use_cache, - return_numpy=self._strategy.return_numpy, - ) + for micro_batch in batches: + cbks.on_batch_begin('predict', step, logs) + outs = self._executor.run( + self.main_program, + feed=micro_batch, + fetch_list=fetch_names, + use_program_cache=self._strategy.use_cache, + return_numpy=self._strategy.return_numpy, + ) except core.EOFException: break + + if steps_per_epoch and step >= steps_per_epoch: + if not auto_utils.use_new_executor(): + test_dataloader._reset() + break logs = self._prepare_logger( outs, None, step, None, fetch_names, fetch_indices, self._mode ) @@ -1281,7 +1330,7 @@ def dataloader( dataset, batch_size=1, shuffle=False, - drop_last=False, + drop_last=True, collate_fn=None, num_workers=0, use_buffer_reader=True, @@ -1451,7 +1500,7 @@ def _prepare_dataloader( return_list=True, batch_size=1, shuffle=False, - drop_last=False, + drop_last=True, collate_fn=None, num_workers=0, use_buffer_reader=True, @@ -1630,9 +1679,7 @@ def _validate_spec(self, specs): ) if spec.name is None: raise ValueError( - "Requires Input[{}].name != None, but receive `None` with {}.".format( - i, spec - ) + f"Requires Input[{i}].name != None, but receive `None` with {spec}." ) if self._acc_steps > 1: shape = list(spec.shape) diff --git a/python/paddle/distributed/auto_parallel/static/helper.py b/python/paddle/distributed/auto_parallel/static/helper.py index f705ee4968848..6fe6700b996ad 100644 --- a/python/paddle/distributed/auto_parallel/static/helper.py +++ b/python/paddle/distributed/auto_parallel/static/helper.py @@ -235,7 +235,9 @@ def build_program(self, mode): self._logger.info("start to build program for mode = %s." % mode) input_spec = [self.inputs_spec, self.labels_spec] - static_func = to_static(self.static_func(), input_spec=input_spec) + static_func = to_static( + self.static_func(), input_spec=input_spec, full_graph=True + ) func_name = '_' + mode setattr(self.proxy_layer, func_name, static_func) diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py b/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py index 71f72defcd462..913e1340b32f0 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py @@ -130,9 +130,7 @@ def forward(ctx, *args, **kwargs): and src_op.attr("seed") ): _logger.info( - "Auto Parallel Random Control Skipped Since manul seed is set by user: {}".format( - src_op - ) + f"Auto Parallel Random Control Skipped Since manul seed is set by user: {src_op}" ) elif rank_id not in op_dist_attr.process_mesh.process_ids: pass @@ -163,9 +161,7 @@ def forward(ctx, *args, **kwargs): pre_op._set_attr("force_cpu", True) else: _logger.info( - "Auto Parallel Random Control Skipped Since manul seed is set by user: {}".format( - src_op - ) + f"Auto Parallel Random Control Skipped Since manul seed is set by user: {src_op}" ) else: # determinate rng diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py b/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py index d83beb82cd12a..841dc0a587044 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py @@ -18,7 +18,7 @@ register_distributed_operator_impl, register_distributed_operator_impl_container, ) -from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0 +from .dist_eltwise import DistributedElementwiseImpl0 class DistributedFlashAttn(DistributedOperatorImplContainer): @@ -30,6 +30,7 @@ def __init__(self, op_type): # Dist FlashAttn with Random Control +# NOTE(zhiqiu): trick implementation, copy dist_attr of q,k,v to out class DistributedFlashAttnImpl0(DistributedElementwiseImpl0): def __init__(self, name): super().__init__(name) @@ -83,12 +84,12 @@ def forward(ctx, *args, **kwargs): src_op._set_attr('rng_name', rng_name) - DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + DistributedElementwiseImpl0.forward(ctx, *args, **kwargs) @staticmethod def backward(ctx, *args, **kwargs): # dropout backward is deterministic by mask, and not need for random state control - DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + DistributedElementwiseImpl0.backward(ctx, *args, **kwargs) register_distributed_operator_impl( diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py b/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py index 5f2186575c24e..44a99efcebfe1 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py @@ -87,9 +87,7 @@ def forward(ctx, *args, **kwargs): and src_op.attr("seed") ): _logger.info( - "Auto Parallel Random Control Skipped Since manul seed is set by user: {}".format( - src_op - ) + f"Auto Parallel Random Control Skipped Since manul seed is set by user: {src_op}" ) elif rank_id not in op_dist_attr.process_mesh.process_ids: pass @@ -120,9 +118,7 @@ def forward(ctx, *args, **kwargs): pre_op._set_attr("force_cpu", True) else: _logger.info( - "Auto Parallel Random Control Skipped Since manul seed is set by user: {}".format( - src_op - ) + f"Auto Parallel Random Control Skipped Since manul seed is set by user: {src_op}" ) else: # determinate rng diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index 6f0a1db1a3bff..b59cfea194551 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -18,6 +18,7 @@ import time from paddle.distributed.passes import PassManager, new_pass +from paddle.framework import get_flags from paddle.static import append_backward, program_guard from paddle.utils import unique_name @@ -28,6 +29,12 @@ from .reshard import Resharder from .utils import get_pp_stage, is_sequential_run, use_new_executor +NEW_IR_PASS = [ + 'fused_gemm_epilogue_pass', + 'fused_linear_param_grad_add_pass', + 'fused_dropout_add_pass', +] + class Parallelizer: def __init__(self, mode, completer, dist_context): @@ -355,14 +362,18 @@ def _apply_post_optimization( ) params_grads = self._pass_context.get_attr("params_grads") - mp_async_allreduce_in_backward = os.getenv( - "FLAGS_mp_async_allreduce_in_backward" - ) in [1, "1", True, "True"] - if mp_async_allreduce_in_backward: - column_parallel_linear_backward_overlapping_pass = new_pass( - "column_parallel_linear_backward_overlapping", {} + if self._strategy.mp_optimization.allreduce_matmul_grad_overlapping: + if int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) != 1: + self._logger.warning( + "You set mp_optimization.allreduce_matmul_grad_overlapping=True, but you did not set environment " + "variable CUDA_DEVICE_MAX_CONNECTIONS=1, which may leads to performance " + "loss. Try to export CUDA_DEVICE_MAX_CONNECTIONS=1 for better performance." + ) + + allreduce_matmul_grad_overlapping_pass = new_pass( + "allreduce_matmul_grad_overlapping", {} ) - column_parallel_linear_backward_overlapping_pass.apply( + allreduce_matmul_grad_overlapping_pass.apply( [main_program], [startup_program], self._pass_context ) @@ -419,21 +430,44 @@ def _apply_post_optimization( [main_program], [startup_program], self._pass_context ) + enable_ir = get_flags("FLAGS_enable_pir_in_executor")[ + 'FLAGS_enable_pir_in_executor' + ] + ir_pass_list = [] if self.is_train and self._strategy.fused_passes.enable: if len(self._strategy.fused_passes.fused_passes_list) > 0: new_pass_list = [] - for op in self._strategy.fused_passes.fused_passes_list: - new_pass_list.append(new_pass(op)) + for p in self._strategy.fused_passes.fused_passes_list: + if p in NEW_IR_PASS and enable_ir: + ir_pass_list.append(p) + else: + new_pass_list.append(new_pass(p)) pass_manager = PassManager(new_pass_list) pass_manager.apply([main_program], [startup_program]) + main_program._pass_opt = {} + main_program._pass_opt['pass_list'] = ir_pass_list + if ( self.is_train and self._strategy.pipeline.enable and use_new_executor() ): + enable_send_recv_overlap = ( + self._strategy.pipeline.enable_send_recv_overlap + ) + if ( + enable_send_recv_overlap + and int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) != 1 + ): + self._logger.warning( + "You set pipeline.enable_send_recv_overlap=True, but you did not set environment " + "variable CUDA_DEVICE_MAX_CONNECTIONS=1, which may leads to performance " + "loss. Try to export CUDA_DEVICE_MAX_CONNECTIONS=1 for better performance." + ) main_program._pipeline_opt = {} main_program._pipeline_opt["standalone_opt"] = { + "enable_send_recv_overlap": enable_send_recv_overlap, "schedule_mode": self._strategy.pipeline.schedule_mode, "num_micro_batches": self._strategy.pipeline.accumulate_steps, "pp_degree": len(self._dist_context.process_meshes), diff --git a/python/paddle/distributed/auto_parallel/static/partitioner.py b/python/paddle/distributed/auto_parallel/static/partitioner.py index b00baf32ec0fe..4f976ca661280 100644 --- a/python/paddle/distributed/auto_parallel/static/partitioner.py +++ b/python/paddle/distributed/auto_parallel/static/partitioner.py @@ -345,9 +345,7 @@ def partition_block(self, ref_block, target_block): ) else: raise NotImplementedError( - "partitioner only support forward and backward, optimize ops, but got {}".format( - str(op) - ) + f"partitioner only support forward and backward, optimize ops, but got {str(op)}" ) def _is_valid_annotated_program(self, program): diff --git a/python/paddle/distributed/auto_parallel/static/process_group.py b/python/paddle/distributed/auto_parallel/static/process_group.py index df881be1a31e3..015e5c719caba 100644 --- a/python/paddle/distributed/auto_parallel/static/process_group.py +++ b/python/paddle/distributed/auto_parallel/static/process_group.py @@ -13,7 +13,6 @@ # limitations under the License import hashlib -import os from collections import OrderedDict import paddle @@ -63,9 +62,9 @@ def new_process_group( global _g_process_group_map if not force_new_group: # A key constructed from ranks is used for avoiding duplication - new_key = ''.join(map(str, ranks)) + new_key = '_'.join(map(str, ranks)) for pg_id, pg in _g_process_group_map.items(): - cur_key = ''.join(map(str, pg.ranks)) + cur_key = '_'.join(map(str, pg.ranks)) if pg_id != 0 and new_key == cur_key: return pg # If not matching the existing one, construct a new process group @@ -158,10 +157,10 @@ def instantiate(self): strategy.nrings = 1 if core.is_compiled_with_cuda(): place = core.CUDAPlace(genv.device_id) - use_new_comm = os.getenv( - "FLAGS_dynamic_static_unified_comm", "0" - ) - if use_new_comm in ["1", "True", "true"]: + use_new_comm = paddle.get_flags( + "FLAGS_dynamic_static_unified_comm" + )["FLAGS_dynamic_static_unified_comm"] + if use_new_comm: store = core.create_or_get_global_tcp_store() endpoints_str = "" for endpoint in strategy.trainer_endpoints: diff --git a/python/paddle/distributed/auto_parallel/static/reshard.py b/python/paddle/distributed/auto_parallel/static/reshard.py index cf1ed597536e3..efa2e9663119f 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard.py +++ b/python/paddle/distributed/auto_parallel/static/reshard.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License + import copy from collections import OrderedDict from functools import reduce @@ -33,17 +34,10 @@ from .dist_attribute import TensorDistAttr from .dist_context import DistributedContext from .process_group import new_process_group -from .utils import is_gradient_clip_op +from .utils import _g_gradient_clip_ops, is_gradient_clip_op, is_optimize_op # NOTE: If op in _g_special_ops or _g_gradient_clip_ops, it will not be resharded. _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] -_g_gradient_clip_ops = [ - "sum", - "sqrt", - "fill_constant", - "elementwise_max", - "elementwise_div", -] _g_subblock_ops = ["while", "conditional_block"] @@ -62,6 +56,26 @@ def get_var_with_recursion(var_name, block, program): return var +class EndOpDesc: + """ + Describe to end reshard parse process. + It is supposed to contain a list of variables which are the outputs of one reshard process. + + Args: + vars (list): a list of variables. + """ + + def __init__(self, vars): + self._vars = vars + + @property + def vars(self): + return self._vars + + def __repr__(self): + return f"End vars : {self._vars}." + + class AllGatherOpDesc: """ Describe the allgather op in the reshard phase. @@ -72,11 +86,12 @@ class AllGatherOpDesc: is_bool (bool): Whether allgather bool data. Default: False. """ - def __init__(self, group, shape, is_bool=False): + def __init__(self, group, shape, is_bool=False, need_split=True): self._group = group self._desc = "all_gather" self._shape = shape self._is_bool = is_bool + self._need_split = need_split @property def is_bool(self): @@ -94,8 +109,12 @@ def desc(self): def shape(self): return self._shape + @property + def need_split(self): + return self._need_split + def __repr__(self): - return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}." + return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}, need_split: {self._need_split}." class AllGatherConcatOpDesc: @@ -605,7 +624,7 @@ def insert_fill_constant_op(block, idx, op_role, shape): return out @staticmethod - def insert_allgather_op(block, idx, tensor, ranks, op_role): + def insert_allgather_op(block, idx, tensor, ranks, op_role, need_split): """Insert allgather op into block at the given index.""" tensor_list = [] group = new_process_group(ranks) @@ -643,11 +662,14 @@ def insert_allgather_op(block, idx, tensor, ranks, op_role): idx_offset += 1 # insert split op - split_out = Inserter.insert_split_op( - block, idx + idx_offset, allgather_out, group.nranks, op_role - ) - idx_offset += 1 - tensor_list.extend(split_out) + if need_split: + split_out = Inserter.insert_split_op( + block, idx + idx_offset, allgather_out, group.nranks, op_role + ) + idx_offset += 1 + tensor_list.extend(split_out) + else: + tensor_list.extend([allgather_out]) return tensor_list, idx_offset @staticmethod @@ -1269,7 +1291,7 @@ def is_unshard(self, dims_mapping): return True def is_special_op(self, op): - global _g_special_ops, _g_gradient_clip_ops + global _g_special_ops if op.type in _g_special_ops: return True if is_gradient_clip_op(op) and op.type in _g_gradient_clip_ops: @@ -1721,6 +1743,21 @@ def find_op_desc_seq(self, dist_tensor, dist_attr, serial=False): group=group, shape=allgather_shape ) ] + # optimization: [sharded, any x n] -> [unsharded, any x n], only need one allgather and no split or concat anymore. + elif ( + target_dims_mapping[1:] == source_dims_mapping[1:] + and target_dims_mapping[0] == -1 + and source_dims_mapping[0] != -1 + ): + op_desc_seq[process] = [ + AllGatherOpDesc( + group=min_comm_group, + shape=allgather_shape, + is_bool=(source_tensor.dtype == paddle.bool), + need_split=False, + ), + EndOpDesc(None), + ] else: op_desc_seq[process] = ( [ @@ -1786,6 +1823,19 @@ def parse_op_desc( source_tensor = get_var_with_recursion( var_name, block, self.auto_parallel_main_prog ) + + def is_grad(name): + return name.endswith('GRAD') + + # all op that generate grad is marked as OpRole.Backward + op_role = ( + OpRole.Backward + if is_optimize_op(reshard_op) and is_grad(var_name) + else reshard_op.attr('op_role') + ) + + # a Hack to send output vars from allgather_op to end_op + end_vars = None for op_desc in op_desc_list: if isinstance(op_desc, AllGatherOpDesc): if var_name not in self.has_allgather.keys(): @@ -1799,7 +1849,7 @@ def parse_op_desc( block, idx, source_tensor, - reshard_op.attr('op_role'), + op_role, paddle.int64, ) tensor_list, idx_offset = Inserter.insert_allgather_op( @@ -1807,7 +1857,8 @@ def parse_op_desc( idx + 1, out_cast, op_desc.group, - reshard_op.attr('op_role'), + op_role, + need_split=op_desc.need_split, ) idx += idx_offset tensor_name_list = [] @@ -1816,7 +1867,7 @@ def parse_op_desc( block, idx, var, - reshard_op.attr('op_role'), + op_role, paddle.bool, ) tensor_name_list.append(out_cast.name) @@ -1830,8 +1881,11 @@ def parse_op_desc( idx, source_tensor, op_desc.group, - reshard_op.attr('op_role'), + op_role, + need_split=op_desc.need_split, ) + if idx_offset == 1: + end_vars = tensor_list idx += idx_offset tensor_name_list = [var.name for var in tensor_list] self.has_allgather[var_name].append( @@ -1862,7 +1916,7 @@ def parse_op_desc( block, idx, source_tensor, - reshard_op.attr('op_role'), + op_role, paddle.int64, ) Inserter.insert_send_op( @@ -1871,7 +1925,7 @@ def parse_op_desc( out_cast, op_desc.src, op_desc.dst, - reshard_op.attr('op_role'), + op_role, ) idx += 2 else: @@ -1881,7 +1935,7 @@ def parse_op_desc( source_tensor, op_desc.src, op_desc.dst, - reshard_op.attr('op_role'), + op_role, ) idx += 1 self.has_sent[var_name].append(op_desc.dst) @@ -1909,13 +1963,13 @@ def parse_op_desc( recv_tensor, op_desc.src, op_desc.dst, - reshard_op.attr('op_role'), + op_role, ) out_cast = Inserter.insert_cast_op( block, idx + 1, recv_tensor, - reshard_op.attr('op_role'), + op_role, paddle.bool, ) tensor_list.append(out_cast) @@ -1935,7 +1989,7 @@ def parse_op_desc( recv_tensor, op_desc.src, op_desc.dst, - reshard_op.attr('op_role'), + op_role, ) # for lod tensor, need reset lod after received @@ -1958,7 +2012,7 @@ def parse_op_desc( idx + 1, recv_tensor, tmp_var, - reshard_op.attr('op_role'), + op_role, ) ) tensor_list.append(reset_lod_out) @@ -1988,11 +2042,13 @@ def parse_op_desc( partition_index_list[index], block, idx_list, - reshard_op.attr('op_role'), + op_role, ) idx = idx_list[0] - elif isinstance(op_desc, (SliceOpDesc, AllGatherConcatOpDesc)): + elif isinstance( + op_desc, (SliceOpDesc, AllGatherConcatOpDesc, EndOpDesc) + ): target_tensor = None if isinstance(op_desc, SliceOpDesc): assert ( @@ -2013,16 +2069,20 @@ def parse_op_desc( ends=op_desc.ends, axes=op_desc.axes, new_var_name=new_name, - op_role=reshard_op.attr('op_role'), + op_role=op_role, ) - else: + elif isinstance(op_desc, AllGatherConcatOpDesc): target_tensor = Inserter.insert_c_concat_op( block, idx, source_tensor, op_desc.group, - reshard_op.attr('op_role'), + op_role, ) + else: + assert isinstance(op_desc, EndOpDesc) + assert len(end_vars) == 1 + target_tensor = end_vars[0] assert target_tensor is not None process_mesh = dist_attr[0] diff --git a/python/paddle/distributed/auto_parallel/static/tuner/rule_based_tuner.py b/python/paddle/distributed/auto_parallel/static/tuner/rule_based_tuner.py index 07d98d67226d7..95c258a41c6d5 100644 --- a/python/paddle/distributed/auto_parallel/static/tuner/rule_based_tuner.py +++ b/python/paddle/distributed/auto_parallel/static/tuner/rule_based_tuner.py @@ -54,6 +54,7 @@ from ....utils.log_utils import get_logger from ..graph import Graph +from ..utils import _g_gradient_clip_ops _PATTERNS = {} @@ -1644,13 +1645,7 @@ def _complete_sub_update_program(self, sub_program_dist_context): op = ops[idx] if int(op.attr('op_role')) == int(OpRole.Optimize): if is_gradient_clip_op(op): - if op.type in [ - "sum", - "sqrt", - "fill_constant", - "elementwise_max", - "elementwise_div", - ]: + if op.type in _g_gradient_clip_ops: op_dist_attr = OperatorDistAttr() op_dist_attr.process_mesh = world_ranks for in_name in op.input_arg_names: diff --git a/python/paddle/distributed/auto_parallel/static/tuner/tunable_space.py b/python/paddle/distributed/auto_parallel/static/tuner/tunable_space.py index 84f1e8924b60a..02fd3b54d0e0b 100644 --- a/python/paddle/distributed/auto_parallel/static/tuner/tunable_space.py +++ b/python/paddle/distributed/auto_parallel/static/tuner/tunable_space.py @@ -143,9 +143,7 @@ def _deserialize_tunable_variable(state): or "state" not in state ): raise ValueError( - "Expect state to be a python dict containing class_name and state as keys, but found {}".format( - state - ) + f"Expect state to be a python dict containing class_name and state as keys, but found {state}" ) cls_name = state["class_name"] diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index da57e126058a5..a4a76963dbed7 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -42,6 +42,15 @@ ] __not_naive_data_parallel_op__ = ["expand_v2"] +_g_gradient_clip_ops = [ + "sum", + "sqrt", + "fill_constant", + "elementwise_max", + "elementwise_div", + "stack", + "reduce_sum", +] def get_logger(log_level, name="auto_parallel"): @@ -1823,7 +1832,15 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank): def is_recompute_op(op): - return op.has_attr('op_namescope') and "/auto_parallel/rc" in op.attr( + return ( + op.has_attr('op_namescope') + and "/auto_parallel/rc" in op.attr('op_namescope') + and 'exclude_rc' not in op.attr('op_namescope') + ) + + +def is_recompute_exclude_op(op): + return op.has_attr('op_namescope') and 'exclude_rc' in op.attr( 'op_namescope' ) @@ -1931,6 +1948,10 @@ def validate_opt(optimizer): if optimizer is not None: optimizer._parameter_list = None optimizer._param_groups = None + if optimizer._grad_clip and isinstance( + optimizer._grad_clip, paddle.nn.ClipGradByGlobalNorm + ): + optimizer._grad_clip._async_add_n = True return optimizer diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index 1df4663b4fed5..958d7dc565304 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -136,6 +136,12 @@ def __init__(self, config_dict=None): super().__init__(category, config_dict) +class MPOptimizationConfig(BaseConfig): + def __init__(self, config_dict=None): + category = constants.MP_OPTIMIZATION + super().__init__(category, config_dict) + + class Strategy(BaseConfig): """ The `Strategy` object is used to configure the parallelization and optimization behaviors. @@ -214,3 +220,6 @@ def __init__(self, config=None): config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None) self.dp_optimization = DPOptimizationConfig(config_dict) + + config_dict = self._config_dict.get(constants.MP_OPTIMIZATION, None) + self.mp_optimization = MPOptimizationConfig(config_dict) diff --git a/python/paddle/distributed/auto_tuner/cost_model.py b/python/paddle/distributed/auto_tuner/cost_model.py new file mode 100644 index 0000000000000..53a4fdae793fa --- /dev/null +++ b/python/paddle/distributed/auto_tuner/cost_model.py @@ -0,0 +1,143 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def all_params(mp, pp, sharding, h, l, V): + # TODO: TBD - add some fixed structure models. + return 1 + + +def full_recompute_acts(mp, pp, s, b, h, l): + # TODO: TBD - add some fixed structure models. + return 1 + + +def all_acts(mp, pp, s, b, h, l, a): + # TODO: TBD - add some fixed structure models. + return 1 + + +def to_gb(p): + return p / (2**30) + + +def get_mem(total_cards, parallel_cfg, l, h, a, V, s, gbs): + """Estimate the memory of model unser parallel strategy.""" + sharding = parallel_cfg["sharding_degree"] + mp = parallel_cfg["mp_degree"] + b = parallel_cfg["micro_batch_size"] + pp = parallel_cfg["pp_degree"] + vpp = parallel_cfg["vpp_degree"] + use_recompute = parallel_cfg["use_recompute"] + + sep = 1 + + lbs = int(gbs / sharding / s) + lbs = int(lbs / pp) * pp + assert s % sep == 0 + s_sep = s // sep + assert a % (sep * mp) == 0, f'{a} vs {sep * mp}' + + vpp_ratio = 1 + if vpp > 1: + assert l % (pp * vpp) == 0 + vpp_ratio = 1 + (pp - 1) / (pp * vpp) + + params = to_gb(all_params(mp, pp, sharding, h, l, V)) + + acts = 0 + assert l % pp == 0 + + if use_recompute: + acts = to_gb(full_recompute_acts(mp, pp, s_sep, b, h, l)) * vpp_ratio + else: + acts = to_gb(all_acts(mp, pp, s, b, h, l, a)) * vpp_ratio + assert acts > 0 + + peak_mem = params + acts + return peak_mem + + +def divisor(num, reverse=False): + """Get the divisor of a given number.""" + results = set() + i = 1 + mid = num // 2 + 1 + while i < mid: + if num % i == 0: + results.add(i) + results.add(num // i) + i += 1 + results = list(results) + return sorted(results, reverse=reverse) + + +def get_not_oom_cfgs(cfgs, tuner_cfg): + """Get not OOM parallel strategies.""" + total_cards, l, h, a, V, s, gbs, per_card_memory = ( + tuner_cfg["estimated_num_gpus"], + tuner_cfg["model_cfg"]["num_layers"], + tuner_cfg["model_cfg"]["hidden_size"], + tuner_cfg["model_cfg"]["num_attention_heads"], + tuner_cfg["model_cfg"]["vocab_size"], + tuner_cfg["model_cfg"]["seq_length"], + tuner_cfg["model_cfg"]["global_batch_size"], + tuner_cfg.get("per_card_memory", 80), + ) + pruned_cfgs = [] + for cfg in cfgs: + mp = cfg["mp_degree"] + sharding = cfg["sharding_degree"] + mbs = cfg["micro_batch_size"] + pp = cfg["pp_degree"] + vpp = cfg["vpp_degree"] + dp = cfg["dp_degree"] + use_recompute = cfg["use_recompute"] + + if mp * sharding * pp * dp != total_cards: + continue + if gbs % sharding != 0: + continue + if gbs // sharding % dp != 0: + continue + if gbs // sharding // dp % mbs != 0: + continue + if l % pp != 0: + continue + if l // pp % vpp != 0: + continue + if vpp != 1 and pp <= 2: + continue + if a % mp != 0 or V % mp != 0 or h % mp != 0: + continue + + pruned_cfgs.append(cfg) + valid_cfgs = [] + for cfg in pruned_cfgs: + mem = get_mem(total_cards, cfg, l, h, a, V, s, gbs) + # TODO: Uncomment when it is actually implemented. + # if ( + # mem < per_card_memory + # and mem + # > tuner_cfg.get( + # "search_algo", {"name": "dp_estimation", "threshold": 0.7} + # ).get("threshold", 0.7) + # * per_card_memory + # ): + # cfg["memory_cost"] = mem + # valid_cfgs.append(cfg) + cfg["memory_cost"] = mem + valid_cfgs.append(cfg) + assert valid_cfgs + return valid_cfgs diff --git a/python/paddle/distributed/auto_tuner/prune.py b/python/paddle/distributed/auto_tuner/prune.py index abae3f606fee1..976089f9d05f2 100644 --- a/python/paddle/distributed/auto_tuner/prune.py +++ b/python/paddle/distributed/auto_tuner/prune.py @@ -85,10 +85,6 @@ def prune_by_mp(tuner_cfg, cur_cfg, history_cfgs=None): if mp_degree not in mp_degree_candidates: return True - # prune default candidates - if mp_degree > 8: - return True - return False diff --git a/python/paddle/distributed/auto_tuner/recorder.py b/python/paddle/distributed/auto_tuner/recorder.py index 71c1b08ff3ecd..11517da529f4f 100644 --- a/python/paddle/distributed/auto_tuner/recorder.py +++ b/python/paddle/distributed/auto_tuner/recorder.py @@ -70,9 +70,8 @@ def get_best(self, metric, direction, mode=None) -> Tuple[dict, bool]: if first_few >= 5: break return (best_cfg, False) - if ( - isinstance(self.history[0]["max_mem_usage"], str) - or self.history[0]["time"] == -1 + if isinstance(self.history[0]["max_mem_usage"], str) or ( + "time" in self.history[0] and self.history[0]["time"] == -1 ): return (self.history[0], True) return (self.history[0], False) diff --git a/python/paddle/distributed/auto_tuner/search.py b/python/paddle/distributed/auto_tuner/search.py index 0e0114a5249f0..b788e538581f1 100644 --- a/python/paddle/distributed/auto_tuner/search.py +++ b/python/paddle/distributed/auto_tuner/search.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from .prune import _PRUNE_FUNC -from .utils import gbs_search_all, search_all +from .utils import gbs_search_all, search_all, search_by_dp_estimation class SearchAlgo(ABC): @@ -54,6 +54,34 @@ def search_once(self, history_cfgs): return new_cfg +class DpEstimationSearch(SearchAlgo): + def __init__(self, tuner_cfg): + super().__init__(tuner_cfg) + self.idx = 0 + self.all_tasks = search_by_dp_estimation(tuner_cfg) + assert len(self.all_tasks) > 0, "Unable to perform this search." + # change global_batch_size and dp_degree + tuner_cfg["model_cfg"]["global_batch_size"] = ( + tuner_cfg["model_cfg"]["global_batch_size"] + // self.all_tasks[0]["dp_degree"] + ) + for task in self.all_tasks: + task["estimated_dp_degree"] = task["dp_degree"] + task["dp_degree"] = 1 + + def search_once(self, history_cfgs): + new_cfg = None + stop = False + while not stop: + if self.idx < len(self.all_tasks): + new_cfg = self.all_tasks[self.idx] + self.idx += 1 + stop = not self.prune(self.tuner_cfg, new_cfg, history_cfgs) + else: + return None + return new_cfg + + class GBSSearch(SearchAlgo): def __init__(self, tuner_cfg): super().__init__(tuner_cfg) diff --git a/python/paddle/distributed/auto_tuner/tuner.py b/python/paddle/distributed/auto_tuner/tuner.py index bdc6bed5c6a08..9e693fcc3874f 100644 --- a/python/paddle/distributed/auto_tuner/tuner.py +++ b/python/paddle/distributed/auto_tuner/tuner.py @@ -29,13 +29,18 @@ def __init__(self, tuner_cfg): self.cur_task_id = 1 self.task_limit = tuner_cfg.get("task_limit", 100) - search_algo = tuner_cfg.get("search_algo", "grid") + search_algo = tuner_cfg.get("search_algo", {"name": "grid"})["name"] if search_algo == "grid": from .search import GridSearch tuner_cfg["candidates"] = default_candidates(tuner_cfg) self.algo = GridSearch(tuner_cfg) + elif search_algo == "dp_estimation": + from .search import DpEstimationSearch + + tuner_cfg["candidates"] = default_candidates(tuner_cfg) + self.algo = DpEstimationSearch(tuner_cfg) elif search_algo == "gbs": from .search import GBSSearch diff --git a/python/paddle/distributed/auto_tuner/utils.py b/python/paddle/distributed/auto_tuner/utils.py index 3f2dcf45fcd85..2928908750d92 100644 --- a/python/paddle/distributed/auto_tuner/utils.py +++ b/python/paddle/distributed/auto_tuner/utils.py @@ -112,8 +112,16 @@ def dist_degree(mode, num_gpus, num_nodes, tuner_cfg=None): def default_candidates(tuner_cfg): """Return the default candidates of every hyper param which user defined auto""" candidates = {} - num_gpus = tuner_cfg["num_gpus"] - num_nodes = tuner_cfg["nodes"] + num_gpus = ( + tuner_cfg["num_gpus"] + if "estimated_num_gpus" not in tuner_cfg + else tuner_cfg["estimated_num_gpus"] + ) + num_nodes = ( + tuner_cfg["nodes"] + if "estimated_num_gpus" not in tuner_cfg + else tuner_cfg["estimated_num_gpus"] // 8 + ) assert num_gpus > 0 if tuner_cfg.get("dp_degree", None) == "auto": @@ -210,7 +218,11 @@ def search_all(tuner_cfg): use_recompute_candidates = candidates["use_recompute"] recompute_granularity_candidates = candidates["recompute_granularity"] - num_gpus = tuner_cfg["num_gpus"] + num_gpus = ( + tuner_cfg["num_gpus"] + if "estimated_num_gpus" not in tuner_cfg + else tuner_cfg["estimated_num_gpus"] + ) valid_degrees = [] for mp_degree in mp_degree_candidates: @@ -294,6 +306,22 @@ def search_all(tuner_cfg): return new_all_cfgs +def search_by_dp_estimation(tuner_cfg): + from .cost_model import get_not_oom_cfgs + + all_cfgs = search_all(tuner_cfg) + not_oom_cfgs = get_not_oom_cfgs(all_cfgs, tuner_cfg) + num_gpus_per_dp_degree = tuner_cfg["num_gpus"] + estimated_dp_degree = ( + tuner_cfg["estimated_num_gpus"] // num_gpus_per_dp_degree + ) + result_cfgs = [] + for cfg in not_oom_cfgs: + if cfg["dp_degree"] == estimated_dp_degree: + result_cfgs.append(cfg) + return result_cfgs + + def gen_new_args(raw_args, cfg, tuner_cfg, run_best=False): """Generate new script args.""" @@ -309,6 +337,9 @@ def _gen_new_arg(arg, cmd, cfg, res_args, tuner_cfg): import json file_path = cmd[arg][0] + prefix = "" + if len(cmd[arg]) >= 3: + prefix = cmd[arg][2] try: with open(file_path, "r") as f: cmd_cfg = json.load(f) @@ -317,14 +348,28 @@ def _gen_new_arg(arg, cmd, cfg, res_args, tuner_cfg): "Please check your auto tuner json whether valid." ) keys = cmd[arg][1].split(".") + value = None for key in keys[: len(keys) - 1]: - cmd_cfg = cmd_cfg[key] - cmd_cfg[keys[-1]] = cfg[arg] + if not value: + value = cmd_cfg[key] + else: + value = value[key] + if value: + value[keys[-1]] = ( + prefix + str(cfg[arg]) if prefix else cfg[arg] + ) + else: + cmd_cfg[keys[-1]] = ( + prefix + str(cfg[arg]) if prefix else cfg[arg] + ) json.dump(cmd_cfg, open(cmd[arg][0], "w")) elif ".yaml" in cmd[arg][0]: import yaml file_path = cmd[arg][0] + prefix = "" + if len(cmd[arg]) >= 3: + prefix = cmd[arg][2] try: with open(file_path, "r") as f: cmd_cfg = yaml.safe_load(f) @@ -333,9 +378,20 @@ def _gen_new_arg(arg, cmd, cfg, res_args, tuner_cfg): "Please check your auto tuner json whether valid." ) keys = cmd[arg][1].split(".") + value = None for key in keys[: len(keys) - 1]: - cmd_cfg = cmd_cfg[key] - cmd_cfg[keys[-1]] = cfg[arg] + if not value: + value = cmd_cfg[key] + else: + value = value[key] + if value: + value[keys[-1]] = ( + prefix + str(cfg[arg]) if prefix else cfg[arg] + ) + else: + cmd_cfg[keys[-1]] = ( + prefix + str(cfg[arg]) if prefix else cfg[arg] + ) yaml.dump(cmd_cfg, open(cmd[arg][0], "w")) elif arg == "local_batch_size" and arg in cmd: local_batch_size = ( @@ -357,6 +413,9 @@ def _gen_new_arg(arg, cmd, cfg, res_args, tuner_cfg): import json file_path = cmd[arg][0] + prefix = "" + if len(cmd[arg]) >= 3: + prefix = cmd[arg][2] try: with open(file_path, "r") as f: cmd_cfg = json.load(f) @@ -365,14 +424,32 @@ def _gen_new_arg(arg, cmd, cfg, res_args, tuner_cfg): "Please check your auto tuner json whether valid." ) keys = cmd[arg][1].split(".") + value = None for key in keys[: len(keys) - 1]: - cmd_cfg = cmd_cfg[key] - cmd_cfg[keys[-1]] = local_batch_size + if not value: + value = cmd_cfg[key] + else: + value = value[key] + if value: + value[keys[-1]] = ( + prefix + str(local_batch_size) + if prefix + else local_batch_size + ) + else: + cmd_cfg[keys[-1]] = ( + prefix + str(local_batch_size) + if prefix + else local_batch_size + ) json.dump(cmd_cfg, open(cmd[arg][0], "w")) elif ".yaml" in cmd[arg][0]: import yaml file_path = cmd[arg][0] + prefix = "" + if len(cmd[arg]) >= 3: + prefix = cmd[arg][2] try: with open(file_path, "r") as f: cmd_cfg = yaml.safe_load(f) @@ -381,9 +458,24 @@ def _gen_new_arg(arg, cmd, cfg, res_args, tuner_cfg): "Please check your auto tuner json whether valid." ) keys = cmd[arg][1].split(".") + value = None for key in keys[: len(keys) - 1]: - cmd_cfg = cmd_cfg[key] - cmd_cfg[keys[-1]] = local_batch_size + if not value: + value = cmd_cfg[key] + else: + value = value[key] + if value: + value[keys[-1]] = ( + prefix + str(local_batch_size) + if prefix + else local_batch_size + ) + else: + cmd_cfg[keys[-1]] = ( + prefix + str(local_batch_size) + if prefix + else local_batch_size + ) yaml.dump(cmd_cfg, open(cmd[arg][0], "w")) elif arg == "gradient_accumulation_steps" and arg in cmd: @@ -413,6 +505,9 @@ def _gen_new_arg(arg, cmd, cfg, res_args, tuner_cfg): import json file_path = cmd[arg][0] + prefix = "" + if len(cmd[arg]) >= 3: + prefix = cmd[arg][2] try: with open(file_path, "r") as f: cmd_cfg = json.load(f) @@ -421,14 +516,32 @@ def _gen_new_arg(arg, cmd, cfg, res_args, tuner_cfg): "Please check your auto tuner json whether valid." ) keys = cmd[arg][1].split(".") + value = None for key in keys[: len(keys) - 1]: - cmd_cfg = cmd_cfg[key] - cmd_cfg[keys[-1]] = gradient_accumulation_steps + if not value: + value = cmd_cfg[key] + else: + value = value[key] + if value: + value[keys[-1]] = ( + prefix + str(gradient_accumulation_steps) + if prefix + else gradient_accumulation_steps + ) + else: + cmd_cfg[keys[-1]] = ( + prefix + str(gradient_accumulation_steps) + if prefix + else gradient_accumulation_steps + ) json.dump(cmd_cfg, open(cmd[arg][0], "w")) elif ".yaml" in cmd[arg][0]: import yaml file_path = cmd[arg][0] + prefix = "" + if len(cmd[arg]) >= 3: + prefix = cmd[arg][2] try: with open(file_path, "r") as f: cmd_cfg = yaml.safe_load(f) @@ -437,9 +550,24 @@ def _gen_new_arg(arg, cmd, cfg, res_args, tuner_cfg): "Please check your auto tuner json whether valid." ) keys = cmd[arg][1].split(".") + value = None for key in keys[: len(keys) - 1]: - cmd_cfg = cmd_cfg[key] - cmd_cfg[keys[-1]] = gradient_accumulation_steps + if not value: + value = cmd_cfg[key] + else: + value = value[key] + if value: + value[keys[-1]] = ( + prefix + str(gradient_accumulation_steps) + if prefix + else gradient_accumulation_steps + ) + else: + cmd_cfg[keys[-1]] = ( + prefix + str(gradient_accumulation_steps) + if prefix + else gradient_accumulation_steps + ) yaml.dump(cmd_cfg, open(cmd[arg][0], "w")) assert "run_cmd" in tuner_cfg @@ -477,9 +605,16 @@ def _gen_new_arg(arg, cmd, cfg, res_args, tuner_cfg): "Please check your auto tuner json whether valid." ) keys = cmd[arg][1].split(".") + value = None for key in keys[: len(keys) - 1]: - cmd_cfg = cmd_cfg[key] - cmd_cfg[keys[-1]] = cmd[arg][2] + if value: + value = value[key] + else: + value = cmd_cfg[key] + if value: + value[keys[-1]] = cmd[arg][2] + else: + cmd_cfg[keys[-1]] = cmd[arg][2] json.dump(cmd_cfg, open(cmd[arg][0], "w")) elif ".yaml" in cmd[arg][0]: import yaml @@ -493,9 +628,16 @@ def _gen_new_arg(arg, cmd, cfg, res_args, tuner_cfg): "Please check your auto tuner json whether valid." ) keys = cmd[arg][1].split(".") + value = None for key in keys[: len(keys) - 1]: - cmd_cfg = cmd_cfg[key] - cmd_cfg[keys[-1]] = cmd[arg][2] + if value: + value = cmd_cfg[key] + else: + value = value[key] + if value: + value[keys[-1]] = cmd[arg][2] + else: + cmd_cfg[keys[-1]] = cmd[arg][2] yaml.dump(cmd_cfg, open(cmd[arg][0], "w")) if tuner_cfg["run_cmd"].get("run_best_stage", None) and run_best: @@ -517,9 +659,16 @@ def _gen_new_arg(arg, cmd, cfg, res_args, tuner_cfg): "Please check your auto tuner json whether valid." ) keys = cmd[arg][1].split(".") + value = None for key in keys[: len(keys) - 1]: - cmd_cfg = cmd_cfg[key] - cmd_cfg[keys[-1]] = cmd[arg][2] + if value: + value = value[key] + else: + value = cmd_cfg[key] + if value: + value[keys[-1]] = cmd[arg][2] + else: + cmd_cfg[keys[-1]] = cmd[arg][2] json.dump(cmd_cfg, open(cmd[arg][0], "w")) elif ".yaml" in cmd[arg][0]: import yaml @@ -533,9 +682,16 @@ def _gen_new_arg(arg, cmd, cfg, res_args, tuner_cfg): "Please check your auto tuner json whether valid." ) keys = cmd[arg][1].split(".") + value = None for key in keys[: len(keys) - 1]: - cmd_cfg = cmd_cfg[key] - cmd_cfg[keys[-1]] = cmd[arg][2] + if value: + value = value[key] + else: + value = cmd_cfg[key] + if value: + value[keys[-1]] = cmd[arg][2] + else: + cmd_cfg[keys[-1]] = cmd[arg][2] yaml.dump(cmd_cfg, open(cmd[arg][0], "w")) return res_args diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index a2bac699bb542..66d82ae1a2914 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -153,8 +153,9 @@ def _new_process_group_impl( if backend == "gloo": pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id) elif backend == "nccl": - pg = core.ProcessGroupNCCL.create(store, rank, world_size, group_id) - + pg = core.ProcessGroupNCCL.create( + store, rank, world_size, group_id, genv.pg_timeout + ) elif backend == "xccl": pg = core.ProcessGroupCustom.create( store, genv.device_type, rank, world_size, group_id @@ -240,12 +241,6 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout): # TODO: The method below is a new method for group management, will replace the previous # three in the future. _add_new_group(group) - - # TODO(shenliang03): This is a temporary solution to solve the problem of - # hang caused by tcp - paddle.distributed.barrier(group=group) - if paddle.distributed.get_world_size() > 1: - paddle.distributed.barrier() return group if not backend: diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 4750c6bca66fc..32a8a242fb570 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -41,6 +41,23 @@ def __impl__(*args, **kwargs): is_strict_auto = wrap_decorator(__non_auto_func_called__) +def get_repeated_msg_dict(msg): + res_list = [] + for item in msg: + fields = item.DESCRIPTOR.fields + res_dict = {} + for f in fields: + v = getattr(item, f.name) + if ( + f.label + == google.protobuf.descriptor.FieldDescriptor.LABEL_REPEATED + ): + v = list(v) + res_dict[f.name] = v + res_list.append(res_dict) + return res_list + + def get_msg_dict(msg): res_dict = {} fields = msg.DESCRIPTOR.fields @@ -52,11 +69,40 @@ def get_msg_dict(msg): # I guess the type or value of protobuf item is NULL when # dealloc. if f.label == google.protobuf.descriptor.FieldDescriptor.LABEL_REPEATED: - v = list(v) + if ( + f.type + != google.protobuf.descriptor.FieldDescriptor.TYPE_MESSAGE + ): + v = list(v) + else: + v = get_repeated_msg_dict(v) res_dict[f.name] = v return res_dict +def assign_repeated_msg(msg, config): + for key in config: + new_item = msg.add() + fields = new_item.DESCRIPTOR.fields + for f in fields: + if key == f.name: + # LABEL_OPTIONAL = 1 + # LABEL_REPEATED = 3 + # LABEL_REQUIRED = 2 + if f.label == 3: + if config[f.name] is not None: + new_item = getattr(msg, f.name) + if ( + f.type + != google.protobuf.descriptor.FieldDescriptor.TYPE_MESSAGE + ): + new_item.extend(config[f.name]) + else: + assign_configs_value(new_item, config[f.name]) + elif f.label == 1 or f.label == 2: + setattr(msg, f.name, config[f.name]) + + def assign_configs_value(msg, config): fields = msg.DESCRIPTOR.fields for key in config: @@ -67,7 +113,15 @@ def assign_configs_value(msg, config): # LABEL_REQUIRED = 2 if f.label == 3: if config[f.name] is not None: - getattr(msg, f.name).extend(config[f.name]) + new_item = getattr(msg, f.name) + # deal with repeated message + if ( + f.type + != google.protobuf.descriptor.FieldDescriptor.TYPE_MESSAGE + ): + new_item.extend(config[f.name]) + else: + assign_repeated_msg(new_item, config[f.name]) elif f.label == 1 or f.label == 2: setattr(msg, f.name, config[f.name]) @@ -930,6 +984,8 @@ def amp_configs(self): use_pure_fp16(bool): Whether to use the pure fp16 training. Default False. + use_pure_bf16(bool): Whether to use the pure bf16 training. Default False. + use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program. Default True. Only takes effect when `use_pure_fp16` is turned on. diff --git a/python/paddle/distributed/fleet/base/meta_optimizer_factory.py b/python/paddle/distributed/fleet/base/meta_optimizer_factory.py index 2577df9380e38..84383000707b8 100755 --- a/python/paddle/distributed/fleet/base/meta_optimizer_factory.py +++ b/python/paddle/distributed/fleet/base/meta_optimizer_factory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..meta_optimizers import * # noqa: F401, F403 +from ..meta_optimizers import * # noqa: F403 __all__ = [] diff --git a/python/paddle/distributed/fleet/base/private_helper_function.py b/python/paddle/distributed/fleet/base/private_helper_function.py index c5199eb46a747..0da733c0f24c6 100644 --- a/python/paddle/distributed/fleet/base/private_helper_function.py +++ b/python/paddle/distributed/fleet/base/private_helper_function.py @@ -16,6 +16,8 @@ import time from contextlib import closing +import paddle + __all__ = [] @@ -33,6 +35,15 @@ def wait_server_ready(endpoints): >>> wait_server_ready(["127.0.0.1:8080", "127.0.0.1:8081"]) """ + try: + use_new_comm = paddle.get_flags("FLAGS_dynamic_static_unified_comm")[ + "FLAGS_dynamic_static_unified_comm" + ] + except: + use_new_comm = False + + if use_new_comm: + return assert not isinstance(endpoints, str) while True: all_ok = True diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index b6130b55bf673..bced953eff139 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -189,16 +189,27 @@ def __init__(self, topology): self._sep_parallel_id = self._get_sep_parallel_id() self.stage_id = self._get_pipe_parallel_id() - assert self._check_vaild_topo(), ( - "Here is an unreasonable topogy setting. world_size: {}, but" - "mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}, sep_num: {}".format( - self.nranks, - self._mp_degree, - self._sharding_degree, - self._pp_degree, - self._dp_degree, - self._sep_degree, - ) + assert ( + self._check_vaild_topo() + ), "mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}, sep_num: {}".format( + self.nranks, + self._mp_degree, + self._sharding_degree, + self._pp_degree, + self._dp_degree, + ) + + # create comm group for pipe parallel + self._pp_group, self._pp_comm_group = self._set_comm_group("pipe") + # NOTE(shenliang03): In pipeline parallel, we use batch_isend_irecv. + # if batch_isend_irecv is the first collective operation, all ranks of + # the pipeline group must participate in this call. In order to avoid + # this situation, we perform a collective communication in advance and + # create a communicator. + paddle.distributed.all_reduce( + paddle.zeros([1], dtype="int32"), + op=paddle.distributed.ReduceOp.SUM, + group=self._pp_comm_group, ) # create comm group for data parallel @@ -207,9 +218,6 @@ def __init__(self, topology): # create comm group for model parallel self._mp_group, self._mp_comm_group = self._set_comm_group("model") - # create comm group for pipe parallel - self._pp_group, self._pp_comm_group = self._set_comm_group("pipe") - # create comm group for sharding parallel self._sharding_group, self._sharding_comm_group = self._set_comm_group( "sharding" @@ -240,6 +248,11 @@ def __init__(self, topology): ["pipe", "model"] ) + ( + self.sharding_check_group, + self.sharding_check_comm_group, + ) = self._set_check_group("sharding") + # create p2p group self.is_first_stage = self.stage_id == 0 self.is_last_stage = self.stage_id == (self._pp_degree - 1) diff --git a/python/paddle/distributed/fleet/data_generator/data_generator.py b/python/paddle/distributed/fleet/data_generator/data_generator.py index cddba2ddee382..7963128f2c6d9 100644 --- a/python/paddle/distributed/fleet/data_generator/data_generator.py +++ b/python/paddle/distributed/fleet/data_generator/data_generator.py @@ -156,13 +156,13 @@ def generate_sample(self, line): Returns: Returns the data processed by the user. - The data format is list or tuple: + The data format is list or tuple: [(name, [feasign, ...]), ...] - or ((name, [feasign, ...]), ...) + or ((name, [feasign, ...]), ...) For example: [("words", [1926, 08, 17]), ("label", [1])] - or (("words", [1926, 08, 17]), ("label", [1])) + or (("words", [1926, 08, 17]), ("label", [1])) Note: The type of feasigns must be in int or float. Once the float diff --git a/python/paddle/distributed/fleet/dataset/dataset.py b/python/paddle/distributed/fleet/dataset/dataset.py index d0c7ca3b7b644..b0ca282eafdcb 100755 --- a/python/paddle/distributed/fleet/dataset/dataset.py +++ b/python/paddle/distributed/fleet/dataset/dataset.py @@ -1240,10 +1240,10 @@ def _set_fea_eval(self, record_candidate_size, fea_eval=True): Examples: .. code-block:: python - import paddle - paddle.enable_static() - dataset = paddle.distributed.InMemoryDataset() - dataset._set_fea_eval(1000000, True) + >>> import paddle + >>> paddle.enable_static() + >>> dataset = paddle.distributed.InMemoryDataset() + >>> dataset._set_fea_eval(1000000, True) """ if fea_eval: @@ -1299,11 +1299,10 @@ class QueueDataset(DatasetBase): QueueDataset, it will process data streamly. Examples: + .. code-block:: python - .. code-block:: python - - import paddle - dataset = paddle.distributed.QueueDataset() + >>> import paddle + >>> dataset = paddle.distributed.QueueDataset() """ @@ -1510,6 +1509,7 @@ def end_pass(self, need_save_delta): """ End Pass Notify BoxPS that current pass ended + Examples: .. code-block:: python @@ -1523,6 +1523,7 @@ def wait_preload_done(self): """ Wait async preload done Wait Until Feed Pass Done + Examples: .. code-block:: python @@ -1539,6 +1540,7 @@ def wait_preload_done(self): def load_into_memory(self): """ Load next pass into memory and notify boxps to fetch its emb from SSD + Examples: .. code-block:: python @@ -1555,6 +1557,7 @@ def load_into_memory(self): def preload_into_memory(self): """ Begin async preload next pass while current pass may be training + Examples: .. code-block:: python @@ -1588,11 +1591,13 @@ def slots_shuffle(self, slots): slots(list[string]): the set of slots(string) to do slots shuffle. Examples: - import paddle - dataset = paddle.distributed.fleet.BoxPSDataset() - dataset.set_merge_by_lineid() - #suppose there is a slot 0 - dataset.slots_shuffle(['0']) + .. code-block:: python + + >>> import paddle + >>> dataset = paddle.distributed.fleet.BoxPSDataset() + >>> dataset._set_merge_by_lineid() + >>> #suppose there is a slot 0 + >>> dataset.slots_shuffle(['0']) """ slots_set = set(slots) self.boxps.slots_shuffle(slots_set) diff --git a/python/paddle/distributed/fleet/elastic/manager.py b/python/paddle/distributed/fleet/elastic/manager.py index 6c3810f7aae74..153d2447abe1d 100644 --- a/python/paddle/distributed/fleet/elastic/manager.py +++ b/python/paddle/distributed/fleet/elastic/manager.py @@ -494,9 +494,9 @@ def _update_elastic_scale_out(self): if curr_host_port not in host_endpoints: host_endpoints.append(curr_host_port) - os.environ[ - 'PADDLE_TRAINER_ID' - ] = f'{host_endpoints.index(self.curr_host)}' + os.environ['PADDLE_TRAINER_ID'] = str( + host_endpoints.index(self.curr_host) + ) hosts = ','.join( [host_port.split(":")[0] for host_port in host_endpoints] ) @@ -547,9 +547,9 @@ def _update_elastic_scale_in(self): ) self.args.ips = hosts - os.environ[ - 'PADDLE_TRAINER_ID' - ] = f'{sorted_endpoints.index(self.curr_host)}' + os.environ['PADDLE_TRAINER_ID'] = str( + sorted_endpoints.index(self.curr_host) + ) os.environ['PADDLE_TRAINERS'] = hosts self.np = len(sorted_endpoints) os.environ['PADDLE_TRAINER_ENDPOINTS'] = ','.join(sorted_endpoints) diff --git a/python/paddle/distributed/fleet/fleet.py b/python/paddle/distributed/fleet/fleet.py index 5e90584b25b5e..f18f7aeb06876 100755 --- a/python/paddle/distributed/fleet/fleet.py +++ b/python/paddle/distributed/fleet/fleet.py @@ -105,54 +105,55 @@ class Fleet: Returns: Fleet: A Fleet instance - + Examples: .. code-block:: python :name: code-example1 - # Example1: for collective training - import paddle - paddle.enable_static() - import paddle.distributed.fleet as fleet + >>> # Example1: for collective training + >>> import paddle + >>> paddle.enable_static() + >>> import paddle.distributed.fleet as fleet - fleet.init(is_collective=True) + >>> fleet.init(is_collective=True) - strategy = fleet.DistributedStrategy() - optimizer = paddle.optimizer.SGD(learning_rate=0.001) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + >>> strategy = fleet.DistributedStrategy() + >>> linear = paddle.nn.Linear(10, 10) + >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters()) + >>> optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - # do distributed training + >>> # do distributed training .. code-block:: python :name: code-example2 - # Example2: for parameter server training - import paddle - paddle.enable_static() - import paddle.distributed.fleet as fleet - strategy = fleet.DistributedStrategy() - fleet.init(strategy=strategy) + >>> # Example2: for parameter server training + >>> import paddle + >>> paddle.enable_static() + >>> import paddle.distributed.fleet as fleet + >>> strategy = fleet.DistributedStrategy() + >>> fleet.init(strategy=strategy) - optimizer = paddle.optimizer.SGD(learning_rate=0.001) - optimizer = fleet.distributed_optimizer(optimizer) + >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001) + >>> optimizer = fleet.distributed_optimizer(optimizer) - if fleet.is_first_worker(): - print("this is first worker") + >>> if fleet.is_first_worker(): + ... print("this is first worker") - print("current node index: {}".format(fleet.worker_index())) - print("total number of worker num: {}".format(fleet.worker_num())) + >>> print("current node index: {}".format(fleet.worker_index())) + >>> print("total number of worker num: {}".format(fleet.worker_num())) - if fleet.is_worker(): - print("this is worker") - print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True))) + >>> if fleet.is_worker(): + ... print("this is worker") + >>> print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True))) - print("server num: {}".format(fleet.server_num())) - print("server endpoints: {}".format(fleet.server_endpoints(to_string=True))) + >>> print("server num: {}".format(fleet.server_num())) + >>> print("server endpoints: {}".format(fleet.server_endpoints(to_string=True))) - if fleet.is_server(): - print("this is server") - fleet.stop_worker() + >>> if fleet.is_server(): + ... print("this is server") + >>> fleet.stop_worker() """ @@ -202,37 +203,37 @@ def init( .. code-block:: python :name: code-example1 - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() .. code-block:: python :name: code-example2 - import paddle.distributed.fleet as fleet - fleet.init(is_collective=True) + >>> import paddle.distributed.fleet as fleet + >>> fleet.init(is_collective=True) .. code-block:: python :name: code-example3 - import paddle.distributed.fleet as fleet - role = fleet.PaddleCloudRoleMaker() - fleet.init(role) + >>> import paddle.distributed.fleet as fleet + >>> role = fleet.PaddleCloudRoleMaker() + >>> fleet.init(role) .. code-block:: python :name: code-example4 - import paddle.distributed.fleet as fleet - strategy = fleet.DistributedStrategy() - fleet.init(strategy=strategy) + >>> import paddle.distributed.fleet as fleet + >>> strategy = fleet.DistributedStrategy() + >>> fleet.init(strategy=strategy) .. code-block:: python :name: code-example5 - import paddle.distributed.fleet as fleet - strategy = fleet.DistributedStrategy() - fleet.init(log_level = "DEBUG") + >>> import paddle.distributed.fleet as fleet + >>> strategy = fleet.DistributedStrategy() + >>> fleet.init(log_level = "DEBUG") """ from paddle.distributed import parallel_helper @@ -454,9 +455,9 @@ def is_first_worker(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.is_first_worker() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.is_first_worker() """ return self._role_maker._is_first_worker() @@ -472,9 +473,9 @@ def worker_index(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.worker_index() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.worker_index() """ return self._role_maker._worker_index() @@ -490,9 +491,9 @@ def worker_num(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.worker_num() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.worker_num() """ return self._role_maker._worker_num() @@ -521,9 +522,9 @@ def is_worker(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.is_worker() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.is_worker() """ return self._role_maker._is_worker() @@ -542,9 +543,9 @@ def worker_endpoints(self, to_string=False): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.worker_endpoints() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.worker_endpoints() """ if to_string: @@ -563,9 +564,9 @@ def server_num(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.server_num() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.server_num() """ return len(self._role_maker._get_pserver_endpoints()) @@ -580,9 +581,9 @@ def server_index(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.server_index() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.server_index() """ return self._role_maker._server_index() @@ -598,9 +599,9 @@ def server_endpoints(self, to_string=False): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.server_endpoints() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.server_endpoints() """ @@ -621,9 +622,9 @@ def is_server(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.is_server() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.is_server() """ return self._role_maker._is_server() @@ -639,9 +640,9 @@ def barrier_worker(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.barrier_worker() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.barrier_worker() """ self._role_maker._barrier("worker") @@ -659,13 +660,13 @@ def init_worker(self, scopes=None): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.init_worker() + >>> fleet.init_worker() """ self._runtime_handle._init_worker(scopes) @@ -704,13 +705,13 @@ def init_server(self, *args, **kwargs): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.init_server() + >>> fleet.init_server() """ self._runtime_handle._init_server(*args, **kwargs) @@ -729,13 +730,13 @@ def load_model(self, path, mode): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.load_model("path", mode=0) + >>> fleet.load_model("path", mode=0) """ self._runtime_handle._load_persistables(path, mode) @@ -754,13 +755,13 @@ def load_one_table(self, table_id, path, mode): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.load_one_table(0, "path", mode=0) + >>> fleet.load_one_table(0, "path", mode=0) """ self._runtime_handle._load_one_table(table_id, path, mode) @@ -779,13 +780,13 @@ def load_inference_model(self, path, mode): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.load_inference_model("path", mode=1) + >>> fleet.load_inference_model("path", mode=1) """ self._runtime_handle._load_inference_model(path, mode) @@ -803,14 +804,14 @@ def run_server(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - if fleet.is_server(): - fleet.init_server() + >>> if fleet.is_server(): + ... fleet.init_server() """ self._runtime_handle._run_server() @@ -828,13 +829,13 @@ def stop_worker(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.init_server() + >>> fleet.init_server() """ self._runtime_handle._stop_worker() @@ -908,13 +909,13 @@ def save_inference_model( .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.init_server() + >>> fleet.init_server() """ @@ -958,17 +959,17 @@ def save_persistables(self, executor, dirname, main_program=None, mode=0): .. code-block:: text - import paddle - paddle.enable_static() - import paddle.distributed.fleet as fleet + >>> import paddle + >>> paddle.enable_static() + >>> import paddle.distributed.fleet as fleet - fleet.init() + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - exe = paddle.static.Executor(paddle.CPUPlace()) - fleet.save_persistables(exe, "dirname", paddle.static.default_main_program()) + >>> exe = paddle.static.Executor(paddle.CPUPlace()) + >>> fleet.save_persistables(exe, "dirname", paddle.static.default_main_program()) """ self._runtime_handle._save_persistables( @@ -1008,13 +1009,13 @@ def save_one_table(self, table_id, path, mode): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.save_one_table(0, "path", mode=0) + >>> fleet.save_one_table(0, "path", mode=0) """ self._runtime_handle._save_one_table(table_id, path, mode) @@ -1035,16 +1036,16 @@ def save_dense_params( .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - import paddle - place = paddle.CPUPlace() - exe = paddle.static.Executor(place) + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> import paddle + >>> place = paddle.CPUPlace() + >>> exe = paddle.static.Executor(place) - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.save_dense_params(exe, "path", scope=paddle.static.global_scope(), program=paddle.static.default_main_program()) + >>> fleet.save_dense_params(exe, "path", scope=paddle.static.global_scope(), program=paddle.static.default_main_program()) """ self._runtime_handle._save_dense_params( @@ -1078,12 +1079,13 @@ def distributed_optimizer(self, optimizer, strategy=None): .. code-block:: python - import paddle - import paddle.distributed.fleet as fleet - fleet.init(is_collective=True) - strategy = fleet.DistributedStrategy() - optimizer = paddle.optimizer.SGD(learning_rate=0.001) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + >>> import paddle + >>> import paddle.distributed.fleet as fleet + >>> fleet.init(is_collective=True) + >>> linear = paddle.nn.Linear(10, 10) + >>> strategy = fleet.DistributedStrategy() + >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters()) + >>> optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) """ self.user_defined_optimizer = optimizer @@ -1141,46 +1143,46 @@ def amp_init( Examples: .. code-block:: python - import paddle - import paddle.nn.functional as F - paddle.enable_static() - - def run_example_code(): - place = paddle.CUDAPlace(0) - exe = paddle.static.Executor(place) - data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32') - conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3) - # 1) Use fp16_guard to control the range of fp16 kernels used. - with paddle.static.amp.fp16_guard(): - bn = paddle.static.nn.batch_norm(input=conv2d, act="relu") - pool = F.max_pool2d(bn, kernel_size=2, stride=2) - hidden = paddle.static.nn.fc(pool, size=10) - loss = paddle.mean(hidden) - # 2) Create the optimizer and set `multi_precision` to True. - # Setting `multi_precision` to True can avoid the poor accuracy - # or the slow convergence in a way. - optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True) - # 3) These ops in `custom_black_list` will keep in the float32 computation type. - amp_list = paddle.static.amp.CustomOpLists( - custom_black_list=['pool2d']) - # 4) The entry of Paddle AMP. - # Enable pure fp16 training by setting `use_pure_fp16` to True. - optimizer = paddle.static.amp.decorate( - optimizer, - amp_list, - init_loss_scaling=128.0, - use_dynamic_loss_scaling=True, - use_pure_fp16=True) - # If you don't use the default_startup_program(), you sholud pass - # your defined `startup_program` into `minimize`. - optimizer.minimize(loss) - exe.run(paddle.static.default_startup_program()) - # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). - # If you want to perform the testing process, you should pass `test_program` into `amp_init`. - optimizer.amp_init(place, scope=paddle.static.global_scope()) - - if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: - run_example_code() + >>> import paddle + >>> import paddle.nn.functional as F + >>> paddle.enable_static() + + >>> def run_example_code(): + ... place = paddle.CUDAPlace(0) + ... exe = paddle.static.Executor(place) + ... data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32') + ... conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3) + ... # 1) Use fp16_guard to control the range of fp16 kernels used. + ... with paddle.static.amp.fp16_guard(): + ... bn = paddle.static.nn.batch_norm(input=conv2d, act="relu") + ... pool = F.max_pool2d(bn, kernel_size=2, stride=2) + ... hidden = paddle.static.nn.fc(pool, size=10) + ... loss = paddle.mean(hidden) + ... # 2) Create the optimizer and set `multi_precision` to True. + ... # Setting `multi_precision` to True can avoid the poor accuracy + ... # or the slow convergence in a way. + ... optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True) + ... # 3) These ops in `custom_black_list` will keep in the float32 computation type. + ... amp_list = paddle.static.amp.CustomOpLists( + ... custom_black_list=['pool2d']) + ... # 4) The entry of Paddle AMP. + ... # Enable pure fp16 training by setting `use_pure_fp16` to True. + ... optimizer = paddle.static.amp.decorate( + ... optimizer, + ... amp_list, + ... init_loss_scaling=128.0, + ... use_dynamic_loss_scaling=True, + ... use_pure_fp16=True) + ... # If you don't use the default_startup_program(), you sholud pass + ... # your defined `startup_program` into `minimize`. + ... optimizer.minimize(loss) + ... exe.run(paddle.static.default_startup_program()) + ... # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). + ... # If you want to perform the testing process, you should pass `test_program` into `amp_init`. + ... optimizer.amp_init(place, scope=paddle.static.global_scope()) + + >>> if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: + ... run_example_code() """ amp_optimizer = self._get_amp_optimizer() return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test) @@ -1273,28 +1275,29 @@ def minimize( .. code-block:: python - import paddle - paddle.enable_static() - import paddle.distributed.fleet as fleet - import paddle.nn.functional as F - - hid_dim = 10 - label_dim = 2 - input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32') - input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64') - fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh') - fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh') - prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax') - cost = F.cross_entropy(input=prediction, label=input_y) - avg_cost = paddle.mean(x=cost) - - fleet.init(is_collective=True) - strategy = fleet.DistributedStrategy() - optimizer = paddle.optimizer.SGD(learning_rate=0.001) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - optimizer.minimize(avg_cost) - - # for more examples, please reference https://github.com/PaddlePaddle/PaddleFleetX + >>> import paddle + >>> paddle.enable_static() + >>> import paddle.distributed.fleet as fleet + >>> import paddle.nn.functional as F + + >>> hid_dim = 10 + >>> label_dim = 2 + >>> input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32') + >>> input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64') + >>> fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh') + >>> fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh') + >>> prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax') + >>> cost = F.cross_entropy(input=prediction, label=input_y) + >>> avg_cost = paddle.mean(x=cost) + + >>> fleet.init(is_collective=True) + >>> strategy = fleet.DistributedStrategy() + >>> linear = paddle.nn.Linear(10, 10) + >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters()) + >>> optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + >>> optimizer.minimize(avg_cost) + + >>> # for more examples, please reference https://github.com/PaddlePaddle/PaddleFleetX """ if not isinstance(loss, list): diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index ac51a9b8a08bb..2b6b6eec7748c 100755 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -1287,9 +1287,7 @@ def get_role_endpoints(self, args): assert ( len(heter_worker_endpoints) == self.stage_heter_trainer_num[i] - ), "The heter trainer num in stage {} is not equal in args.heter_worker_num and args.heter_workers".format( - i - ) + ), f"The heter trainer num in stage {i} is not equal in args.heter_worker_num and args.heter_workers" heter_worker_endpoints_ips = [ x.strip().split(":")[0] diff --git a/python/paddle/distributed/fleet/layers/mpu/random.py b/python/paddle/distributed/fleet/layers/mpu/random.py index 5b43ef951cfff..fad40dc03409b 100644 --- a/python/paddle/distributed/fleet/layers/mpu/random.py +++ b/python/paddle/distributed/fleet/layers/mpu/random.py @@ -237,9 +237,7 @@ def dropout( if isinstance(p, Variable) and not p.shape != [1]: raise TypeError( - "Required p.shape == [1] if type(p) is Variable, but received p.shape = {}".format( - p.shape - ) + f"Required p.shape == [1] if type(p) is Variable, but received p.shape = {p.shape}" ) helper = LayerHelper('dropout', **locals()) diff --git a/python/paddle/distributed/fleet/meta_optimizers/common.py b/python/paddle/distributed/fleet/meta_optimizers/common.py index 9625e2481d400..75be5f621d412 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/common.py +++ b/python/paddle/distributed/fleet/meta_optimizers/common.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import paddle from paddle.framework import core @@ -99,8 +98,10 @@ def _init_communicator( other_endpoints.remove(current_endpoint) if rank == 0 and wait_port: - use_new_comm = os.getenv("FLAGS_dynamic_static_unified_comm", "0") - if use_new_comm not in [1, "1", "True", "true"]: + use_new_comm = paddle.get_flags( + "FLAGS_dynamic_static_unified_comm" + )["FLAGS_dynamic_static_unified_comm"] + if not use_new_comm: wait_server_ready(other_endpoints) def _add_sync_by_allreduce(block): diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 071e1a07ce027..119d2839fe92d 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -18,10 +18,16 @@ import paddle from paddle import framework +from paddle.base.framework import EagerParamBase from paddle.distributed import fleet from ...utils.log_util import logger -from ...utils.tensor_fusion_helper import fused_parameters +from ...utils.tensor_fusion_helper import ( + HOOK_ACTION, + FusedCommBuffer, + assign_group_by_size, + fused_parameters, +) g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1)) g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 0)) @@ -29,7 +35,7 @@ if g_shard_norm_align_dp: assert ( not g_shard_use_reduce - ), "g_shard_norm_align_dp is not support if g_shard_use_reduce is true" + ), "g_shard_norm_align_dp is not supported if g_shard_use_reduce is true" def _is_trainable(param): @@ -54,6 +60,7 @@ class DygraphShardingOptimizer: # 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm def __init__(self, optimizer, hcg): + logger.info("init DygraphShardingOptimizer") # TODO(pangengzheng): support param_groups if isinstance(optimizer._parameter_list[0], dict): raise TypeError( @@ -76,6 +83,7 @@ def __init__(self, optimizer, hcg): self.tensor_fusion = strategy.hybrid_configs[ 'sharding_configs' ].tensor_fusion + self.accumulate_steps = strategy.hybrid_configs[ 'sharding_configs' ].accumulate_steps @@ -416,3 +424,281 @@ def _set_inner_opt_attr(self, attr_name, value): def __getattr__(self, item): return getattr(self._inner_opt, item) + + +class DygraphShardingOptimizerV2: + """ + A wrapper for Sharding Optimizer in Dygraph, which split params + + .. warning: DygraphShardingOptimizer is experimental and subject to change. + + .. ZeRO: https://arxiv.org/abs/1910.02054 + + """ + + # TODO (JZ-LIANG) + # TO support following featrues in future: + # 1. fused update parameter sync + # 2. parameters_groups + # 3. dynamic trainable params, which is the case bewteen pretraining and finetuning + # 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm + # 5. do not shard small params + + def __init__(self, optimizer, hcg): + logger.info("init DygraphShardingOptimizerV2") + assert ( + g_shard_use_reduce + ), "g_shard_use_reduce must be true if DygraphShardingOptimizerV2 is used" + + # TODO(pangengzheng): support param_groups + if isinstance(optimizer._parameter_list[0], dict): + raise TypeError( + "Do not support param_groups now, please set optimizer._parameter_list as a list of Parameter" + ) + if not hasattr(optimizer, '_apply_optimize') or not callable( + optimizer._apply_optimize + ): + raise ValueError( + "the optimzier object should have _apply_optimize function" + ) + + self._inner_opt = optimizer + self._hcg = hcg + self._sharding_world_size = self._hcg.get_sharding_parallel_world_size() + self._sharding_rank = self._hcg.get_sharding_parallel_rank() + + self._parameter_list = optimizer._parameter_list + + # param name -> slice_param + self._slice_params = {} + # comm_buffer_list = [] + self._comm_buffer_list = [] + + # slice parameter list + self._local_parameter_list = [ + self._create_slice_param(p) for p in optimizer._parameter_list + ] + + strategy = fleet.fleet._user_defined_strategy + self.tensor_fusion = strategy.hybrid_configs[ + 'sharding_configs' + ].tensor_fusion + + assert not self.tensor_fusion, "not supported yet" + + self.accumulate_steps = strategy.hybrid_configs[ + 'sharding_configs' + ].accumulate_steps + self.comm_overlap = strategy.hybrid_configs[ + 'sharding_configs' + ].comm_overlap + + self.pp_overlap = strategy.hybrid_configs[ + 'pp_configs' + ].sharding_comm_overlap + + # TODO(liuzhenhai):support it latter + assert not self.comm_overlap, "not supported yet" + + self._build_comm_buffers() + self._set_inner_opt_attr('_parameter_list', self._local_parameter_list) + self._set_inner_opt_attr('_param_groups', self._local_parameter_list) + + def _build_comm_buffers(self, group_size=256 * 1024 * 1024): + if self.pp_overlap: + return + + comm_group = self._hcg.get_sharding_parallel_group() + var_groups = assign_group_by_size(self._parameter_list, group_size) + for group_idx, parameters in var_groups.items(): + buffer = FusedCommBuffer( + group_idx, + parameters, + comm_group, + act=HOOK_ACTION.REDUCE_SCATTER, + ) + self._comm_buffer_list.append(buffer) + + def clear_grad(self, set_to_zero=True): + """ + should clear grad for all parameters in model + """ + assert set_to_zero, "should not erase grad buffer" + + def clear_grad_func(p): + if hasattr(p, "main_grad") and p.main_grad is not None: + assert p._grad_ivar() is None + if set_to_zero: + p.main_grad.zero_() + else: + p.main_grad._clear() + p.main_grad = None + elif not hasattr(p, "main_grad"): + if self.tensor_fusion: + if set_to_zero: + p.grad.zero_() + else: + p.grad._clear() + p.grad = None + else: + p.clear_gradient(set_to_zero) + + for p in self._parameter_list: + clear_grad_func(p) + + def filter_parameters(self, parameter_list, hcg): + parameter_list = [ + self._slice_params[param.name] for param in parameter_list + ] + parameter_list = [ + param for param in parameter_list if param._is_initialized() + ] + return parameter_list + + def reduce_gradients(self, parameter_list, hcg): + # TODO merge grad / nrank with dp + logger.debug("sharding start gradients sync") + with framework.no_grad(): + for comm_buffer in self._comm_buffer_list: + comm_buffer._comm_grads() + comm_buffer.scale_grads() + + def _sharding_sync_parameters(self): + """ + sync parameter across sharding group + """ + + logger.debug("sharding start sync parameters") + with framework.no_grad(): + for comm_buffer in self._comm_buffer_list: + comm_buffer.sync_params() + + def _update_trainable(self): + """ + allow user to update trainable parameters list during training + """ + raise NotImplementedError + + def minimize( + self, loss, startup_program=None, parameters=None, no_grad_set=None + ): + # NOTE in dygraph mode, the only different between step and minimize is that minimize + # allow user to customize the parameters for updating on each step + raise AssertionError("not supported yet") + + def _create_slice_param(self, param): + # not initialized yet + slice_param = EagerParamBase(shape=[1], dtype=param.dtype) + slice_param.name = param.name + + def copy_attr(attr_name): + if hasattr(param, attr_name): + setattr(slice_param, attr_name, getattr(param, attr_name)) + + copy_attr("is_distributed") + copy_attr("optimize_attr") + copy_attr("do_model_average") + copy_attr("need_clip") + + self._slice_params[param.name] = slice_param + return slice_param + + def _collect_comm_buffers(self): + if self._comm_buffer_list: + return + for param in self._parameter_list: + if not hasattr(param, "comm_buffer_ref"): + continue + comm_buffer_ref = param.comm_buffer_ref + del param.comm_buffer_ref + comm_buffer = comm_buffer_ref() + self._comm_buffer_list.append(comm_buffer) + + assert self._comm_buffer_list + + def _assign_slice_grad(self): + param_num = 0 + for comm_buffer in self._comm_buffer_list: + param_num = param_num + len(comm_buffer.params) + for param in comm_buffer.params: + assert param.name in self._slice_params + slice_param = self._slice_params[param.name] + comm_buffer.assign_slice_grad(param, slice_param) + + assert param_num == len(self._parameter_list) + + def step(self): + # TODO Check whether the model trainable param changed and update state accordingly + # hack for pp comm overlap + self._collect_comm_buffers() + self._assign_slice_grad() + + if not isinstance(self._parameter_list[0], dict): + params_grads = [] + for param in self._parameter_list: + if ( + hasattr(param, "regularizer") + and param.regularizer is not None + ): + raise ValueError( + f"param {param.name} should not has the regularizer attribute" + ) + if param.stop_gradient: + continue + # update on slice + assert param.name in self._slice_params + param = self._slice_params[param.name] + grad_var = param._grad_ivar() + if hasattr(param, "main_grad") and param.main_grad is not None: + grad_var = param.main_grad + if grad_var is not None: + params_grads.append((param, grad_var)) + + self._apply_optimize( + loss=None, + startup_program=None, + params_grads=params_grads, + ) + + # sync parameters across sharding ranks + self._sharding_sync_parameters() + + @framework.dygraph_only + def set_state_dict(self, state_dict): + inner_state = {} + parameters = self._parameter_list + + if "LR_Scheduler" in state_dict: + inner_state["LR_Scheduler"] = state_dict.pop("LR_Scheduler") + + if "master_weights" in state_dict: + master = state_dict.pop("master_weights") + inner_state["master_weights"] = {} + for p in parameters: + for k, v in master.items(): + if p.name == k: + v.name = self._inner_opt._gen_master_weight_var_name(p) + inner_state["master_weights"][k] = v + + for p in parameters: + for k, v in state_dict.items(): + if p.name in k: + inner_state[k] = v + + self._inner_opt.set_state_dict(inner_state) + + def _set_inner_opt_attr(self, attr_name, value): + inner_opt = self._inner_opt + inner_opt_name = '_inner_opt' + if not isinstance(attr_name, str): + raise TypeError( + f"attr_name should be str type, but is {type(attr_name)}" + ) + while hasattr(inner_opt, attr_name): + setattr(inner_opt, attr_name, value) + inner_opt = getattr(inner_opt, inner_opt_name, None) + if inner_opt is None: + break + + def __getattr__(self, item): + return getattr(self._inner_opt, item) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 4415f70df37f6..c4f546cc1ea9c 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -20,6 +20,7 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( DygraphShardingOptimizer, + DygraphShardingOptimizerV2, ) from paddle.distributed.fleet.utils.hybrid_parallel_util import ( obtain_optimizer_parameters_list, @@ -257,7 +258,15 @@ def __init__(self, optimizer, hcg, strategy): # Note: Only sharding stage 1 is considered in HybridParallelOptimizer. # The sharding stage2 and stage3 optimizers are invoked in other api. if hcg.get_sharding_parallel_world_size() > 1: - optimizer = DygraphShardingOptimizer(optimizer, hcg) + split_param = strategy.hybrid_configs[ + 'sharding_configs' + ].split_param + ShardingOptimizer = ( + DygraphShardingOptimizerV2 + if split_param + else DygraphShardingOptimizer + ) + optimizer = ShardingOptimizer(optimizer, hcg) self._inner_opt = optimizer self._strategy = strategy self._hcg = hcg @@ -287,7 +296,11 @@ def __init__(self, optimizer, hcg, strategy): inner_opt = unwrap_optimizer( self._inner_opt, - (MixPrecisionOptimizer, DygraphShardingOptimizer), + ( + MixPrecisionOptimizer, + DygraphShardingOptimizer, + DygraphShardingOptimizerV2, + ), ) if ( @@ -369,77 +382,91 @@ def _step(self, parameters_list): key=lambda p: p.name, ) + def syc_grad(p): + if hasattr(p, "main_grad") and p.main_grad is not None: + assert p.grad is None + self._insert_sync( + p.main_grad, src_rank, mp_group, mp_configs.sync_mode + ) + elif p.grad is not None: + self._insert_sync( + p.grad, src_rank, mp_group, mp_configs.sync_mode + ) + # Grad sync before opt if mp_group.nranks > 1 and mp_configs and mp_configs.sync_grad: for p in params: - if hasattr(p, "main_grad") and p.main_grad is not None: - assert p.grad is None - self._insert_sync( - p.main_grad, src_rank, mp_group, mp_configs.sync_mode - ) - elif p.grad is not None: - self._insert_sync( - p.grad, src_rank, mp_group, mp_configs.sync_mode - ) + syc_grad(p) self._inner_opt.step() + def syc_param(p): + # Param sync after opt + self._insert_sync(p, src_rank, mp_group, mp_configs.sync_mode) + + def syc_master_weight(p): + # Master param sync after opt + if ( + hasattr(self._inner_opt, "_multi_precision") + and self._inner_opt._multi_precision + and p.name in self._inner_opt._master_weights + ): + self._insert_sync( + self._inner_opt._master_weights[p.name], + src_rank, + mp_group, + mp_configs.sync_mode, + ) + + # syc param and master weight after opt if mp_group.nranks > 1 and mp_configs and mp_configs.sync_param: for p in params: - # Param sync after opt - self._insert_sync(p, src_rank, mp_group, mp_configs.sync_mode) + syc_param(p) + syc_master_weight(p) - # Master param sync after opt + def syc_moment(p): + if isinstance( + self._inner_opt, + (paddle.optimizer.Adam, paddle.optimizer.AdamW), + ): if ( - hasattr(self._inner_opt, "_multi_precision") - and self._inner_opt._multi_precision - and p.name in self._inner_opt._master_weights + p.name + in self._inner_opt._accumulators[ + self._inner_opt._moment1_acc_str + ] ): + moment1 = self._inner_opt._get_accumulator( + self._inner_opt._moment1_acc_str, p + ) self._insert_sync( - self._inner_opt._master_weights[p.name], - src_rank, - mp_group, - mp_configs.sync_mode, + moment1, src_rank, mp_group, mp_configs.sync_mode + ) + + if ( + p.name + in self._inner_opt._accumulators[ + self._inner_opt._moment2_acc_str + ] + ): + moment2 = self._inner_opt._get_accumulator( + self._inner_opt._moment2_acc_str, p + ) + self._insert_sync( + moment2, src_rank, mp_group, mp_configs.sync_mode ) # Moment sync after opt if mp_group.nranks > 1 and mp_configs and mp_configs.sync_moment: for p in params: - # support opt state of adam and adamw to broadcast now. - if isinstance( - self._inner_opt, - (paddle.optimizer.Adam, paddle.optimizer.AdamW), - ): - if ( - p.name - in self._inner_opt._accumulators[ - self._inner_opt._moment1_acc_str - ] - ): - moment1 = self._inner_opt._get_accumulator( - self._inner_opt._moment1_acc_str, p - ) - self._insert_sync( - moment1, src_rank, mp_group, mp_configs.sync_mode - ) - - if ( - p.name - in self._inner_opt._accumulators[ - self._inner_opt._moment2_acc_str - ] - ): - moment2 = self._inner_opt._get_accumulator( - self._inner_opt._moment2_acc_str, p - ) - self._insert_sync( - moment2, src_rank, mp_group, mp_configs.sync_mode - ) + syc_moment(p) def _hybrid_sync_grad(self, parameter_list): dp_parameter_list = parameter_list if self._sharding_enable: - assert isinstance(self._inner_opt, DygraphShardingOptimizer) + assert isinstance( + self._inner_opt, + (DygraphShardingOptimizer, DygraphShardingOptimizerV2), + ) self._inner_opt.reduce_gradients(parameter_list, self._hcg) # dp later do not need to use global parameter list if not g_shard_norm_align_dp: diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 1ee99b10854b9..ab8ec3a67b145 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -14,6 +14,7 @@ import os +import paddle from paddle.base import core from paddle.incubate.optimizer import PipelineOptimizer from paddle.static import ( @@ -714,8 +715,10 @@ def minimize_impl( self._recreate_not_persist_param_as_var() self._dump_program_for_debug() - use_new_comm = os.getenv("FLAGS_dynamic_static_unified_comm", "0") - if use_new_comm not in ["1", "True", "true"]: + use_new_comm = paddle.get_flags("FLAGS_dynamic_static_unified_comm")[ + "FLAGS_dynamic_static_unified_comm" + ] + if not use_new_comm: self._wait() return optimize_ops, params_grads diff --git a/python/paddle/distributed/fleet/meta_parallel/__init__.py b/python/paddle/distributed/fleet/meta_parallel/__init__.py index 7b1f668f421da..bdf76262157d4 100644 --- a/python/paddle/distributed/fleet/meta_parallel/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/__init__.py @@ -27,6 +27,7 @@ from .pipeline_parallel import ( # noqa: F401 PipelineParallel, PipelineParallelWithInterleave, + PipelineParallelWithInterleaveFthenB, ) from .segment_parallel import SegmentParallel # noqa: F401 from .sharding_parallel import ShardingParallel # noqa: F401 diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index a3e5b406be79e..7c4cdbb0b69d0 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and import os +import queue import sys import time import warnings @@ -50,6 +51,14 @@ g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1)) +def get_action(is_dp, shard_split_param=False): + if is_dp or not g_shard_use_reduce: + return HOOK_ACTION.ALL_REDUCE + if shard_split_param: + return HOOK_ACTION.REDUCE_SCATTER + return HOOK_ACTION.REDUCE + + # assume only the first stage and last stage need data, and data consumption is ordred # to be replaced by real micro dataset from reader class FakeMicroDataset: @@ -192,6 +201,16 @@ def __init__(self, layers, hcg, strategy): "pp_configs" ].enable_timer + self._sharding_split_param = self._strategy.hybrid_configs[ + "sharding_configs" + ].split_param + + logger.info( + f"dp_comm_overlap {self._dp_comm_overlap}; \ + sharding_comm_overlap {self._sharding_comm_overlap}; \ + sharding_split_param {self._sharding_split_param};" + ) + self._profiling = self._strategy.hybrid_configs["pp_configs"].profiling self._records = [] self._record_format = ( @@ -239,11 +258,13 @@ def __init__(self, layers, hcg, strategy): p2p.initialize_p2p_groups( hcg, - self._using_cache, self._enable_partial_send_recv, self._enable_timer, ) + # construct pipeline meta info + self._p2p_helper = p2p.P2pHelper(self._using_cache) + self.global_rank = self._hcg.get_global_rank() self.micro_batch_id = 0 @@ -311,16 +332,14 @@ def register_allreduce_overlap_hook( else: models = [model] - if not dp: + act = get_action(dp, self._sharding_split_param) + + if act == HOOK_ACTION.REDUCE: assert hasattr(self, "optimizer") assert hasattr(self.optimizer, "_param2rank") _param2rank = self.optimizer._param2rank + # Note: after sharding change to reduce operation, here need to be cleared - act = ( - HOOK_ACTION.ALL_REDUCE - if (dp or not g_shard_use_reduce) - else HOOK_ACTION.REDUCE - ) for chunk_idx, model in enumerate(models): # For virtual pipeline. Will separate parameters in different chunk into @@ -333,9 +352,7 @@ def register_allreduce_overlap_hook( if len(parameter_list) < 1: return - if dp: - fused_parameter_group[-1] = parameter_list - else: + if act == HOOK_ACTION.REDUCE: # Sort parameters for sharding, since they have different dst rank for p in parameter_list: assert p.name in _param2rank @@ -344,10 +361,12 @@ def register_allreduce_overlap_hook( fused_parameter_group[dst_rank].append(p) else: fused_parameter_group[dst_rank] = [p] + else: + fused_parameter_group[-1] = parameter_list for dst in fused_parameter_group: parameter_list = fused_parameter_group[dst] - if act != HOOK_ACTION.ALL_REDUCE: + if act == HOOK_ACTION.REDUCE: # parse the relative dst rank to absolute dst rank for sharding dst = comm_group.ranks[dst] else: @@ -436,12 +455,16 @@ def forward_backward_pipeline( schedule += f"f{step_id};" logger.info(f"forward step for micro step {step_id}") continue - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = self._p2p_helper.recv_forward( + self.is_pipeline_first_stage() + ) self._record_stamp("F", step_id, '"B"', self._forward_color) output_tensor = self._forward_step(input_tensor, micro_dataset) self._record_stamp("F", step_id, '"E"', self._forward_color) - p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) + self._p2p_helper.send_forward( + output_tensor, self.is_pipeline_last_stage() + ) input_buffers.append(input_tensor) output_buffers.append(output_tensor) @@ -450,7 +473,9 @@ def forward_backward_pipeline( self._release_output(output_tensor) if steady_steps > 0 and not static_scheduler: - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = self._p2p_helper.recv_forward( + self.is_pipeline_first_stage() + ) for i in range(steady_steps): if static_scheduler: @@ -469,7 +494,7 @@ def forward_backward_pipeline( "F", startup_steps + i, '"E"', self._forward_color ) - output_tensor_grad = p2p.send_forward_recv_backward( + output_tensor_grad = self._p2p_helper.send_forward_recv_backward( output_tensor, self.is_pipeline_last_stage() ) @@ -491,11 +516,11 @@ def forward_backward_pipeline( if last_iter: input_tensor = None - p2p.send_backward( + self._p2p_helper.send_backward( input_tensor_grad, self.is_pipeline_first_stage() ) else: - input_tensor = p2p.send_backward_recv_forward( + input_tensor = self._p2p_helper.send_backward_recv_forward( input_tensor_grad, self.is_pipeline_first_stage() ) @@ -507,7 +532,7 @@ def forward_backward_pipeline( input_tensor = input_buffers.pop(0) output_tensor = output_buffers.pop(0) - output_tensor_grad = p2p.recv_backward( + output_tensor_grad = self._p2p_helper.recv_backward( self.is_pipeline_last_stage() ) @@ -520,7 +545,9 @@ def forward_backward_pipeline( self._record_stamp( "B", steady_steps + i, '"E"', self._backward_color ) - p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage()) + self._p2p_helper.send_backward( + input_tensor_grad, self.is_pipeline_first_stage() + ) if static_scheduler: return schedule @@ -548,6 +575,17 @@ def forward_backward_pipeline( self.timer_printer() return train_loss + def register_sharding_comm_overlap_hook(self, optimizer): + """for delayed hook register until we get optimizer""" + assert isinstance( + optimizer, HybridParallelOptimizer + ), 'optimizer should be HybridParallelOptimizer subclass.' + self.optimizer = optimizer + if self._sharding_comm_overlap and len(self._chunk_2_comm_buffers) == 0: + self.register_allreduce_overlap_hook( + self._layers, self.sharding_group, self.accumulate_steps, False + ) + def _prepare_training(self, data, optimizer, lr_scheduler): # reset the virtual pp rank for each run self.set_virtual_pipeline_rank(0) @@ -571,13 +609,8 @@ def _prepare_training(self, data, optimizer, lr_scheduler): self.optimizer = optimizer self.lr_scheduler = lr_scheduler - self._layers.train() - - if self._sharding_comm_overlap and len(self._chunk_2_comm_buffers) == 0: - self.register_allreduce_overlap_hook( - self._layers, self.sharding_group, self.accumulate_steps, False - ) + self.register_sharding_comm_overlap_hook(optimizer) return data @@ -632,28 +665,38 @@ def eval_batch(self, data, compute_loss=False): micro_dataset = self._wrap_data(data) for step_id in range(startup_steps): - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = self._p2p_helper.recv_forward( + self.is_pipeline_first_stage() + ) output_tensor = self._forward_step(input_tensor, micro_dataset) - p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) + self._p2p_helper.send_forward( + output_tensor, self.is_pipeline_last_stage() + ) input_buffers.append(input_tensor) output_buffers.append(output_tensor) if steady_steps > 0: - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = self._p2p_helper.recv_forward( + self.is_pipeline_first_stage() + ) for i in range(steady_steps): last_iter = i == (steady_steps - 1) output_tensor = self._forward_step(input_tensor, micro_dataset) - p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) + self._p2p_helper.send_forward( + output_tensor, self.is_pipeline_last_stage() + ) input_buffers.append(input_tensor) output_buffers.append(output_tensor) if not last_iter: - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = self._p2p_helper.recv_forward( + self.is_pipeline_first_stage() + ) if self._compute_loss: self.train_loss = self._broadcast_final_loss() @@ -852,15 +895,7 @@ def __init__(self, layers, hcg, strategy): self._backward_micro_step_counter = {} assert layers.get_num_virtual_stages() > 1 - assert ( - self.num_stages > 2 - ), "virtual pipeline must run under pp degree > 2" - assert ( - framework.in_dynamic_mode() - ), "virtual pipeline stage with interleave only support eager dygraph mode" - assert ( - self.accumulate_steps % self.num_stages == 0 - ), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave" + # setup for interleave scheduler self.num_model_chunks = layers.get_num_virtual_stages() self.model_chunks = layers.get_model_chunks() @@ -870,6 +905,21 @@ def __init__(self, layers, hcg, strategy): self._virtual_pp_rank = 0 self._reset_counter() + self._check_sanity() + + def _check_sanity(self): + assert ( + framework.in_dygraph_mode() + ), "virtual pipeline stage with interleave only support eager dygraph mode" + + assert ( + self.num_stages > 2 + ), "virtual pipeline must run under pp degree > 2" + + assert ( + self.accumulate_steps % self.num_stages == 0 + ), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave" + def _reset_counter(self): for i in range(self.num_model_chunks): self._forward_micro_step_counter[i] = 0 @@ -1012,6 +1062,7 @@ def _backward_step_helper(self, micro_step): def bw_hook_func(self, buffer, param): # For pipeline with interleave, we need to add grad to buffer without communication. # Use communication where appropriate to avoid dp communication and pp scheduling conflicts. + # all reduce hook @paddle.autograd.no_grad() def fused_allreduce(*_): buffer.add_grad(param, use_comm=False) @@ -1054,6 +1105,10 @@ def forward_backward_pipeline( "enable static_scheduler will return the pp schedule instead of the loss" ) schedule = "" + # NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled + assert ( + self._using_cache + ), "cache should be enabled for pipeline with interleave" # init some attributes for this batch run self.scaler = scaler @@ -1099,7 +1154,7 @@ def forward_backward_pipeline( self.set_virtual_pipeline_rank(0) if not static_scheduler: self.input_tensors[0].append( - p2p.recv_forward( + self._p2p_helper.recv_forward( self.is_pipeline_first_stage(), sync_recv=False ) ) @@ -1141,7 +1196,12 @@ def forward_backward_pipeline( if self.is_pipeline_last_stage(): output_tensor = None - if micro_step == (startup_steps - 1) and not forward_only: + # prepare for the first steady step + if ( + micro_step == (startup_steps - 1) + and (not forward_only) + and steady_steps + ): input_tensor_grad = None recv_next = True if self.is_pipeline_last_stage(ignore_virtual=True): @@ -1151,7 +1211,7 @@ def forward_backward_pipeline( ( input_tensor, output_tensor_grad, - ) = p2p.send_forward_backward_recv_forward_backward( + ) = self._p2p_helper.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, @@ -1163,7 +1223,7 @@ def forward_backward_pipeline( output_tensor_grad ) else: - input_tensor = p2p.send_forward_recv_forward( + input_tensor = self._p2p_helper.send_forward_recv_forward( output_tensor, recv_prev=recv_prev ) # append input_tensor no matter none or not @@ -1274,7 +1334,7 @@ def forward_backward_pipeline( ( input_tensor, output_tensor_grad, - ) = p2p.send_forward_backward_recv_forward_backward( + ) = self._p2p_helper.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, @@ -1295,6 +1355,14 @@ def forward_backward_pipeline( # remaining backward steps if not forward_only: + # no steady steps, which only occurs when accumulate_step == num_stage + if not steady_steps: + output_tensor_grad = p2p.recv_backward( + self.is_pipeline_last_stage() + ) + self.output_tensor_grads[self.num_model_chunks - 1].append( + output_tensor_grad + ) for micro_step in range(steady_steps, num_steps): if static_scheduler: virtual_pp_rank = self._get_virtual_pp_rank( @@ -1328,7 +1396,7 @@ def forward_backward_pipeline( recv_next = False # append output_tensor_grad no matter none or not self.output_tensor_grads[next_backward_virtual_pp_rank].append( - p2p.send_backward_recv_backward( + self._p2p_helper.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next ) ) @@ -1386,3 +1454,225 @@ def get_static_scheduler(self): return self.forward_backward_pipeline( data=None, scaler=None, static_scheduler=True ) + + +class PipelineParallelWithInterleaveFthenB(PipelineParallelWithInterleave): + def __init__(self, layers, hcg, strategy): + super().__init__(layers=layers, hcg=hcg, strategy=strategy) + + def _check_sanity(self): + assert ( + framework.in_dygraph_mode() + ), "virtual pipeline stage with interleave only support eager dygraph mode" + + assert ( + self.num_stages > 2 + ), "virtual pipeline must run under pp degree > 2" + + def _get_virtual_pp_rank(self, micro_step, forward): + virtual_pp_stage = micro_step % ( + self.accumulate_steps * self.num_model_chunks + ) + virtual_pp_stage = virtual_pp_stage // self.accumulate_steps + if not forward: + virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1 + + return virtual_pp_stage + + def _overlap_comm_grads(self): + if not self._comm_overlap: + return + self._backward_step_count += 1 + sync_step = self._backward_step_count - self.stage_id + + if sync_step > 0 and sync_step % self.accumulate_steps == 0: + chunk_idx = self._virtual_pp_world_size - ( + sync_step // self.accumulate_steps + ) + for buffer in self._chunk_2_comm_buffers[chunk_idx]: + buffer.comm_grads() + + if self.stage_id == 0: + return + + if ( + self._backward_step_count + == self.accumulate_steps * self._virtual_pp_world_size + ): + for buffer in self._chunk_2_comm_buffers[0]: + buffer.comm_grads() + + def _sync_overlap_grads(self): + if not self._comm_overlap: + return + + expected_count = self.accumulate_steps * self._virtual_pp_world_size + assert self._backward_step_count == expected_count, ( + f"backward step count should be equal to accumulate steps * virtual pp world size, " + f"but got {self._backward_step_count}, expected result is {expected_count}" + ) + + for buffers in self._chunk_2_comm_buffers.values(): + for buffer in buffers: + buffer.scale_and_split_grads() + + def forward_backward_pipeline( + self, data, scaler, forward_only=False, compute_loss=True + ): + if not compute_loss: + assert ( + not forward_only + ), "compute_loss can only be set to False when forward_only is set to True" + + # NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled + assert ( + self._using_cache + ), "cache should be enabled for pipeline with interleave" + + # init some attributes for this batch run + self.scaler = scaler + self.total_loss = None + self.micro_batch_id = 0 + self._forward_only = forward_only + + assert ( + self.accumulate_steps >= self.num_stages + ), "accumulate_steps({}) should be larger than num_stages({}) for pipeline with interleave".format( + self.accumulate_steps, self.num_stages + ) + assert ( + self.accumulate_steps < 2 * self.num_stages + ), "accumulate_steps({}) should be smaller than 2 * num_stages({}) for pipeline with interleave".format( + self.accumulate_steps, self.num_stages + ) + + self._backward_step_count = 0 + skip_steps = self.accumulate_steps - self.num_stages + send_recv_buffer_queue = queue.Queue() + + # init some data buffers for interleave scheduler + self.input_tensors = [[] for _ in range(self.num_model_chunks)] + self.output_tensors = [[] for _ in range(self.num_model_chunks)] + self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)] + + micro_dataset = self._wrap_data(data) + num_steps = self.accumulate_steps * self.num_model_chunks + + self.set_virtual_pipeline_rank(0) + self.input_tensors[0].append( + self._p2p_helper.recv_forward( + self.is_pipeline_first_stage(), sync_recv=False + ) + ) + + # run startup steps + for micro_step in range(num_steps): + output_tensor = self._forward_step_helper(micro_dataset, micro_step) + # determine whether recv forward tensor or not + next_virtual_pp_rank = self._get_virtual_pp_rank( + micro_step + 1, forward=True + ) + + recv_prev = True + if self.is_pipeline_first_stage(ignore_virtual=True): + if next_virtual_pp_rank == 0: + # next chunk is the first chunk, not need to pre recv an input tensor + recv_prev = False + + # last micro step, no next run + if micro_step == (num_steps - 1): + recv_prev = False + + if self.is_pipeline_last_stage(ignore_virtual=True): + # last stage skip send/recv + if not self.is_pipeline_last_stage(): + send_recv_buffer_queue.put(output_tensor) + + if micro_step < skip_steps or ( + self.is_pipeline_last_stage() + and micro_step % self.accumulate_steps >= skip_steps + ): + output_tensor = None + else: + output_tensor = send_recv_buffer_queue.get() + + input_tensor = self._p2p_helper.send_forward_recv_forward( + output_tensor, recv_prev=recv_prev + ) + self.input_tensors[next_virtual_pp_rank].append(input_tensor) + + self._release_output(output_tensor) + + assert ( + send_recv_buffer_queue.empty() + ), "send_recv buffer should be empty" + + # remaining backward steps + if not forward_only: + self.output_tensor_grads[self.num_model_chunks - 1].append( + self._p2p_helper.recv_backward( + self.is_pipeline_last_stage(), sync_recv=False + ) + ) + + for micro_step in range(num_steps): + # cooldown loop + input_tensor_grad = self._backward_step_helper(micro_step) + next_backward_virtual_pp_rank = self._get_virtual_pp_rank( + micro_step + 1, forward=False + ) + + recv_next = True + if self.is_pipeline_last_stage(ignore_virtual=True): + if next_backward_virtual_pp_rank == ( + self.num_model_chunks - 1 + ): + recv_next = False + + if micro_step == (num_steps - 1): + recv_next = False + + if self.is_pipeline_first_stage(ignore_virtual=True): + if not self.is_pipeline_first_stage(): + send_recv_buffer_queue.put(input_tensor_grad) + + if micro_step < skip_steps or ( + self.is_pipeline_first_stage() + and micro_step % self.accumulate_steps >= skip_steps + ): + input_tensor_grad = None + else: + input_tensor_grad = send_recv_buffer_queue.get() + + self.output_tensor_grads[next_backward_virtual_pp_rank].append( + self._p2p_helper.send_backward_recv_backward( + input_tensor_grad, recv_next=recv_next + ) + ) + + assert ( + send_recv_buffer_queue.empty() + ), "send_recv buffer should be empty" + + self._sync_overlap_grads() + + if self._enable_timer: + self.timers("allreduce_shared_weight_gradients").start() + self._layers.allreduce_shared_weight_gradients() + if self._enable_timer: + self.timers("allreduce_shared_weight_gradients").stop() + + if compute_loss: + # return loss if compute loss + if self._enable_timer: + self.timers("broadcast_final_loss").start() + with paddle.amp.auto_cast(enable=False): + train_loss = self._broadcast_final_loss() + if self._enable_timer: + self.timers("broadcast_final_loss").stop() + else: + # else just return all intermediate output tensor for all micro steps + train_loss = self.output_tensors + + self.timer_printer() + return train_loss diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index dd422635a8cc1..4484a156101f6 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -28,17 +28,16 @@ from .utils import number_2_dtype, paddle_2_number _hcg = None -_use_cache = False +# _use_cache = False _enable_partial_send_recv = True _timers = None def initialize_p2p_groups( - hcg, use_cache=True, enable_partial_send_recv=True, enable_timer=False + hcg, enable_partial_send_recv=True, enable_timer=False ): - global _hcg, _use_cache, _enable_partial_send_recv, _timers + global _hcg, _enable_partial_send_recv, _timers _hcg = hcg - _use_cache = use_cache _enable_partial_send_recv = enable_partial_send_recv if enable_timer: _timers = timer.get_timers() @@ -167,8 +166,14 @@ def set_send_message(self, tensor): ] ) - -_send_recv_meta = SendRecvMeta() + def __repr__(self): + return "send_shape_message: {}, send_dtype_message: {}, recv_shape_message: {}, recv_dtype_message: {}, recv_stop_gradient: {}".format( + self.send_shape_message, + self.send_dtype_message, + self.recv_shape_message, + self.recv_dtype_message, + self.recv_stop_gradient, + ) def _is_valid_send_recv_partial(tensor, mp_degree): @@ -300,7 +305,12 @@ def _process_p2p_tuple_or_tensor( def _p2p_helper( - tensor_send_next, tensor_send_prev, recv_prev, recv_next, sync_recv=True + tensor_send_next, + tensor_send_prev, + recv_prev, + recv_next, + sync_recv=True, + send_recv_meta=None, ): global _hcg @@ -308,12 +318,13 @@ def _p2p_helper( tensor_recv_next = None # send / recv message - recv_shape_msg = _send_recv_meta.recv_shape_message - recv_dtype_msg = _send_recv_meta.recv_dtype_message - recv_stop_gradient = _send_recv_meta.recv_stop_gradient + assert send_recv_meta is not None, "send_recv_meta should not be None" + recv_shape_msg = send_recv_meta.recv_shape_message + recv_dtype_msg = send_recv_meta.recv_dtype_message + recv_stop_gradient = send_recv_meta.recv_stop_gradient - send_shape_msg = _send_recv_meta.send_shape_message - send_dtype_msg = _send_recv_meta.send_dtype_message + send_shape_msg = send_recv_meta.send_shape_message + send_dtype_msg = send_recv_meta.send_dtype_message # model parallel message mp_group = _hcg.get_model_parallel_group() @@ -433,183 +444,197 @@ def _p2p_helper( return tensor_recv_prev, tensor_recv_next -def recv_forward(pp_first_stage, sync_recv=True): - global _timers - if _timers is not None: - _timers("recv_forward").start() - if pp_first_stage: - input_tensor = None - else: - if not _send_recv_meta.has_recv_meta: - _send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group()) - _send_recv_meta.has_recv_meta = _use_cache - - input_tensor, _ = _p2p_helper( - tensor_send_next=None, - tensor_send_prev=None, - recv_prev=True, - recv_next=False, - sync_recv=sync_recv, - ) - if _timers is not None: - _timers("recv_forward").stop() - return input_tensor - +class P2pHelper: + def __init__(self, use_cache=True): + self._send_recv_meta = SendRecvMeta() + self._use_cache = use_cache -def recv_backward(pp_last_stage, sync_recv=True): - global _timers - if _timers is not None: - _timers("recv_backward").start() - if pp_last_stage: - output_tensor_grad = None - else: - _, output_tensor_grad = _p2p_helper( - tensor_send_next=None, - tensor_send_prev=None, - recv_prev=False, - recv_next=True, - sync_recv=sync_recv, - ) - if _timers is not None: - _timers("recv_backward").stop() - return output_tensor_grad - - -def send_forward(output_tensor, pp_last_stage): - global _timers - if _timers is not None: - _timers("send_forward").start() - if not pp_last_stage: - if not _send_recv_meta.has_send_meta: - _send_recv_meta.set_send_message(output_tensor) - _send_recv_meta.send_meta( + def _send_meta(self, output_tensor): + if not self._send_recv_meta.has_send_meta: + self._send_recv_meta.set_send_message(output_tensor) + self._send_recv_meta.send_meta( output_tensor, _hcg.get_pipe_parallel_group() ) - _send_recv_meta.has_send_meta = _use_cache - - _p2p_helper( + self._send_recv_meta.has_send_meta = self._use_cache + + def _recv_meta(self): + if not self._send_recv_meta.has_recv_meta: + self._send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group()) + self._send_recv_meta.has_recv_meta = self._use_cache + + def recv_forward(self, pp_first_stage, sync_recv=True): + global _timers + if _timers is not None: + _timers("recv_forward").start() + if pp_first_stage: + input_tensor = None + else: + self._recv_meta() + + input_tensor, _ = _p2p_helper( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=True, + recv_next=False, + sync_recv=sync_recv, + send_recv_meta=self._send_recv_meta, + ) + if _timers is not None: + _timers("recv_forward").stop() + return input_tensor + + def recv_backward(self, pp_last_stage, sync_recv=True): + global _timers + if _timers is not None: + _timers("recv_backward").start() + if pp_last_stage: + output_tensor_grad = None + else: + _, output_tensor_grad = _p2p_helper( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=False, + recv_next=True, + sync_recv=sync_recv, + send_recv_meta=self._send_recv_meta, + ) + if _timers is not None: + _timers("recv_backward").stop() + return output_tensor_grad + + def send_forward(self, output_tensor, pp_last_stage): + global _timers + if _timers is not None: + _timers("send_forward").start() + if not pp_last_stage: + self._send_meta(output_tensor) + + _p2p_helper( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=False, + send_recv_meta=self._send_recv_meta, + ) + if _timers is not None: + _timers("send_forward").stop() + + def send_backward(self, input_tensor_grad, pp_first_stage): + global _timers + if _timers is not None: + _timers("send_backward").start() + if not pp_first_stage: + _p2p_helper( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=False, + recv_next=False, + send_recv_meta=self._send_recv_meta, + ) + if _timers is not None: + _timers("send_backward").stop() + + def send_forward_recv_backward(self, output_tensor, pp_last_stage): + global _timers + if _timers is not None: + _timers("send_forward_recv_backward").start() + if pp_last_stage: + output_tensor_grad = None + else: + _, output_tensor_grad = _p2p_helper( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=True, + send_recv_meta=self._send_recv_meta, + ) + if _timers is not None: + _timers("send_forward_recv_backward").stop() + return output_tensor_grad + + def send_backward_recv_forward(self, input_tensor_grad, pp_first_stage): + global _timers + if _timers is not None: + _timers("send_backward_recv_forward").start() + if pp_first_stage: + input_tensor = None + else: + input_tensor, _ = _p2p_helper( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=True, + recv_next=False, + send_recv_meta=self._send_recv_meta, + ) + if _timers is not None: + _timers("send_backward_recv_forward").stop() + return input_tensor + + def send_forward_backward_recv_forward_backward( + self, output_tensor, input_tensor_grad, recv_prev, recv_next + ): + # always have to send dytpe info to downstream + global _timers + if _timers is not None: + _timers("send_forward_backward_recv_forward_backward").start() + + self._send_meta(output_tensor) + if recv_prev: + self._recv_meta() + + input_tensor, output_tensor_grad = _p2p_helper( tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=False, - recv_next=False, + tensor_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + sync_recv=False, + send_recv_meta=self._send_recv_meta, ) - if _timers is not None: - _timers("send_forward").stop() + if _timers is not None: + _timers("send_forward_backward_recv_forward_backward").stop() + return input_tensor, output_tensor_grad + def send_forward_recv_forward(self, output_tensor, recv_prev): + # always have to send dytpe info to downstream + global _timers + if _timers is not None: + _timers("send_forward_recv_forward").start() -def send_backward(input_tensor_grad, pp_first_stage): - global _timers - if _timers is not None: - _timers("send_backward").start() - if not pp_first_stage: - _p2p_helper( - tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=False, - recv_next=False, - ) - if _timers is not None: - _timers("send_backward").stop() + if output_tensor is not None: + self._send_meta(output_tensor) + if recv_prev: + self._recv_meta() -def send_forward_recv_backward(output_tensor, pp_last_stage): - global _timers - if _timers is not None: - _timers("send_forward_recv_backward").start() - if pp_last_stage: - output_tensor_grad = None - else: - _, output_tensor_grad = _p2p_helper( + input_tensor, _ = _p2p_helper( tensor_send_next=output_tensor, tensor_send_prev=None, - recv_prev=False, - recv_next=True, + recv_prev=recv_prev, + recv_next=False, + sync_recv=False, + send_recv_meta=self._send_recv_meta, ) - if _timers is not None: - _timers("send_forward_recv_backward").stop() - return output_tensor_grad - - -def send_backward_recv_forward(input_tensor_grad, pp_first_stage): - global _timers - if _timers is not None: - _timers("send_backward_recv_forward").start() - if pp_first_stage: - input_tensor = None - else: - input_tensor, _ = _p2p_helper( + if _timers is not None: + _timers("send_forward_recv_forward").stop() + return input_tensor + + def send_backward_recv_backward(self, input_tensor_grad, recv_next): + global _timers + if _timers is not None: + _timers("send_backward_recv_backward").start() + _, output_tensor_grad = _p2p_helper( tensor_send_next=None, tensor_send_prev=input_tensor_grad, - recv_prev=True, - recv_next=False, + recv_prev=False, + recv_next=recv_next, + sync_recv=False, + send_recv_meta=self._send_recv_meta, ) - if _timers is not None: - _timers("send_backward_recv_forward").stop() - return input_tensor - - -def send_forward_backward_recv_forward_backward( - output_tensor, input_tensor_grad, recv_prev, recv_next -): - # always have to send dytpe info to downstream - global _timers - if _timers is not None: - _timers("send_forward_backward_recv_forward_backward").start() - if not _send_recv_meta.has_send_meta: - _send_recv_meta.set_send_message(output_tensor) - _send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group()) - _send_recv_meta.has_send_meta = _use_cache - if recv_prev and not _send_recv_meta.has_recv_meta: - _send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group()) - _send_recv_meta.has_recv_meta = _use_cache - input_tensor, output_tensor_grad = _p2p_helper( - tensor_send_next=output_tensor, - tensor_send_prev=input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - sync_recv=False, - ) - if _timers is not None: - _timers("send_forward_backward_recv_forward_backward").stop() - return input_tensor, output_tensor_grad - - -def send_forward_recv_forward(output_tensor, recv_prev): - # always have to send dytpe info to downstream - global _timers - if _timers is not None: - _timers("send_forward_recv_forward").start() - if not _send_recv_meta.has_send_meta: - _send_recv_meta.set_send_message(output_tensor) - _send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group()) - _send_recv_meta.has_send_meta = _use_cache - if recv_prev and not _send_recv_meta.has_recv_meta: - _send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group()) - _send_recv_meta.has_recv_meta = _use_cache - - input_tensor, _ = _p2p_helper( - tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=recv_prev, - recv_next=False, - sync_recv=False, - ) - if _timers is not None: - _timers("send_forward_recv_forward").stop() - return input_tensor - - -def send_backward_recv_backward(input_tensor_grad, recv_next): - global _timers - if _timers is not None: - _timers("send_backward_recv_backward").start() - _, output_tensor_grad = _p2p_helper( - tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=False, - recv_next=recv_next, - sync_recv=False, - ) - if _timers is not None: - _timers("send_backward_recv_backward").stop() - return output_tensor_grad + if _timers is not None: + _timers("send_backward_recv_backward").stop() + return output_tensor_grad + + def __repr__(self): + debug_str = f"using cache: {self._use_cache} \n" + debug_str += repr(self._send_recv_meta) + return debug_str diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index f7fc29b8d27ab..c54b63ff17d9e 100755 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -20,6 +20,7 @@ PipelineLayer, PipelineParallel, PipelineParallelWithInterleave, + PipelineParallelWithInterleaveFthenB, SegmentParallel, ShardingParallel, TensorParallel, @@ -85,19 +86,27 @@ def distributed_model(model): if paddle.distributed.get_world_size() <= 1: return model - amp_enable = False strategy = fleet_env._user_defined_strategy if strategy.amp: - amp_enable = True - amp_level = "O2" if strategy.amp_configs['use_pure_fp16'] else "O1" - if amp_level.upper() == "O2": + level = ( + "O2" + if strategy.amp_configs['use_pure_fp16'] + or strategy.amp_configs['use_pure_bf16'] + else "O1" + ) + + if level == "O2": model = paddle.amp.decorate( models=model, optimizers=None, level="O2", master_weight=None, save_dtype=None, + dtype="float16" + if strategy.amp_configs['use_pure_fp16'] + else "bfloat16", ) + init_loss_scaling = strategy.amp_configs['init_loss_scaling'] incr_ratio = strategy.amp_configs['incr_ratio'] decr_ratio = strategy.amp_configs['decr_ratio'] @@ -150,9 +159,21 @@ def distributed_model(model): # 1f1b pipeline model = PipelineParallel(model, fleet_env._hcg, strategy=strategy) else: - # interleave pipeline - model = PipelineParallelWithInterleave( - model, fleet_env._hcg, strategy=strategy - ) + accumulate_steps = strategy.pipeline_configs['accumulate_steps'] + pp_degree = fleet_env._hcg.get_pipe_parallel_world_size() + if ( + accumulate_steps >= pp_degree + and accumulate_steps < pp_degree * 2 + ): + # NOTE(shenliang03): Hacky for unbalanced pipeline parallel with interleave + # Currently, we only support pp_degree <= accumulate_steps < 2 * pp_degree + model = PipelineParallelWithInterleaveFthenB( + model, fleet_env._hcg, strategy=strategy + ) + else: + # interleave pipeline + model = PipelineParallelWithInterleave( + model, fleet_env._hcg, strategy=strategy + ) return model diff --git a/python/paddle/distributed/fleet/scaler.py b/python/paddle/distributed/fleet/scaler.py index 463674c958741..e284563614745 100755 --- a/python/paddle/distributed/fleet/scaler.py +++ b/python/paddle/distributed/fleet/scaler.py @@ -31,6 +31,7 @@ def unscale_method(self, optimizer): return param_grads = [] + param_grads_bf16 = [] param_grads_fp16 = [] param_grads_fp32 = [] if getattr(optimizer, '_param_groups', None) and isinstance( @@ -53,6 +54,10 @@ def unscale_method(self, optimizer): paddle.float16, ]: param_grads_fp16.append(tgt_grad) + elif tgt_grad.dtype in [ + paddle.bfloat16, + ]: + param_grads_bf16.append(tgt_grad) else: param_grads_fp32.append(tgt_grad) else: @@ -90,10 +95,15 @@ def unscale_method(self, optimizer): paddle.float16, ]: param_grads_fp16.append(tgt_grad) + elif tgt_grad.dtype in [ + paddle.bfloat16, + ]: + param_grads_bf16.append(tgt_grad) else: param_grads_fp32.append(tgt_grad) temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_)) + temp_found_inf_bf16 = to_variable(np.array([0]).astype(np.bool_)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_)) self._found_inf = self._temp_found_inf_value_false if len(param_grads_fp16): @@ -106,6 +116,16 @@ def unscale_method(self, optimizer): self._found_inf = _C_ops.bitwise_or( self._found_inf, temp_found_inf_fp16 ) + if len(param_grads_bf16): + _legacy_C_ops.check_finite_and_unscale( + param_grads_bf16, + self._scale, + param_grads_bf16, + temp_found_inf_bf16, + ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, temp_found_inf_bf16 + ) if len(param_grads_fp32): _legacy_C_ops.check_finite_and_unscale( param_grads_fp32, diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 3a751b5d0c3c8..2d7c44e77b662 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle # noqa: F401 from paddle.distributed import fleet -from paddle.utils import deprecated # noqa: F401 from . import ( # noqa: F401 hybrid_parallel_util, diff --git a/python/paddle/distributed/fleet/utils/ps_util.py b/python/paddle/distributed/fleet/utils/ps_util.py index de502bcf4482c..002720450636a 100644 --- a/python/paddle/distributed/fleet/utils/ps_util.py +++ b/python/paddle/distributed/fleet/utils/ps_util.py @@ -265,9 +265,7 @@ def dag_check_up_and_reorder(program, inputs, outputs): if w.name not in varname2tables.keys(): raise ValueError( - "can not find variable {}, please check your configuration".format( - w.name - ) + f"can not find variable {w.name}, please check your configuration" ) table_id = varname2tables[w.name] diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index 548eae655cce5..a828d5b4aae17 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -13,6 +13,7 @@ # limitations under the License. import itertools import os +import weakref from collections import OrderedDict import numpy as np @@ -25,6 +26,7 @@ class HOOK_ACTION: ALL_REDUCE = 0 REDUCE = 1 + REDUCE_SCATTER = 2 alignment = { @@ -117,6 +119,147 @@ def fused_comm(*_): return fused_comm +class ShardingGradView: + def __init__( + self, + param, + param_buffer, + grad_buffer, + index, + padded_size, + sharding_degree, + rank, + use_main_grad=False, + ): + self._param = param + self._param_buffer = param_buffer + self._grad_buffer = grad_buffer + self._index = index + self._padded_size = padded_size + self._sharding_degree = sharding_degree + self._rank = rank + shard_size = param_buffer._numel() // sharding_degree + rank_begin = rank * shard_size + rank_end = rank_begin + shard_size + + param_begin = max(self._index, rank_begin) + param_end = min(self._index + self._padded_size, rank_end) + self._param_begin = param_begin + self._param_end = param_end + + self._slice_grad = None + if param_begin < param_end: + self._slice_grad = grad_buffer._slice(param_begin, param_end) + + # share grad buffer + tmp_grad = grad_buffer._slice(self._index, self._index + param._numel()) + tmp_grad.get_tensor()._set_dims(param.shape) + if not use_main_grad: + self._param._copy_gradient_from(tmp_grad) + else: + self._param.main_grad = tmp_grad + + # share param buffer + self._share_param_buffer() + + def _share_param_buffer(self): + param_shape = self._param.shape + stop_gradient = self._param.stop_gradient + self._param.stop_gradient = True + self._param.flatten_() + self._param_buffer[ + self._index : self._index + self._param._numel() + ] = self._param + self._param.get_tensor()._set_dims(param_shape) + self._param.stop_gradient = stop_gradient + self._param_buffer._slice( + self._index, self._index + self._param._numel() + )._share_buffer_to(self._param) + + def fill_slice_param(self, slice_param): + slice_begin = self._param_begin + slice_end = self._param_end + if slice_param._is_initialized(): + assert self._param_buffer._is_shared_buffer_with(slice_param) + assert len(slice_param.shape) == 1 + assert slice_param.shape[0] == (slice_end - slice_begin) + slice_begin = self._param_begin + slice_end = self._param_end + slice_buffer = self._param_buffer._slice(slice_begin, slice_end) + slice_param.get_tensor()._set_dims([slice_end - slice_begin]) + slice_buffer._share_buffer_to(slice_param) + + def assign_slice_grad(self, slice_param): + assert self._param_buffer._is_shared_buffer_with(self._param) + slice_grad = self._slice_grad + if slice_grad is None: + return + self.fill_slice_param(slice_param) + if hasattr(self._param, "main_grad"): + if not hasattr(slice_param, "main_grad"): + slice_param.main_grad = slice_grad + else: + assert slice_param.main_grad is slice_grad + elif slice_grad is not None: + if slice_param.grad is None: + slice_param._copy_gradient_from(slice_grad) + else: + assert slice_param.grad._is_shared_buffer_with(slice_grad) + + +def build_reduce_scatter_buffer( + parameters, sharding_degree, rank, use_main_grad=False +): + total_buffer_size = 0 + param2index = {} + dtype = parameters[0].dtype + + def get_padded_size(param): + size = np.prod(param.shape) + align_size = alignment["gpu"] // align[dtype] + align_size = align_size * sharding_degree + padded_size = ((size + align_size - 1) // align_size) * align_size + return padded_size + + for param in parameters: + assert param.trainable, "param must be trainable..." + param2index[param.name] = total_buffer_size + total_buffer_size += get_padded_size(param) + + grad_dtype = paddle.float32 if use_main_grad else dtype + + param_buffer = paddle.zeros(shape=[total_buffer_size], dtype=dtype) + grad_buffer = paddle.zeros(shape=[total_buffer_size], dtype=grad_dtype) + + sharding_grad_view = {} + for param in parameters: + padded_size = get_padded_size(param) + grad_view = ShardingGradView( + param, + param_buffer, + grad_buffer, + param2index[param.name], + padded_size, + sharding_degree, + rank, + use_main_grad, + ) + # hack main_grad + sharding_grad_view[param.name] = grad_view + return sharding_grad_view, param_buffer, grad_buffer + + +def get_grad_address(param, use_main_grad): + addr = None + if use_main_grad: + if param.main_grad is not None: + addr = param.main_grad.data_ptr() + else: + if (param.grad is not None) and param.grad._is_initialized(): + addr = param.grad.data_ptr() + return addr + + class FusedCommBuffer: def __init__( self, @@ -151,6 +294,8 @@ def __init__( self._act = act if self._act == HOOK_ACTION.ALL_REDUCE: assert dst == -1 + elif self._act == HOOK_ACTION.REDUCE_SCATTER: + assert dst == -1 elif self._act == HOOK_ACTION.REDUCE: assert dst != -1 else: @@ -160,35 +305,45 @@ def __init__( self._dst = dst self._init_step_dict() - - if self._fuse_param: - self.param_storage, self.grad_storage = flatten_dense_tensors( - self._params, - use_main_grad=use_main_grad, - fuse_param=True, - warp_buffer=True, - ) - self.param_storage = self.param_storage.buffer - self.grad_storage = self.grad_storage.buffer + if self._act != HOOK_ACTION.REDUCE_SCATTER: + if self._fuse_param: + self.param_storage, self.grad_storage = flatten_dense_tensors( + self._params, + use_main_grad=use_main_grad, + fuse_param=True, + warp_buffer=True, + ) + self.param_storage = self.param_storage.buffer + self.grad_storage = self.grad_storage.buffer + else: + self.param_storage = None + self.grad_storage = flatten_dense_tensors( + self._params, + use_main_grad=self.use_main_grad, + fuse_param=False, + warp_buffer=False, + ).buffer else: - self.param_storage = None - self.grad_storage = flatten_dense_tensors( + assert not self._fuse_param, "not supported" + ( + self._sharding_param_grad_view, + self.param_storage, + self.grad_storage, + ) = build_reduce_scatter_buffer( self._params, + self._comm_group.nranks, + self._comm_group.rank, use_main_grad=self.use_main_grad, - fuse_param=False, - warp_buffer=False, - ).buffer - + ) + # hack, for parameter sync in dygraph sharding optimizer after step + self._params[0].comm_buffer_ref = weakref.ref(self) self._record_addr() def _record_addr(self): for param in self._params: - addr = ( - param.main_grad.data_ptr() - if self.use_main_grad - else param.grad.data_ptr() + self._grads_to_addr[param.name] = get_grad_address( + param, self.use_main_grad ) - self._grads_to_addr[param.name] = addr def _init_step_dict(self): for p in self._params: @@ -208,11 +363,9 @@ def _all_params_checked_in(self): def add_grad(self, param, use_comm=True): assert param.name in self._params_step_dict - current_ptr = ( - param.main_grad.data_ptr() - if self.use_main_grad - else param.grad.data_ptr() - ) + + current_ptr = get_grad_address(param, self.use_main_grad) + if self._grads_to_addr[param.name] != current_ptr: raise ValueError( "The address of the grad/main_grad of the param has been changed during training, " @@ -230,6 +383,28 @@ def add_grad(self, param, use_comm=True): if self._all_params_checked_in and use_comm: self.comm_grads() + @imperative_base.no_grad + def assign_slice_grad(self, param, slice_param): + assert self._act == HOOK_ACTION.REDUCE_SCATTER + assert param.name in self._sharding_param_grad_view + grad_view = self._sharding_param_grad_view[param.name] + grad_view.assign_slice_grad(slice_param) + + @imperative_base.no_grad + def sync_params(self): + assert self._act == HOOK_ACTION.REDUCE_SCATTER + full_buffer = self.param_storage + group = self._comm_group + shard_size = full_buffer._numel() // group.nranks + begin = shard_size * group.rank + end = begin + shard_size + slice_buffer = full_buffer._slice(begin, end) + group.process_group.all_gather(slice_buffer, full_buffer).wait() + + @property + def params(self): + return self._params + @imperative_base.no_grad def comm_grads(self): assert self._all_params_checked_in, ( @@ -238,7 +413,10 @@ def comm_grads(self): len(self._params), self._params_checked_in ) ) + self._comm_grads() + @imperative_base.no_grad + def _comm_grads(self): if not self._scale_after_comm: scale_factor = 1.0 / self._comm_group.nranks self.grad_storage.scale_(scale_factor) @@ -256,6 +434,17 @@ def comm_grads(self): sync_op=False, ) + elif self._act == HOOK_ACTION.REDUCE_SCATTER: + shard_size = self.grad_storage._numel() // self._comm_group.nranks + begin = shard_size * self._comm_group.rank + end = begin + shard_size + reduce_scattered = self.grad_storage._slice(begin, end) + task = paddle.distributed.reduce_scatter( + reduce_scattered, + self.grad_storage, + group=self._comm_group, + sync_op=False, + ) self._task = task @imperative_base.no_grad diff --git a/python/paddle/distributed/launch/controllers/master.py b/python/paddle/distributed/launch/controllers/master.py index d625887b8167f..27e294907304b 100644 --- a/python/paddle/distributed/launch/controllers/master.py +++ b/python/paddle/distributed/launch/controllers/master.py @@ -197,8 +197,9 @@ def __init__(self, ctx): host, port = self.endpoint.split(':') if ctx.is_auto_tuner_mode(): - self.etcd_client = ETCDClient(host=host, port=port) - self.client = etcd3.client(host=host, port=port) + self.client = ETCDClient(host=host, port=port) + else: + self.client = etcd3.client(host=host, port=port) def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int): ''' @@ -256,22 +257,13 @@ def register_heartbeat(self, job_id, pod_id, ttl=10): self.job_prefix = f'/paddle/{job_id}' self.heartbeat_prefix = f'{self.job_prefix}/heartbeat' - if self.ctx.is_auto_tuner_mode(): - self.etcd_client.delete_prefix(self.job_prefix) - lease = self.etcd_client.lease(ttl) - else: - self.client.delete_prefix(self.job_prefix) - lease = self.client.lease(ttl) + self.client.delete_prefix(self.job_prefix) + lease = self.client.lease(ttl) # self.client.delete_prefix(self.job_prefix) beat_path = f"{self.heartbeat_prefix}/{pod_id}" - if self.ctx.is_auto_tuner_mode(): - self.etcd_client.put( - beat_path, pod_id.encode('latin-1'), lease=lease - ) - else: - self.client.put(beat_path, pod_id.encode('latin-1'), lease=lease) + self.client.put(beat_path, pod_id.encode('latin-1'), lease=lease) def _beat_watch(event): self.ctx.status.restart() diff --git a/python/paddle/distributed/launch/controllers/watcher.py b/python/paddle/distributed/launch/controllers/watcher.py index 25855572620f8..fd5571c39d443 100644 --- a/python/paddle/distributed/launch/controllers/watcher.py +++ b/python/paddle/distributed/launch/controllers/watcher.py @@ -23,7 +23,7 @@ class Watcher: def __init__(self, ctx): self.ctx = ctx - self.interval = 30 + self.interval = 5 self.gpu_util = [] diff --git a/python/paddle/distributed/launch/main.py b/python/paddle/distributed/launch/main.py index 0a7dff06dc227..e24984e6f1479 100644 --- a/python/paddle/distributed/launch/main.py +++ b/python/paddle/distributed/launch/main.py @@ -527,6 +527,9 @@ def launch(): # build AutoTuner to get new config auto_tuner = AutoTuner(tuner_cfg) + logger.info( + f"Launch {len(auto_tuner.algo.all_tasks)} tasks by auto tuner: " + ) cur_cfg = auto_tuner.search_once() auto_tuner.add_cfg(cur_cfg) assert cur_cfg is not None, "No config can run." @@ -557,7 +560,9 @@ def launch(): cur_cfg["acc_steps"], ) - ctx.args.log_dir = log_dir + ctx.args.log_dir = os.path.join( + os.path.dirname(ctx.args.auto_tuner_json), log_dir + ) # every task has own job id job_id += 1 @@ -651,6 +656,7 @@ def launch(): elif "OK" not in status: timeout_flag = False + has_error = False if err & (1 << 0): ctx.logger.warning( f"Read metric failed for parameters: {log_dir}" @@ -660,6 +666,7 @@ def launch(): cur_cfg['time'] = -1 cur_cfg[tuner_cfg['metric_cfg']['name']] = None cur_cfg["max_mem_usage"] = mem if not OOM_flag else "OOM" + has_error = True if err & (1 << 1): ctx.logger.warning(f"Out of memory for parameters: {log_dir}") @@ -668,6 +675,7 @@ def launch(): cur_cfg['time'] = -1 cur_cfg[tuner_cfg['metric_cfg']['name']] = None cur_cfg["max_mem_usage"] = "OOM" + has_error = True # not err & (1 << 1): do not record memory usage when out of memory if err & (1 << 2) and not err & (1 << 1): @@ -679,20 +687,23 @@ def launch(): ) cur_cfg["max_mem_usage"] = None if not OOM_flag else "OOM" - if not err and timeout_flag: + if not has_error and timeout_flag: # for pruner use cur_cfg['time'] = metric cur_cfg[tuner_cfg['metric_cfg']['name']] = metric cur_cfg["max_mem_usage"] = mem if not OOM_flag else "OOM" - if not err and not timeout_flag: + if not has_error and not timeout_flag: cur_cfg['time'] = -1 cur_cfg[tuner_cfg['metric_cfg']['name']] = None cur_cfg["max_mem_usage"] = None if not OOM_flag else "OOM" # record history + if tuner_cfg['metric_cfg']['name'] not in cur_cfg: + cur_cfg[tuner_cfg['metric_cfg']['name']] = None cur_cfg['job_id'] = job_id recorder.add_cfg(**cur_cfg) + recorder.store_history(history_file_path) cur_best_cfgs, err = recorder.get_best( metric=tuner_cfg['metric_cfg']['name'], direction=tuner_cfg['metric_cfg']['OptimizationDirection'], @@ -700,7 +711,6 @@ def launch(): if not err: ctx.logger.info(f"Current best config: {cur_best_cfgs}") logger.info(f"Current best config: {cur_best_cfgs}") - recorder.store_history(history_file_path) else: ctx.logger.info( "Get best config failed. Currently there are no appropriate configs." @@ -789,13 +799,17 @@ def launch(): ctx.logger.info(f"AutoTuner ends in {end_time-start_time}s.") logger.info(f"AutoTuner ends in {end_time-start_time}s.") # launch best cfg + if not tuner_cfg.get("run_best", True): + sys.exit() new_args = gen_new_args(raw_args, best_cfg, tuner_cfg, run_best=True) ctx.run_best = True ctx.args.training_script_args = new_args ctx.args.job_id = "best_cfg" ctx.logger.info(f"Launch best cfg from auto tuner: {best_cfg}") logger.info(f"Launch best cfg from auto tuner: {best_cfg}") - ctx.args.log_dir = "best_cfg" + ctx.args.log_dir = ctx.args.log_dir = os.path.join( + os.path.dirname(ctx.args.auto_tuner_json), "best_cfg" + ) # run best cfg c = controllers.init(ctx) c.run() diff --git a/python/paddle/distributed/launch/utils/etcd_client.py b/python/paddle/distributed/launch/utils/etcd_client.py index e4bbf8e1409a4..a96c7a034fdb1 100644 --- a/python/paddle/distributed/launch/utils/etcd_client.py +++ b/python/paddle/distributed/launch/utils/etcd_client.py @@ -140,3 +140,41 @@ def lease(self, ttl, lease_id=None): if times >= self.retry_times: raise ValueError(f"Lease failed after {self.retry_times} times.") + + def add_watch_prefix_callback(self, key_prefix, callback, **kwargs): + times = 0 + while times < self.retry_times: + try: + return self.client.add_watch_prefix_callback( + key_prefix, callback, **kwargs + ) + break + except Exception as e: + times += 1 + logging.info( + f"Add watch prefix callback failed with exception {e}, retry after 1 second." + ) + time.sleep(1) + + if times >= self.retry_times: + raise ValueError( + f"Add watch prefix callback failed after {self.retry_times} times." + ) + + def cancel_watch(self, watch_id): + times = 0 + while times < self.retry_times: + try: + return self.client.cancel_watch(watch_id) + break + except Exception as e: + times += 1 + logging.info( + f"Cancel watch failed with exception {e}, retry after 1 second." + ) + time.sleep(1) + + if times >= self.retry_times: + raise ValueError( + f"Cancel watch failed after {self.retry_times} times." + ) diff --git a/python/paddle/distributed/launch/utils/nvsmi.py b/python/paddle/distributed/launch/utils/nvsmi.py index da44600615458..232ccce2209cc 100644 --- a/python/paddle/distributed/launch/utils/nvsmi.py +++ b/python/paddle/distributed/launch/utils/nvsmi.py @@ -16,6 +16,9 @@ import os import shutil import subprocess +import time + +import paddle class Info: @@ -73,6 +76,39 @@ def query_smi(query=None, query_type="gpu", index=None, dtype=None): return ret +def query_rocm_smi(query=None, index=None, dtype=None, mem=32150): + if not has_rocm_smi(): + return [] + + cmd = ["rocm-smi"] + + if not isinstance(dtype, list) or len(dtype) != len(query): + dtype = [str] * len(query) + + output = subprocess.check_output(cmd, timeout=3) + lines = output.decode("utf-8").split(os.linesep) + ret = [] + for line in lines: + if not line: + continue + if len(line.split()) != 8 or "DCU" in line.split(): + continue + info = Info() + line = line.split() + line = [ + line[0], + line[7][: len(line[7]) - 1], + mem, + mem * float(line[6][: len(line[6]) - 1]) / 100, + mem - mem * float(line[6][: len(line[6]) - 1]) / 100, + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + ] + for k, v, d in zip(query, line, dtype): + setattr(info, k.replace(".", "_"), d(v)) + ret.append(info) + return ret + + def get_gpu_info(index=None): q = "index,uuid,driver_version,name,gpu_serial,display_active,display_mode".split( "," @@ -97,7 +133,8 @@ def get_gpu_util(index=None): if index is None or isinstance(index, list) else str(index).split(",") ) - + if paddle.device.is_compiled_with_rocm(): + return query_rocm_smi(q, index=index, dtype=d) return query_smi(q, index=index, dtype=d) @@ -117,6 +154,10 @@ def has_nvidia_smi(): return shutil.which("nvidia-smi") +def has_rocm_smi(): + return shutil.which("rocm-smi") + + if __name__ == '__main__': print(get_gpu_info(0)) print(get_gpu_util(0)) diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 8890ab0bd179a..44102e60cda61 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -181,7 +181,6 @@ def sync_params_buffers( paddle.distributed.broadcast( coalesced_var, src=src_rank, group=comm_group, sync_op=True ) - for coalesced_var, origin_vars, var_shapes in coalesced_vars: var_len = [np.prod(v_shape) for v_shape in var_shapes] paddle.base.framework._dygraph_tracer().trace_op( @@ -685,6 +684,7 @@ def __init__(self): self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0")) self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) self._device_type = str(os.getenv("PADDLE_XCCL_BACKEND", "")) + self._pg_timeout = int(os.getenv("PADDLE_PG_TIMEOUT", "1800000")) # imperative only support one gpu or xpu if self._device_type != "": @@ -849,6 +849,24 @@ def nrings(self): """ return self._nrings + @property + def pg_timeout(self): + """ + timeout of process group. + + Its value is equal to the value of the environment variable ``PADDLE_PG_TIMEOUT`` . The default value is 30 minutes. + + Examples: + .. code-block:: python + + >>> # execute this command in terminal: export PADDLE_PG_TIMEOUT=1800000 + >>> import paddle.distributed as dist + + >>> env = dist.ParallelEnv() + >>> # the pg_timeout of process group 1800000 + """ + return self._pg_timeout + # [aliases] Compatible with old method names local_rank = rank nranks = world_size @@ -1098,7 +1116,6 @@ def init_parallel_env(): # TODO(mine): support XPU and other backends. if backend in ["nccl", 'xccl', 'bkcl']: core.CommContextManager.set_device_id(parallel_env.device_id) - paddle.distributed.barrier(group=group) return group node_num = {i.split(":")[0] for i in parallel_env.trainer_endpoints} diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index e2f54d47a4e08..8c1f4ab6e5350 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -24,7 +24,7 @@ from .auto_parallel_grad_clip import * # noqa: F403 from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403 from .auto_parallel_pipeline import * # noqa: F403 -from .column_parallel_linear_backward_overlapping import * # noqa: F403 +from .allreduce_matmul_grad_overlapping import * # noqa: F403 from .cpp_pass import * # noqa: F403 from .fuse_all_reduce import * # noqa: F403 from .pipeline_scheduler_pass import * # noqa: F403 diff --git a/python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py similarity index 98% rename from python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py rename to python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py index aa5dbd7d267e1..c6457b612ff81 100644 --- a/python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py +++ b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py @@ -27,8 +27,8 @@ # dY = matmul(X^T, dOut) # # Then the c_allreduce_sum can overlap with the compute of dY. -@register_pass("column_parallel_linear_backward_overlapping") -class ColumnParallelLinearBackwardOverlappingPass(PassBase): +@register_pass("allreduce_matmul_grad_overlapping") +class AllreduceMatmulGradOverlappingPass(PassBase): def __init__(self): super().__init__() self.set_attr("allreduce_stream", None) diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index 53bdca47c48a5..b8cd7a6b8d5d7 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -215,7 +215,7 @@ def build_state(self): fwd_op_id = self.grad_op_to_op_map[ op.desc.original_id() ] - assert fwd_op_id in self._op_fp16_dict, f"{str(op)}" + assert fwd_op_id in self._op_fp16_dict, str(op) self._op_fp16_dict[ op.desc.original_id() ] = self._is_fp16_op(fwd_op_id) diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index 3cda24f1a0f64..9b26a0980e55f 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -245,14 +245,10 @@ def _update_opt_rescale_grad(self): ): assert op.has_attr( 'rescale_grad' - ), "Unexpected: op [{}] is supported to have [rescale_grad] attribute.".format( - str(op) - ) + ), f"Unexpected: op [{str(op)}] is supported to have [rescale_grad] attribute." assert ( len(op.input("Grad")) == 1 - ), "Unexpected: op [{}] is supported to have only one input grad var.".format( - str(op) - ) + ), f"Unexpected: op [{str(op)}] is supported to have only one input grad var." grad_name = op.input("Grad")[0] dp_degree = len( diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 82475251ee516..00d34b4c11fce 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -722,9 +722,7 @@ def is_initialization_op(op): if param_to_dtype.get(output_name, None) == __target_dtype__: assert op.has_attr( 'dtype' - ), "initialization op is supported to has dtype attribute but got {}.".format( - str(op) - ) + ), f"initialization op is supported to has dtype attribute but got {str(op)}." out_var = startup_program.global_block().var(output_name) if out_var.dtype == core.VarDesc.VarType.FP32: out_var.desc.set_dtype(__target_dtype__) diff --git a/python/paddle/distributed/passes/auto_parallel_grad_clip.py b/python/paddle/distributed/passes/auto_parallel_grad_clip.py index 327b208518ee8..cd9b558a31313 100644 --- a/python/paddle/distributed/passes/auto_parallel_grad_clip.py +++ b/python/paddle/distributed/passes/auto_parallel_grad_clip.py @@ -401,6 +401,33 @@ def _remove_no_need_ops_vars(self, block): else: op.desc.set_input("X", reserved_vars) + elif op.type == 'stack': + # 'stack' op is also used to calculate global_norm ('stack' + 'reduce_sum'), and need to filter inputs which is not in cur_rank + reserved_vars = [] + for input_name in op.input_arg_names: + if ( + input_name not in removed_tmp_var + and self.clip_helper.is_local_var_with_dist_attr( + input_name + ) + ): + reserved_vars.append(input_name) + if not reserved_vars: + removed_op_idx.add(idx) + removed_tmp_var.update(set(op.output_arg_names)) + if block.ops[idx + 1].type == 'reduce_sum': + removed_op_idx.add(idx + 1) + removed_tmp_var.update( + set(block.ops[idx + 1].output_arg_names) + ) + if block.ops[idx + 2].type == 'cast': + removed_op_idx.add(idx + 2) + removed_tmp_var.update( + set(block.ops[idx + 2].output_arg_names) + ) + else: + op.desc.set_input("X", reserved_vars) + for idx, op in reversed(list(enumerate(block.ops))): if not is_optimize_op(op): break diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index f804b59a2db2c..b9efd6fb332a0 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -20,11 +20,16 @@ get_world_process_group, ) from paddle.distributed.auto_parallel.static.utils import ( + is_forward_op, is_optimize_op, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, ) -from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole +from paddle.distributed.fleet.meta_optimizers.common import ( + OP_ROLE_KEY, + OP_ROLE_VAR_KEY, + OpRole, +) from paddle.framework import core from paddle.static import device_guard @@ -48,8 +53,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context): removed_op_idx.append(idx) # del op from dist_context - if dist_context: - dist_context.del_dist_op_for_program(op) + # if dist_context: + # dist_context.del_dist_op_for_program(op) for idx in removed_op_idx[::-1]: main_block._remove_op(idx, sync=False) @@ -157,68 +162,104 @@ def _append_gradient_merge_backward_op( startup_block = startup_program.global_block() # step1: remove grad.op's op_role_var + grad_to_params_grads = {} for param, grad in params_grads: assert ( param.type != core.VarDesc.VarType.SELECTED_ROWS ), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now" + grad_to_params_grads[grad.name] = (param, grad) # {grad.name: gradient_merge_var.name} to rename opt inputs grad_to_gradient_merge = {} # {param: gradient_merge_var} to insert scale op and fill_constant op - new_params_to_grads = [] - # step2: create gradient_merge var and init with 0 - for param, grad in params_grads: - param_name = param.name - param_var = main_block.var(param_name) - assert param_var is not None - ref_dist_attr = dist_context.get_tensor_dist_attr_for_program(param_var) - assert ref_dist_attr is not None - gradient_merge_var = main_block.create_var( - name=param_name + "@GRAD@GradientMerge", - shape=param_var.shape, - dtype=param_var.dtype, - persistable=True, - ) - ref_process_mesh = ref_dist_attr.process_mesh - ref_dims_mapping = ref_dist_attr.dims_mapping + new_params_grads = [] + + for index, op in reversed(list(enumerate(main_block.ops))): + if len(grad_to_params_grads) == 0: + break + if is_forward_op(op): + break + + for out_name in op.desc.output_arg_names(): + if out_name in grad_to_params_grads: + param = grad_to_params_grads[out_name][0] + assert param is not None + ref_dist_attr = dist_context.get_tensor_dist_attr_for_program( + param + ) + assert ref_dist_attr is not None + + # step2: create gradient_merge var and init with 0 + # Add persistable gradient variables in main_program + gradient_merge_var = main_block.create_var( + name=param.name + "@GRAD@MERGE", + shape=param.shape, + dtype=param.dtype, + persistable=True, + ) + ref_process_mesh = ref_dist_attr.process_mesh + ref_dims_mapping = ref_dist_attr.dims_mapping + set_var_dist_attr( + dist_context, + gradient_merge_var, + ref_dims_mapping, + ref_process_mesh, + ) - set_var_dist_attr( - dist_context, gradient_merge_var, ref_dims_mapping, ref_process_mesh - ) + # Add persistable gradient variables in startup_program + startup_gradient_merge_var = startup_block.create_var( + name=param.name + "@GRAD@MERGE", + shape=param.shape, + dtype=param.dtype, + persistable=True, + ) + # Initial persistable gradient variables in startup_program + startup_block.append_op( + type="fill_constant", + outputs={"Out": startup_gradient_merge_var}, + attrs={ + "shape": param.shape, + "dtype": param.dtype, + "value": float(0), + }, + ) - startup_gradient_merge_var = startup_block.create_var( - name=param_name + "@GRAD@GradientMerge", - shape=param_var.shape, - dtype=param_var.dtype, - persistable=True, - ) - startup_block.append_op( - type="fill_constant", - outputs={"Out": startup_gradient_merge_var}, - attrs={ - "shape": param_var.shape, - "dtype": param_var.dtype, - "value": float(0), - }, - ) + # step3: Accumulate persistable gradient variables in main_program + grad = grad_to_params_grads[out_name][1] + assert grad is not None + # NOTE(zhaoyingli): inplace operation must be 'a = a + b', cannot be 'a = b + a' + new_grad_op = main_block._insert_op_without_sync( + index + 1, + type="elementwise_add", + inputs={'X': gradient_merge_var, 'Y': grad}, + outputs={'Out': gradient_merge_var}, + attrs={ + 'axis': -1, + 'use_mkldnn': False, + OP_ROLE_KEY: OpRole.Backward, + }, + ) - # grad_merge += grad - new_grad_op = main_block.append_op( - type="elementwise_add", - inputs={'X': grad, 'Y': gradient_merge_var}, - outputs={'Out': gradient_merge_var}, - attrs={ - 'axis': -1, - 'use_mkldnn': False, - OP_ROLE_KEY: OpRole.Backward, - }, - ) - new_params_to_grads.append([param, gradient_merge_var]) - grad_to_gradient_merge[grad.name] = gradient_merge_var.name - naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context - ) - return new_params_to_grads, grad_to_gradient_merge + # Construct new_params_grads and grad_to_gradient_merge + new_params_grads.append([param, gradient_merge_var]) + grad_to_gradient_merge[grad.name] = gradient_merge_var.name + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + new_grad_op, + ref_process_mesh, + ref_dims_mapping, + dist_context, + ) + + del grad_to_params_grads[out_name] + + assert ( + len(grad_to_params_grads) == 0 + ), "grad_to_param_names must be empty right now, but it has {} items".format( + len(grad_to_params_grads) + ) + main_block._sync_with_cpp() + + return new_params_grads, grad_to_gradient_merge def _create_cond_block_and_update_optimizer( @@ -229,6 +270,7 @@ def _create_cond_block_and_update_optimizer( optimize_ops_block, k_steps, avg, + dist_context, ): def true_apply_gradient(): cur_block_idx = main_program.current_block_idx @@ -236,11 +278,10 @@ def true_apply_gradient(): # cur_block's forward_block & backward_block is itself cur_block._set_forward_block_idx(cur_block_idx) - op_maker = core.op_proto_and_checker_maker if avg: - for param, new_grad in new_params_to_grads: + for _, new_grad in new_params_to_grads: # grad /= k_steps - cur_block.append_op( + scale_op = cur_block.append_op( type='scale', inputs={'X': new_grad}, outputs={'Out': new_grad}, @@ -250,7 +291,7 @@ def true_apply_gradient(): 'bias_after_scale': False, }, ) - new_grad.op._set_attr(OP_ROLE_KEY, OpRole.Optimize) + scale_op._set_attr(OP_ROLE_KEY, OpRole.Optimize) # append optimizer ops for opt_op_idx in range(optimize_ops_block.desc.op_size()): @@ -272,28 +313,44 @@ def true_apply_gradient(): ) # remove op_role_var - if new_op_desc.has_attr(op_maker.kOpRoleVarAttrName()): - new_op_desc.remove_attr(op_maker.kOpRoleVarAttrName()) + if new_op_desc.has_attr(OP_ROLE_VAR_KEY): + new_op_desc.remove_attr(OP_ROLE_VAR_KEY) # op's update Grad if core.grad_var_suffix() in new_op_desc.input_arg_names(): grad_value = new_op_desc.input("Grad")[0] # TODO FIXME(xym) support fp16 - grad_merge_value = grad_value + '@GradientMerge' + grad_merge_value = grad_value + '@MERGE' new_op_desc.set_input("Grad", [grad_merge_value]) main_program.global_block()._sync_with_cpp() cur_block._sync_with_cpp() + # update serial op + for op in cur_block.ops: + if is_optimize_op(op): + dist_op = dist_context.get_dist_op_for_program(op) + if dist_op: + dist_op._serial_op = op + # clear gradient_merge_vars - for param, new_grad in new_params_to_grads: - paddle.tensor.fill_constant( - shape=new_grad.shape, - dtype=new_grad.dtype, - value=0.0, - out=new_grad, + # NOTE(zhaoyingli): Must use 'set_value' op in pir to assign 0-value for persistable var. + for _, new_grad in new_params_to_grads: + cur_block.append_op( + type="set_value", + inputs={"Input": [new_grad]}, + outputs={"Out": [new_grad]}, + attrs={ + "values": [float(0)], + "dtype": new_grad.dtype, + "shape": [1], + "axes": [], + "starts": [], + "ends": [], + "steps": [], + OP_ROLE_KEY: OpRole.Optimize, + }, ) - new_grad.op._set_attr(OP_ROLE_KEY, op_maker.OpRole.Optimize) paddle.static.nn.cond(cond_var, true_fn=true_apply_gradient, false_fn=None) cond_op = main_program.global_block().ops[-1] @@ -331,6 +388,7 @@ def parse_program( optimize_ops_block, k_steps, avg, + dist_context, ) diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 5c63d93cb14b1..7a739147dc325 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -31,13 +31,17 @@ get_loss_op, insert_dependencies_for_two_ops, is_backward_op, + is_recompute_exclude_op, is_recompute_op, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_dist_op_desc_original_id, set_var_dist_attr, ) +from ..utils.log_utils import get_logger from .pass_base import PassBase, register_pass +logger = get_logger(logging.INFO) + class RecomputeState(ProgramStats): def __init__(self, block, ops): @@ -80,9 +84,13 @@ def build_states(self): if not is_recompute_op(op): self._checkpoints.extend(op.output_arg_names) - continue + if not is_recompute_exclude_op(op): + continue seg_name = op.attr('op_namescope') + seg_name = ( + seg_name if '_exclude_rc' not in seg_name else seg_name[:-11] + ) if seg_name not in self.seg_op_deps: self.seg_op_deps[seg_name] = [i] else: @@ -277,22 +285,120 @@ def _check_self(self): def _check_conflict(self, other_pass): return True + def get_ops_per_device(self, ops, all_ops_process_meshs, sr=0): + """ + Get ops and op_names of each process mesh excluding ops within the first "sr" chunks + """ + + def reset_recomupte_op(op): + if is_recompute_op(op) or is_recompute_exclude_op(op): + op._set_attr("op_namescope", "") + + all_process_meshes_count = len(all_ops_process_meshs) + ops_of_stages = [[] for _ in range(all_process_meshes_count)] + op_names_of_stages = [[] for _ in range(all_process_meshes_count)] + pushed_ops_count = 0 + reset_ops_count = 0 + chunk_id = 0 + for op_id, op in enumerate(ops): + if chunk_id // all_process_meshes_count < sr: + reset_ops_count += 1 + reset_recomupte_op(op) + if ( + op_id < len(ops) - 1 + and op.dist_attr.process_mesh + != ops[op_id + 1].dist_attr.process_mesh + ): + chunk_id += 1 + if chunk_id // all_process_meshes_count < sr: + continue + + for id, process_mesh in enumerate(all_ops_process_meshs): + if op.dist_attr.process_mesh == process_mesh: + pushed_ops_count += 1 + ops_of_stages[id].append(op) + op_names_of_stages[id].append(op.type) + assert ( + len(ops) == reset_ops_count + pushed_ops_count + ), "The sum of pushed_ops_count and reset_ops_count must be the same as lenght of ops, but the sum is {} while lenght of ops is {}".format( + reset_ops_count + pushed_ops_count, len(ops) + ) + return ops_of_stages, op_names_of_stages + def _apply_single_impl(self, main_program, startup_program, context): loss = self.get_attr("loss") no_grad_set = self.get_attr("no_grad_set") no_recompute_segments = self.get_attr("no_recompute_segments") self._dist_context = self.get_attr("dist_context") + self._sr = self.get_attr("sr", 0) + self._refined_ops_patterns = self.get_attr("refined_ops_patterns", []) # 0. get op_path which is related to loss main_block = main_program.global_block() op_path = _find_op_path(main_program, loss, no_grad_set) - # 1. build recompute state + # 1. mark exclude ops for refined-reompute according to ops-patterns(mainly linear and flash_attn) + # 1.1 get all process_meshs in op_path + all_ops_process_meshs = [] + for op in op_path: + if op.dist_attr.process_mesh not in all_ops_process_meshs: + all_ops_process_meshs.append(op.dist_attr.process_mesh) + + # 1.2 get ops_devices and op_names_devices + ops_devices, op_names_devices = self.get_ops_per_device( + op_path, all_ops_process_meshs, self._sr + ) + all_ops_len = len(op_path) + all_exclude_ops_ids = [[] for _ in op_names_devices] + # 1.3 find exclude ops for refined-reompute according to ops-patterns + for refined_ops_pattern in self._refined_ops_patterns: + num = refined_ops_pattern['num'] + num = ( + num if num >= 0 else all_ops_len + ) # 'num == -1' represents to all ops + main_ops = refined_ops_pattern['main_ops'] + pre_ops = refined_ops_pattern['pre_ops'] + suf_ops = refined_ops_pattern['suf_ops'] + main_start_id = len(pre_ops) + main_ops_len = len(main_ops) + pattern_ops = pre_ops + main_ops + suf_ops + pattern_ops_len = len(pattern_ops) + + for id, op_names_device in enumerate(op_names_devices): + pattern_count = 0 + ops_len_device = len(op_names_device) + for i in range(ops_len_device - pattern_ops_len + 1): + if ( + op_names_device[i : i + pattern_ops_len] == pattern_ops + and pattern_count < num + ): + pattern_count += 1 + all_exclude_ops_ids[id].extend( + list( + range( + i + main_start_id, + i + main_start_id + main_ops_len, + ) + ) + ) + logger.info( + f"The excluded ops in recompute segments are:\n{all_exclude_ops_ids}" + ) + # 1.4 mark exclude ops in exclude_ops_ids + for id, exclude_ops_ids in enumerate(all_exclude_ops_ids): + for op_id in exclude_ops_ids: + if is_recompute_op(ops_devices[id][op_id]): + rc_mark_str = ops_devices[id][op_id].attr("op_namescope") + ops_devices[id][op_id]._set_attr( + "op_namescope", rc_mark_str + "_exclude_rc" + ) + + # 2. build recompute state rc_state = RecomputeState(main_block, op_path) if not rc_state.is_recompute(): return - # 2. get the segments to be recomputed + # 3. get the segments to be recomputed rc_state.modify_forward_desc_for_recompute(self._dist_context) rc_state.build_states() segments = rc_state.get_recompute_segments(no_recompute_segments) @@ -300,15 +406,15 @@ def _apply_single_impl(self, main_program, startup_program, context): return for i, (idx1, idx2) in enumerate(segments): - logging.info(f"recompute segment[{i + 1}/{len(segments)}]") - logging.info( + logger.debug(f"recompute segment[{i + 1}/{len(segments)}]") + logger.debug( "segment start op: [{}]: [{}] [{}]".format( rc_state.ops[idx1].type, rc_state.ops[idx1].input_arg_names, rc_state.ops[idx1].output_arg_names, ) ) - logging.info( + logger.debug( "segment end op: [{}]: [{}] [{}]".format( rc_state.ops[idx2 - 1].type, rc_state.ops[idx2 - 1].input_arg_names, @@ -316,14 +422,15 @@ def _apply_single_impl(self, main_program, startup_program, context): ) ) - # 3. get vars that should be hold in memory + # 4. get vars that should be hold in memory + # list of var_names vars_should_be_hold = [] for segment in segments: vars_should_be_hold.extend( rc_state.get_out_of_subgraph_vars(segment[0], segment[1]) ) cross_vars = set(vars_should_be_hold) - set(rc_state.checkpoints) - logging.info( + logger.debug( "found [{}] vars which cross recompute segment: [{}]," "better checkpoints might be set to reduce those vars".format( len(cross_vars), cross_vars @@ -335,7 +442,7 @@ def _apply_single_impl(self, main_program, startup_program, context): set(vars_should_be_hold) | set(rc_state.checkpoints) ) - # 4. get the fwd ops desc to be recomputed. + # 5. get the fwd ops desc to be recomputed. var_name_dict = {} # varname --> varname.subprog_XXX ckpt_ops_dict = {} # ckpt_op_id --> segment_descs buffer_block = main_block.program._create_block() @@ -406,7 +513,7 @@ def _apply_single_impl(self, main_program, startup_program, context): ckpt_op = op_path[segment[1] - 1] ckpt_ops_dict[ckpt_op.desc.original_id()] = [True, segment_descs] - # 5. insert recomputed fwd ops into backward parse + # 6. insert recomputed fwd ops into backward parse ops = main_block.ops loss_op = get_loss_op(main_block) loss_op_idx = _find_op_index(main_block, loss_op) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index f7b211fdc4ba4..6c3ee4d8d8e95 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -1691,11 +1691,10 @@ def re_order_program(block, param_grads, dist_context): if is_optimize_op(last_op) and last_op.type in _supported_optimizer_type: # record optimizer for idx, op in reversed(list(enumerate(block.ops))): - if op.type not in _supported_optimizer_type: - break - assert len(op.input("Param")) == 1 - pname_to_op[op.input("Param")[0]] = op - remove_op_indices.append(idx) + if op.type in _supported_optimizer_type: + assert len(op.input("Param")) == 1 + pname_to_op[op.input("Param")[0]] = op + remove_op_indices.append(idx) assert len(use_order) == len(pname_to_op) # append new opts diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index d5eb98d7422de..a7b68d923be2c 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -17,13 +17,13 @@ from enum import Enum from paddle.base import core -from paddle.base.framework import Parameter, Program +from paddle.base.framework import Operator, Parameter, Program, get_flags from paddle.distributed.auto_parallel.static.utils import ( get_logger, is_backward_op, is_forward_op, - is_lr_sched_op, is_optimize_op, + use_new_executor, ) from paddle.distributed.fleet.meta_optimizers.common import OpRole @@ -35,6 +35,8 @@ core.VarDesc.VarType.FETCH_LIST, ] +logger = get_logger(logging.INFO) + # NOTE: Here stream is just a presentation with different name, # it is up to executor to create the exact streams given the name. @@ -223,7 +225,44 @@ def var_can_be_deleted(var_name, block): return var is not None and not var.persistable -def set_skip_gc_vars(num_micro_batches, type_to_program, jobs): +def prepare_ir_program(cur_prog, next_prog): + set_output_names = set() + for op in cur_prog.global_block().ops: + for arg_name in op.output_arg_names: + if var_can_be_deleted(arg_name, cur_prog.global_block()): + set_output_names.add(arg_name) + + set_input_names = set() + for op in next_prog.global_block().ops: + for arg_name in op.input_arg_names: + if var_can_be_deleted(arg_name, next_prog.global_block()): + set_input_names.add(arg_name) + + shadow_var_names = sorted(set_output_names & set_input_names) + for var_name in shadow_var_names: + shadow_op_desc = cur_prog.global_block().desc.append_op() + shadow_op_desc.set_type("shadow_output") + shadow_op_desc.set_input('x', [var_name]) + shadow_op_desc.set_output('out', ["@EMPTY@"]) + shadow_op_desc._set_attr("name", var_name) + shadow_op = Operator(cur_prog.global_block(), shadow_op_desc) + cur_prog.global_block().ops.append(shadow_op) + + data_op_desc = next_prog.global_block().desc._prepend_op() + data_op_desc.set_type("data") + data_op_desc._set_attr("shape", []) + data_op_desc._set_attr("dtype", 0) + data_op_desc._set_attr("place", 2) # GPUPlace + data_op_desc._set_attr("name", var_name) + data_op_desc.set_output("out", [var_name]) + data_op = Operator(next_prog.global_block(), data_op_desc) + next_prog.global_block().ops.insert(0, data_op) + + cur_prog._sync_with_cpp() + next_prog._sync_with_cpp() + + +def set_skip_gc_vars(num_micro_batches, job_types, sub_programs, jobs): """ Set `skip_gc_vars` for every job in jobs. @@ -232,6 +271,7 @@ def set_skip_gc_vars(num_micro_batches, type_to_program, jobs): and these vars cannot be gc after executing current sub_program. """ assert num_micro_batches >= 1, "num_micro_batches needs to be >= 1" + type_to_program = dict(zip(job_types, sub_programs)) # step1: Get all vars of every sub_program that are non-persistable and not in op's no_need_buffer. type_to_required_vars = {} @@ -264,7 +304,7 @@ def set_skip_gc_vars(num_micro_batches, type_to_program, jobs): required_vars = type_to_required_vars[job_type] micro_batch_id = job.micro_batch_id() skip_gc_vars = required_vars & suffixed_required_vars[micro_batch_id] - get_logger(logging.INFO).info( + logger.debug( f"Skip gc vars for {job_type}-({micro_batch_id}): {skip_gc_vars}" ) @@ -276,6 +316,19 @@ def set_skip_gc_vars(num_micro_batches, type_to_program, jobs): job.set_skip_gc_vars(skip_gc_vars) suffixed_required_vars[micro_batch_id] |= required_vars + if get_flags("FLAGS_enable_pir_in_executor")[ + 'FLAGS_enable_pir_in_executor' + ]: + for i, type in enumerate(job_types): + if i == len(job_types) - 1: + break + next_type = job_types[i + 1] + prepare_ir_program( + type_to_program[type], type_to_program[next_type] + ) + + return type_to_program + def _create_param(dst_block, src_var): copied_kwargs = {} @@ -434,13 +487,15 @@ def _insert_sync_for_fthenb_1f1b(program): var = block.var(var_name) block._remove_op(index + offset, sync=False) offset -= 1 - block._insert_op_without_sync( - index=backward_recv_index, - type="nop", - inputs={'X': [var]}, - outputs={'Out': [var]}, - attrs={'op_role': OpRole.Backward}, - ) + if not use_new_executor(): + # NOTE: new executor will make sure gc are right without using nop op. + block._insert_op_without_sync( + index=backward_recv_index, + type="nop", + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={'op_role': OpRole.Backward}, + ) block._sync_with_cpp() @@ -478,21 +533,22 @@ def _program_for_fthenb_and_1f1b(program, enable_send_recv_overlap=False): else: _insert_sync_for_fthenb_1f1b(program) - lr_prog = Program() fwd_prog = Program() bwd_prog = Program() opt_prog = Program() + def _is_fetch_op(op): + return op.type in ["fetch", "fetch_v2"] + # split the program based on the op_role def _split_ops(block): - lr_ops = [] fwd_ops = [] bwd_ops = [] opt_ops = [] for op in src_block.ops: - if is_lr_sched_op(op): - lr_ops.append(op) - elif is_forward_op(op): + if _is_fetch_op(op): + continue + if is_forward_op(op): fwd_ops.append(op) elif is_backward_op(op): bwd_ops.append(op) @@ -502,20 +558,17 @@ def _split_ops(block): raise ValueError( "The op role: " + str(op.attr('op_role')) - + " isn't one of LRSched, Forward, Backward or Optimizer." + + " isn't one of Forward, Backward or Optimizer." ) - return lr_ops, fwd_ops, bwd_ops, opt_ops + return fwd_ops, bwd_ops, opt_ops def _add_ops_into_block(src_block, dst_block, ops): for op in ops: _create_program(src_block, dst_block, op) for idx, src_block in enumerate(program.blocks): - lr_ops, fwd_ops, bwd_ops, opt_ops = _split_ops(src_block) + fwd_ops, bwd_ops, opt_ops = _split_ops(src_block) if idx == 0: - lr_block = lr_prog.block(0) - _add_ops_into_block(src_block, lr_block, lr_ops) - fwd_block = fwd_prog.block(0) _add_ops_into_block(src_block, fwd_block, fwd_ops) @@ -525,13 +578,6 @@ def _add_ops_into_block(src_block, dst_block, ops): opt_block = opt_prog.block(0) _add_ops_into_block(src_block, opt_block, opt_ops) else: - if len(lr_ops): - lr_block = lr_prog._create_block( - parent_idx=src_block.parent_idx - ) - lr_block._set_forward_block_idx(src_block.forward_block_idx) - _add_ops_into_block(src_block, lr_block, lr_ops) - if len(fwd_ops): fwd_block = fwd_prog._create_block( parent_idx=src_block.parent_idx @@ -557,25 +603,23 @@ def _add_ops_into_block(src_block, dst_block, ops): if fetch_op.type in ["fetch", "fetch_v2"]: in_name = fetch_op.input_arg_names[0] dst_block = None - for block in [lr_block, fwd_block, bwd_block, opt_block]: + for block in [fwd_block, bwd_block, opt_block]: if block._find_var_recursive(in_name): dst_block = block break if dst_block: _create_program(src_block, dst_block, fetch_op) - lr_prog._sync_with_cpp() fwd_prog._sync_with_cpp() bwd_prog._sync_with_cpp() opt_prog._sync_with_cpp() - lr_prog._rollback() fwd_prog._rollback() bwd_prog._rollback() opt_prog._rollback() # It MUST return in this order - return [lr_prog, fwd_prog, bwd_prog, opt_prog] + return [fwd_prog, bwd_prog, opt_prog] def _add_event_dependency(recorder_op, waiter_op): diff --git a/python/paddle/distributed/passes/pipeline_pass_base.py b/python/paddle/distributed/passes/pipeline_pass_base.py index c15ca267b6fee..386524a84f799 100644 --- a/python/paddle/distributed/passes/pipeline_pass_base.py +++ b/python/paddle/distributed/passes/pipeline_pass_base.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import paddle from paddle.base import core +from ..utils.log_utils import get_logger from .pass_base import PassBase from .pass_utils import set_skip_gc_vars +logger = get_logger(logging.INFO) + class PipelinePassBase(PassBase): def __init__(self): @@ -41,7 +46,7 @@ def _partial_programs(self, program): The return value MUST be two lists, one is a list of types(str), another is a list of sub programs. For example: - return [LR, FORWARD, BACKWARD, OPT], [lr_prog, fwd_prog, bwd_prog, opt_prog] + return [FORWARD, BACKWARD, OPT], [fwd_prog, bwd_prog, opt_prog] or return [FORWARD], [fwd_prog] """ @@ -53,14 +58,24 @@ def _apply_single_impl(self, main_program, startup_program, context): to implement two interfaces above, 'create_job_list' and 'partial_programs'. """ job_types, sub_programs = self._partial_programs(main_program) + for i in range(len(job_types)): + logger.debug( + f"sub_program type: {job_types[i]}, sum_program:\n{sub_programs[i]}" + ) jobs = self._create_job_list() - type_to_program = dict(zip(job_types, sub_programs)) - set_skip_gc_vars( - self.get_attr("num_micro_batches"), type_to_program, jobs + type_to_program = set_skip_gc_vars( + self.get_attr("num_micro_batches"), job_types, sub_programs, jobs ) for type in type_to_program.keys(): - type_to_program[type] = type_to_program[type].desc + if paddle.framework.get_flags("FLAGS_enable_pir_in_executor")[ + 'FLAGS_enable_pir_in_executor' + ]: + type_to_program[type] = paddle.pir.translate_to_pir( + type_to_program[type].desc + ) + else: + type_to_program[type] = type_to_program[type].desc plan = core.Plan(jobs, type_to_program) context.set_attr("plan", plan) diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass.py b/python/paddle/distributed/passes/pipeline_scheduler_pass.py index ba17a7a50a8ff..4d6c43a865bc7 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass.py @@ -36,7 +36,6 @@ core.VarDesc.VarType.FETCH_LIST, ] -LR = "lr" FORWARD = "forward" BACKWARD = "backward" OPT = "optimizer" @@ -53,8 +52,6 @@ def _create_job_list(self): num_micro_batches = self.get_attr("num_micro_batches") job_list = [] - lr_job = core.Job(LR) - job_list.append(lr_job) for i in range(num_micro_batches): forward_job = core.Job(FORWARD) @@ -74,7 +71,7 @@ def _create_job_list(self): def _partial_programs(self, program): # NOTE: The flag "enable_send_recv_overlap" may increase the reserved memory of GPUs. enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap") - types = [LR, FORWARD, BACKWARD, OPT] + types = [FORWARD, BACKWARD, OPT] sub_program_list = _program_for_fthenb_and_1f1b( program, enable_send_recv_overlap ) @@ -219,9 +216,6 @@ def _create_job_list(self): pp_degree = self.get_attr("pp_degree") job_list = [] - lr_job = core.Job(LR) - job_list.append(lr_job) - assert ( pp_degree <= num_micro_batches ), "Num of micro batches should larger than or equal to pp degree." @@ -353,7 +347,7 @@ def _op_cost(self, op): def _partial_programs(self, program): # NOTE: The flag "enable_send_recv_overlap" may increase the reserved memory of GPUs. enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap") - types = [LR, FORWARD, BACKWARD, OPT] + types = [FORWARD, BACKWARD, OPT] sub_programs = _program_for_fthenb_and_1f1b( program, enable_send_recv_overlap ) @@ -379,10 +373,10 @@ def _partial_programs(self, program): ) for i in range(len(types)): - logger.info( + logger.debug( f"type = {types[i]}, sub_programs = {sub_programs[i]}\n" ) - logger.info(f"jobs_in_stable_phase = {self.jobs_in_stable_phase}") + logger.debug(f"jobs_in_stable_phase = {self.jobs_in_stable_phase}") return types, sub_programs @@ -420,9 +414,6 @@ def _create_job_list(self): pp_degree = self.get_attr("pp_degree") job_list = [] - lr_job = core.Job("lr") - job_list.append(lr_job) - assert ( 2 * (pp_degree - pp_stage) - 1 <= num_micro_batches ), "Num of micro batches should larger than 2 * (pp_degree - pp_stage) - 1." @@ -431,30 +422,30 @@ def _create_job_list(self): micro_batch_in_1f1b = num_micro_batches - micro_batch_in_warmup forward_micro_batch_id = 0 - for i in range(micro_batch_in_warmup): - forward_job = core.Job("forward") + for _ in range(micro_batch_in_warmup): + forward_job = core.Job(FORWARD) forward_job.set_micro_batch_id(forward_micro_batch_id) job_list.append(forward_job) forward_micro_batch_id += 1 backward_micro_batch_id = 0 - for i in range(micro_batch_in_1f1b): - backward_job = core.Job("backward") + for _ in range(micro_batch_in_1f1b): + backward_job = core.Job(BACKWARD) backward_job.set_micro_batch_id(backward_micro_batch_id) job_list.append(backward_job) backward_micro_batch_id += 1 - forward_job = core.Job("forward") + forward_job = core.Job(FORWARD) forward_job.set_micro_batch_id(forward_micro_batch_id) job_list.append(forward_job) forward_micro_batch_id += 1 - for i in range(micro_batch_in_warmup): - backward_job = core.Job("backward") + for _ in range(micro_batch_in_warmup): + backward_job = core.Job(BACKWARD) backward_job.set_micro_batch_id(backward_micro_batch_id) job_list.append(backward_job) backward_micro_batch_id += 1 - opt_job = core.Job("optimizer") + opt_job = core.Job(OPT) job_list.append(opt_job) return job_list @@ -462,7 +453,7 @@ def _partial_programs(self, program): # NOTE: The flag "enable_send_recv_overlap" may increase the reserved memory of GPUs. enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap") # TODO: More function will be added later. Now it uses the same logic as FTthenB and 1F1B. - types = ["lr", "forward", "backward", "optimizer"] + types = [FORWARD, BACKWARD, OPT] sub_program_list = _program_for_fthenb_and_1f1b( program, enable_send_recv_overlap ) diff --git a/python/paddle/distributed/spawn.py b/python/paddle/distributed/spawn.py index 970afae464030..ab82b2a48ad64 100644 --- a/python/paddle/distributed/spawn.py +++ b/python/paddle/distributed/spawn.py @@ -113,9 +113,7 @@ def _get_default_nprocs(): return core.get_custom_device_count(device.split(":")[0]) else: raise RuntimeError( - "`paddle.distributed.spawn` does not support parallel training on device `{}` now.".format( - device - ) + f"`paddle.distributed.spawn` does not support parallel training on device `{device}` now." ) @@ -131,9 +129,7 @@ def _get_default_backend(): return 'xccl' else: raise RuntimeError( - "`paddle.distributed.spawn` does not support parallel training on device `{}` now.".format( - device - ) + f"`paddle.distributed.spawn` does not support parallel training on device `{device}` now." ) diff --git a/python/paddle/distribution/bernoulli.py b/python/paddle/distribution/bernoulli.py index 7d4849fab48e7..152306aea31f7 100644 --- a/python/paddle/distribution/bernoulli.py +++ b/python/paddle/distribution/bernoulli.py @@ -212,6 +212,7 @@ def rsample(self, shape, temperature=1.0): .. code-block:: python >>> import paddle + >>> paddle.seed(1) >>> from paddle.distribution import Bernoulli >>> rv = Bernoulli(paddle.full((1), 0.3)) @@ -231,28 +232,26 @@ def rsample(self, shape, temperature=1.0): [100, 2, 2] >>> # `rsample` has to be followed by a `sigmoid` - >>> # doctest: +SKIP >>> rv = Bernoulli(0.3) >>> rsample = rv.rsample([3, ]) >>> rsample_sigmoid = paddle.nn.functional.sigmoid(rsample) - >>> print(rsample, rsample_sigmoid) - Tensor(shape=[3, 1], dtype=float32, place=Place(cpu), stop_gradient=True, - [[-0.88315082], - [-0.62347704], - [-0.31513220]]) - Tensor(shape=[3, 1], dtype=float32, place=Place(cpu), stop_gradient=True, - [[0.29252526], - [0.34899110], - [0.42186251]]) + >>> print(rsample) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [-1.46112013, -0.01239836, -1.32765460]) + >>> print(rsample_sigmoid) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.18829606, 0.49690047, 0.20954758]) >>> # The smaller the `temperature`, the distribution of `rsample` closer to `sample`, with `probs` of 0.3. >>> print(paddle.nn.functional.sigmoid(rv.rsample([1000, ], temperature=1.0)).sum()) + >>> # doctest: +SKIP('output will be different') Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, - 361.06829834) + 365.63122559) + >>> # doctest: -SKIP >>> print(paddle.nn.functional.sigmoid(rv.rsample([1000, ], temperature=0.1)).sum()) Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, - 288.66418457) + 320.15057373) """ name = self.name + '_rsample' if not in_dynamic_mode(): diff --git a/python/paddle/distribution/categorical.py b/python/paddle/distribution/categorical.py index b6484e3f21d56..9d5664dc28f4d 100644 --- a/python/paddle/distribution/categorical.py +++ b/python/paddle/distribution/categorical.py @@ -64,14 +64,12 @@ class Categorical(distribution.Distribution): >>> cat = Categorical(x) >>> cat2 = Categorical(y) - >>> # doctest: +SKIP >>> paddle.seed(1000) # on CPU device >>> print(cat.sample([2,3])) Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, [[0, 1, 5], [3, 4, 5]]) - >>> # doctest: -SKIP >>> print(cat.entropy()) Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 1.77528250) diff --git a/python/paddle/fft.py b/python/paddle/fft.py index 9600f2159abf6..dc38b60ac5db5 100644 --- a/python/paddle/fft.py +++ b/python/paddle/fft.py @@ -913,9 +913,7 @@ def fft2(x, s=None, axes=(-2, -1), norm="backward", name=None): if axes is not None: if not isinstance(axes, Sequence) or len(axes) != 2: raise ValueError( - "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".format( - axes - ) + f"Invalid FFT argument axes ({axes}), it should be a sequence of 2 integers." ) return fftn(x, s, axes, norm, name) @@ -981,9 +979,7 @@ def ifft2(x, s=None, axes=(-2, -1), norm="backward", name=None): if axes is not None: if not isinstance(axes, Sequence) or len(axes) != 2: raise ValueError( - "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".format( - axes - ) + f"Invalid FFT argument axes ({axes}), it should be a sequence of 2 integers." ) return ifftn(x, s, axes, norm, name) @@ -1043,9 +1039,7 @@ def rfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): if axes is not None: if not isinstance(axes, Sequence) or len(axes) != 2: raise ValueError( - "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".format( - axes - ) + f"Invalid FFT argument axes ({axes}), it should be a sequence of 2 integers." ) return rfftn(x, s, axes, norm, name) @@ -1097,9 +1091,7 @@ def irfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): if axes is not None: if not isinstance(axes, Sequence) or len(axes) != 2: raise ValueError( - "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".format( - axes - ) + f"Invalid FFT argument axes ({axes}), it should be a sequence of 2 integers." ) return irfftn(x, s, axes, norm, name) @@ -1144,9 +1136,7 @@ def hfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): if axes is not None: if not isinstance(axes, Sequence) or len(axes) != 2: raise ValueError( - "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".format( - axes - ) + f"Invalid FFT argument axes ({axes}), it should be a sequence of 2 integers." ) return hfftn(x, s, axes, norm, name) @@ -1205,9 +1195,7 @@ def ihfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): if axes is not None: if not isinstance(axes, Sequence) or len(axes) != 2: raise ValueError( - "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".format( - axes - ) + f"Invalid FFT argument axes ({axes}), it should be a sequence of 2 integers." ) return ihfftn(x, s, axes, norm, name) diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index 82315aa72d7d9..ecddb82c9a375 100755 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -37,6 +37,7 @@ _apply_pass, _create_tensor, _current_expected_place, + _current_expected_place_, _dygraph_tracer, _get_paddle_place, _global_flags, diff --git a/python/paddle/framework/dtype.py b/python/paddle/framework/dtype.py index 57a3cb81d00fe..aec82ef7729b5 100644 --- a/python/paddle/framework/dtype.py +++ b/python/paddle/framework/dtype.py @@ -123,4 +123,10 @@ def finfo(dtype): float32 """ + import paddle + + if paddle.base.framework.in_pir_mode() and isinstance( + dtype, paddle.pir.core.DataType + ): + dtype = paddle.base.framework.paddle_type_to_proto_type[dtype] return core_finfo(dtype) diff --git a/python/paddle/geometric/math.py b/python/paddle/geometric/math.py index 4ba60f5d63a54..9281dc1232cb4 100644 --- a/python/paddle/geometric/math.py +++ b/python/paddle/geometric/math.py @@ -15,7 +15,7 @@ from paddle import _C_ops from paddle.base.data_feeder import check_variable_and_dtype from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode +from paddle.framework import in_dynamic_or_pir_mode __all__ = [] @@ -52,7 +52,7 @@ def segment_sum(data, segment_ids, name=None): [4. 5. 6.]] """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.segment_pool(data, segment_ids, "SUM") else: check_variable_and_dtype( @@ -111,7 +111,7 @@ def segment_mean(data, segment_ids, name=None): """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.segment_pool(data, segment_ids, "MEAN") else: check_variable_and_dtype( @@ -169,7 +169,7 @@ def segment_min(data, segment_ids, name=None): """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.segment_pool(data, segment_ids, "MIN") else: check_variable_and_dtype( @@ -227,7 +227,7 @@ def segment_max(data, segment_ids, name=None): """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.segment_pool(data, segment_ids, "MAX") else: check_variable_and_dtype( diff --git a/python/paddle/geometric/message_passing/send_recv.py b/python/paddle/geometric/message_passing/send_recv.py index e0ec592d7010b..ef9fcef37ace7 100644 --- a/python/paddle/geometric/message_passing/send_recv.py +++ b/python/paddle/geometric/message_passing/send_recv.py @@ -22,7 +22,7 @@ ) from paddle.base.framework import Variable from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode +from paddle.framework import in_dynamic_or_pir_mode from .utils import ( convert_out_size_to_list, @@ -127,8 +127,8 @@ def send_u_recv( # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. - if in_dynamic_mode(): - out_size = convert_out_size_to_list(out_size) + if in_dynamic_or_pir_mode(): + out_size = convert_out_size_to_list(out_size, 'graph_send_recv') return _C_ops.send_u_recv( x, src_index, dst_index, reduce_op.upper(), out_size ) @@ -312,8 +312,8 @@ def send_ue_recv( # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. - if in_dynamic_mode(): - out_size = convert_out_size_to_list(out_size) + if in_dynamic_or_pir_mode(): + out_size = convert_out_size_to_list(out_size, 'graph_send_ue_recv') return _C_ops.send_ue_recv( x, y, @@ -472,7 +472,7 @@ def send_uv(x, y, src_index, dst_index, message_op="add", name=None): message_op = 'mul' y = 1.0 / (y + 1e-12) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.send_uv(x, y, src_index, dst_index, message_op.upper()) else: helper = LayerHelper("graph_send_uv", **locals()) diff --git a/python/paddle/geometric/message_passing/utils.py b/python/paddle/geometric/message_passing/utils.py index cbcd8478ab3aa..2566bc0f7074f 100644 --- a/python/paddle/geometric/message_passing/utils.py +++ b/python/paddle/geometric/message_passing/utils.py @@ -19,7 +19,7 @@ from paddle.base.framework import Variable -def convert_out_size_to_list(out_size): +def convert_out_size_to_list(out_size, op_type): """ Convert out_size(int, np.int32, np.int64, Variable) to list in imperative mode. @@ -28,6 +28,17 @@ def convert_out_size_to_list(out_size): out_size = [0] elif isinstance(out_size, (int, np.int32, np.int64)): out_size = [out_size] + elif isinstance(out_size, (Variable, paddle.pir.OpResult)): + out_size.stop_gradient = True + check_dtype( + out_size.dtype, + 'out_size', + ['int32', 'int64'], + 'op_type', + '(When type of out_size in' + op_type + ' is Variable.)', + ) + if convert_dtype(out_size.dtype) == 'int64': + out_size = paddle.cast(out_size, 'int32') else: out_size = [int(out_size)] return out_size diff --git a/python/paddle/hapi/callbacks.py b/python/paddle/hapi/callbacks.py index 1de70114db7e4..d2ed7238d52c4 100644 --- a/python/paddle/hapi/callbacks.py +++ b/python/paddle/hapi/callbacks.py @@ -1305,9 +1305,7 @@ def on_eval_end(self, logs=None): return except Exception as e: warnings.warn( - 'There are something wrong when get learning_rate from optimizer: {}.'.format( - e - ) + f'There are something wrong when get learning_rate from optimizer: {e}.' ) return diff --git a/python/paddle/hapi/dynamic_flops.py b/python/paddle/hapi/dynamic_flops.py index fcae6e4120ac8..ba211749a57be 100644 --- a/python/paddle/hapi/dynamic_flops.py +++ b/python/paddle/hapi/dynamic_flops.py @@ -85,7 +85,6 @@ def flops(net, input_size, custom_ops=None, print_detail=False): ... [1, 1, 28, 28], ... custom_ops= {nn.LeakyReLU: count_leaky_relu}, ... print_detail=True) - >>> # doctest: +SKIP >>> print(FLOPs) 's flops has been counted 's flops has been counted @@ -106,7 +105,6 @@ def flops(net, input_size, custom_ops=None, print_detail=False): +--------------+-----------------+-----------------+--------+--------+ Total Flops: 347560 Total Params: 61610 347560 - >>> # doctest: -SKIP """ if isinstance(net, nn.Layer): # If net is a dy2stat model, net.forward is StaticFunction instance, @@ -242,9 +240,7 @@ def add_hooks(m): else: if m_type not in types_collection: print( - "Cannot find suitable count function for {}. Treat it as zero FLOPs.".format( - m_type - ) + f"Cannot find suitable count function for {m_type}. Treat it as zero FLOPs." ) if flops_fn is not None: diff --git a/python/paddle/hapi/hub.py b/python/paddle/hapi/hub.py index a9118eb1c6cd0..c39fa57ad5681 100644 --- a/python/paddle/hapi/hub.py +++ b/python/paddle/hapi/hub.py @@ -227,7 +227,7 @@ def help(repo_dir, model, source='github', force_reload=False): - github path (str): A string with format "repo_owner/repo_name[:tag_name]" with an optional tag/branch. The default branch is `main` if not specified. - local path (str): Local repo path. + - local path (str): Local repo path. model (str): Model name. source (str): `github` | `gitee` | `local`. Default is `github`. diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 8ca5712a3036c..b33d239662726 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -2399,7 +2399,6 @@ def summary(self, input_size=None, dtype=None): >>> optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()) >>> model.prepare(optim, paddle.nn.CrossEntropyLoss()) >>> params_info = model.summary() - >>> # doctest: +SKIP >>> print(params_info) --------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # @@ -2424,7 +2423,6 @@ def summary(self, input_size=None, dtype=None): Estimated Total Size (MB): 0.35 --------------------------------------------------------------------------- {'total_params': 61610, 'trainable_params': 61610} - >>> # doctest: -SKIP """ assert ( @@ -2474,9 +2472,7 @@ def _verify_spec(self, specs, shapes=None, dtypes=None, is_input=False): assert isinstance(spec, Input) if spec.name is None: raise ValueError( - "Requires Input[{}].name != None, but receive `None` with {}.".format( - i, spec - ) + f"Requires Input[{i}].name != None, but receive `None` with {spec}." ) return out_specs diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index df5791a5fd70d..d893e342122ed 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -78,7 +78,6 @@ def summary(net, input_size=None, dtypes=None, input=None): >>> lenet = LeNet() >>> params_info = paddle.summary(lenet, (1, 1, 28, 28)) - >>> # doctest: +SKIP >>> print(params_info) --------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # @@ -103,7 +102,6 @@ def summary(net, input_size=None, dtypes=None, input=None): Estimated Total Size (MB): 0.35 --------------------------------------------------------------------------- {'total_params': 61610, 'trainable_params': 61610} - >>> # doctest: -SKIP >>> # multi input demo >>> class LeNetMultiInput(LeNet): ... def forward(self, inputs, y): @@ -119,7 +117,6 @@ def summary(net, input_size=None, dtypes=None, input=None): >>> params_info = paddle.summary(lenet_multi_input, ... [(1, 1, 28, 28), (1, 400)], ... dtypes=['float32', 'float32']) - >>> # doctest: +SKIP >>> print(params_info) --------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # @@ -144,7 +141,6 @@ def summary(net, input_size=None, dtypes=None, input=None): Estimated Total Size (MB): 0.35 --------------------------------------------------------------------------- {'total_params': 61610, 'trainable_params': 61610} - >>> # doctest: -SKIP >>> # list input demo >>> class LeNetListInput(LeNet): ... def forward(self, inputs): @@ -158,7 +154,6 @@ def summary(net, input_size=None, dtypes=None, input=None): >>> lenet_list_input = LeNetListInput() >>> input_data = [paddle.rand([1, 1, 28, 28]), paddle.rand([1, 400])] >>> params_info = paddle.summary(lenet_list_input, input=input_data) - >>> # doctest: +SKIP >>> print(params_info) --------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # @@ -183,7 +178,6 @@ def summary(net, input_size=None, dtypes=None, input=None): Estimated Total Size (MB): 0.35 --------------------------------------------------------------------------- {'total_params': 61610, 'trainable_params': 61610} - >>> # doctest: -SKIP >>> # dict input demo >>> class LeNetDictInput(LeNet): ... def forward(self, inputs): @@ -198,7 +192,6 @@ def summary(net, input_size=None, dtypes=None, input=None): >>> input_data = {'x1': paddle.rand([1, 1, 28, 28]), ... 'x2': paddle.rand([1, 400])} >>> params_info = paddle.summary(lenet_dict_input, input=input_data) - >>> # doctest: +SKIP >>> print(params_info) --------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # @@ -223,7 +216,6 @@ def summary(net, input_size=None, dtypes=None, input=None): Estimated Total Size (MB): 0.35 --------------------------------------------------------------------------- {'total_params': 61610, 'trainable_params': 61610} - >>> # doctest: -SKIP """ if input_size is None and input is None: @@ -300,9 +292,7 @@ def _check_shape(shape): elif isinstance(item, numbers.Number): if item <= 0: raise ValueError( - "Expected element in input size greater than zero, but got {}".format( - item - ) + f"Expected element in input size greater than zero, but got {item}" ) new_shape.append(item) return tuple(new_shape) diff --git a/python/paddle/incubate/distributed/fleet/utils.py b/python/paddle/incubate/distributed/fleet/utils.py index 17d0a4e35e693..98945ca7092e0 100644 --- a/python/paddle/incubate/distributed/fleet/utils.py +++ b/python/paddle/incubate/distributed/fleet/utils.py @@ -430,9 +430,7 @@ def check_not_expected_ops(prog): for op in prog.global_block().ops: if op.type in not_expected_op_types and op.type not in op_types_set: logger.warning( - "find op type '{}' in program, please check if your program is pruned correctly !".format( - op.type - ) + f"find op type '{op.type}' in program, please check if your program is pruned correctly !" ) op_types_set.add(op.type) diff --git a/python/paddle/incubate/multiprocessing/__init__.py b/python/paddle/incubate/multiprocessing/__init__.py index 42c7bd7bcf75e..2498a04014d95 100644 --- a/python/paddle/incubate/multiprocessing/__init__.py +++ b/python/paddle/incubate/multiprocessing/__init__.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing # noqa: F401 - from .reductions import init_reductions __all__ = [] diff --git a/python/paddle/incubate/nn/functional/fused_dropout_add.py b/python/paddle/incubate/nn/functional/fused_dropout_add.py index d191f1682fdda..127cc91d54811 100644 --- a/python/paddle/incubate/nn/functional/fused_dropout_add.py +++ b/python/paddle/incubate/nn/functional/fused_dropout_add.py @@ -16,7 +16,7 @@ from paddle import _C_ops from paddle.base import core from paddle.common_ops_import import default_main_program -from paddle.framework import LayerHelper, in_dynamic_mode +from paddle.framework import LayerHelper, in_dynamic_or_pir_mode def fused_dropout_add( @@ -84,7 +84,7 @@ def fused_dropout_add( "mode argument should be 'downscale_in_infer' or 'upscale_in_train'" ) seed = None - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if default_main_program().random_seed != 0: seed = default_main_program().random_seed out, seed_offset = _C_ops.fused_dropout_add( diff --git a/python/paddle/incubate/nn/functional/fused_layer_norm.py b/python/paddle/incubate/nn/functional/fused_layer_norm.py index c649171052ee0..444ee9f695f6d 100644 --- a/python/paddle/incubate/nn/functional/fused_layer_norm.py +++ b/python/paddle/incubate/nn/functional/fused_layer_norm.py @@ -58,14 +58,15 @@ def fused_layer_norm( Examples: .. code-block:: python - # required: gpu - import paddle - - paddle_x = paddle.cast(paddle.randn(shape=[32, 256]), dtype=paddle.float16) - paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float32) - paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float32) - epsilon = 1e-6 - paddle_layernorm = paddle.incubate.nn.functional.fused_layer_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1) + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + + >>> paddle_x = paddle.cast(paddle.randn(shape=[32, 256]), dtype=paddle.float16) + >>> paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float32) + >>> paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float32) + >>> epsilon = 1e-6 + >>> paddle_layernorm = paddle.incubate.nn.functional.fused_layer_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1) """ if in_dynamic_mode(): diff --git a/python/paddle/incubate/nn/functional/fused_rms_norm.py b/python/paddle/incubate/nn/functional/fused_rms_norm.py index 3995cd4a4087d..9a95d99b178a7 100644 --- a/python/paddle/incubate/nn/functional/fused_rms_norm.py +++ b/python/paddle/incubate/nn/functional/fused_rms_norm.py @@ -15,7 +15,7 @@ import paddle from paddle import _C_ops -from paddle.framework import LayerHelper, in_dynamic_or_pir_mode +from paddle.framework import LayerHelper, in_dynamic_mode, in_pir_mode def fused_rms_norm( @@ -54,16 +54,17 @@ def fused_rms_norm( Examples: .. code-block:: python - # required: gpu - import paddle + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') - paddle_x = paddle.cast(paddle.randn(shape=[32, 256]), dtype=paddle.float16) - paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16) - paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16) - epsilon = 1e-6 - paddle_rmsnorm = paddle.incubate.nn.functional.fused_rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1) + >>> paddle_x = paddle.cast(paddle.randn(shape=[32, 256]), dtype=paddle.float16) + >>> paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16) + >>> paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16) + >>> epsilon = 1e-6 + >>> paddle_rmsnorm = paddle.incubate.nn.functional.fused_rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1) """ - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): return _C_ops.rms_norm( x, bias, @@ -77,7 +78,21 @@ def fused_rms_norm( quant_max_bound, quant_min_bound, ) - + if in_pir_mode(): + out, residual_out = _C_ops.rms_norm( + x, + bias, + residual, + norm_weight, + norm_bias, + epsilon, + begin_norm_axis, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + ) + return (out, residual_out) if residual is not None else out helper = LayerHelper('rms_norm', **locals()) out = None if quant_scale <= 0: diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 355b5916b5ddb..52dc1b92d1580 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle import _legacy_C_ops +import paddle +from paddle import _C_ops, _legacy_C_ops from paddle.base import core from paddle.base.data_feeder import check_dtype, check_variable_and_dtype from paddle.base.framework import default_main_program from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode +from paddle.framework import in_dynamic_mode, in_dynamic_or_pir_mode __all__ = [] @@ -56,20 +57,20 @@ def fused_feedforward( This operator only supports running on GPU. The function of the operator is consistent with the following pseudo code: - .. code-block:: python + .. code-block:: text - residual = x - if pre_layer_norm: - out = layer_norm1(x) - else: - out = x - out = linear2(dropout1(activation(linear1(src)))) - if add_residual: - out = residual + dropout2(out) - else: - out = dropout2(out) - if not pre_layer_norm: - out = layer_norm2(out) + >>> residual = x + >>> if pre_layer_norm: + ... out = layer_norm1(x) + ... else: + ... out = x + >>> out = linear2(dropout1(activation(linear1(src)))) + >>> if add_residual: + ... out = residual + dropout2(out) + ... else: + ... out = dropout2(out) + >>> if not pre_layer_norm: + ... out = layer_norm2(out) Args: @@ -110,16 +111,17 @@ def fused_feedforward( Examples: .. code-block:: python - # required: gpu - import paddle - import paddle.incubate.nn.functional as F - - x = paddle.randn(shape=(1, 8, 8), dtype="float32") - linear1_weight = paddle.randn(shape=(8, 8), dtype="float32") - linear2_weight = paddle.randn(shape=(8, 8), dtype="float32") - out = F.fused_feedforward(x, linear1_weight, linear2_weight) - print(out.shape) - # (1, 8, 8) + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + >>> import paddle.incubate.nn.functional as F + + >>> x = paddle.randn(shape=(1, 8, 8), dtype="float32") + >>> linear1_weight = paddle.randn(shape=(8, 8), dtype="float32") + >>> linear2_weight = paddle.randn(shape=(8, 8), dtype="float32") + >>> out = F.fused_feedforward(x, linear1_weight, linear2_weight) + >>> print(out.shape) + [1, 8, 8] """ _verify_dropout_rate(dropout1_rate) _verify_dropout_rate(dropout2_rate) @@ -133,52 +135,95 @@ def fused_feedforward( 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode ) # semantic transfer - if in_dynamic_mode(): - if default_main_program().random_seed != 0: - seed = default_main_program().random_seed - out, _, _, _, _, _, _, _, _, _, _ = _legacy_C_ops.fused_feedforward( - x, - None, - None, - linear1_weight, - linear1_bias, - linear2_weight, - linear2_bias, - ln1_scale, - ln1_bias, - ln2_scale, - ln2_bias, - 'pre_layer_norm', - pre_layer_norm, - 'ln1_epsilon', - ln1_epsilon, - 'ln2_epsilon', - ln2_epsilon, - 'act_method', - activation, - 'dropout1_rate', - dropout1_rate, - 'dropout2_rate', - dropout2_rate, - "is_test", - not training, - "dropout1_fix_seed", - seed is not None, - "dropout2_fix_seed", - seed is not None, - "dropout1_seed", - seed if seed is not None else 0, - "dropout2_seed", - seed if seed is not None else 0, - 'dropout1_implementation', - mode, - 'dropout2_implementation', - mode, - 'add_residual', - add_residual, - 'ring_id', - ring_id, - ) + if in_dynamic_or_pir_mode(): + if paddle.static.default_main_program().random_seed != 0: + seed = paddle.static.default_main_program().random_seed + + if in_dynamic_mode(): + out, _, _, _, _, _, _, _, _, _, _ = _legacy_C_ops.fused_feedforward( + x, + None, + None, + linear1_weight, + linear1_bias, + linear2_weight, + linear2_bias, + ln1_scale, + ln1_bias, + ln2_scale, + ln2_bias, + 'pre_layer_norm', + pre_layer_norm, + 'ln1_epsilon', + ln1_epsilon, + 'ln2_epsilon', + ln2_epsilon, + 'act_method', + activation, + 'dropout1_rate', + dropout1_rate, + 'dropout2_rate', + dropout2_rate, + "is_test", + not training, + "dropout1_fix_seed", + seed is not None, + "dropout2_fix_seed", + seed is not None, + "dropout1_seed", + seed if seed is not None else 0, + "dropout2_seed", + seed if seed is not None else 0, + 'dropout1_implementation', + mode, + 'dropout2_implementation', + mode, + 'add_residual', + add_residual, + 'ring_id', + ring_id, + ) + else: + dtype = x.dtype + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64'], 'fused_feedforward' + ) + check_dtype( + dtype, + 'dtype', + ['float16', 'float32', 'float64'], + 'fused_feedforward', + ) + + out, _, _, _, _, _, _, _, _, _, _ = _C_ops.fused_feedforward( + x, + None, + None, + linear1_weight, + linear1_bias, + linear2_weight, + linear2_bias, + ln1_scale, + ln1_bias, + ln2_scale, + ln2_bias, + pre_layer_norm, + ln1_epsilon, + ln2_epsilon, + activation, + dropout1_rate, + dropout2_rate, + mode, + mode, + not training, + seed is not None, + seed is not None, + seed if seed is not None else 0, + seed if seed is not None else 0, + add_residual, + ring_id, + ) + return out helper = LayerHelper("fused_feedforward") @@ -288,9 +333,9 @@ def fused_bias_dropout_residual_layer_norm( The fused_bias_dropout_residual_layer_norm operator. The pseudo code is as follows: - .. code-block:: python + .. code-block:: text - y = layer_norm(residual + dropout(bias + x)) + >>> y = layer_norm(residual + dropout(bias + x)) Parameters: x (Tensor): The input tensor. The shape is `[*, embed\_dim]`. @@ -323,21 +368,22 @@ def fused_bias_dropout_residual_layer_norm( Examples: .. code-block:: python - # required: gpu - import paddle - import paddle.incubate.nn.functional as F - - # input: [batch_size, seq_len, embed_dim] - x = paddle.rand(shape=(2, 4, 128), dtype="float32") - # residual: [batch_size, seq_len, embed_dim] - residual = paddle.rand(shape=(2, 4, 128), dtype="float32") - # linear bias: [embed_dim] - bias = paddle.rand(shape=[128], dtype="float32") - # output: [batch_size, seq_len, embed_dim] - output = F.fused_bias_dropout_residual_layer_norm( - x, residual, bias) - # [2, 4, 128] - print(output.shape) + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + >>> import paddle.incubate.nn.functional as F + + >>> # input: [batch_size, seq_len, embed_dim] + >>> x = paddle.rand(shape=(2, 4, 128), dtype="float32") + >>> # residual: [batch_size, seq_len, embed_dim] + >>> residual = paddle.rand(shape=(2, 4, 128), dtype="float32") + >>> # linear bias: [embed_dim] + >>> bias = paddle.rand(shape=[128], dtype="float32") + >>> # output: [batch_size, seq_len, embed_dim] + >>> output = F.fused_bias_dropout_residual_layer_norm( + ... x, residual, bias) + >>> print(output.shape) + [2, 4, 128] """ seed = None @@ -493,35 +539,35 @@ def fused_multi_head_attention( to information from different representation subspaces. This API only support self_attention. The pseudo code is as follows: - .. code-block:: python - - residual = x - if pre_layer_norm: - out = layer_norm(x) - else: - out = x - # compute q, k, v - out = matmul(out, qkv_weight) + qkv_bias - out = transpose(out, perm=[2, 0, 3, 1, 4]) - # extract q, k and v from out - q = out[0:1,::] * (head_dim ** -0.5) - k = out[1:2,::] - v = out[2:3,::] - out = matmul(q, k, transpose_y=True) - out = out + attn_mask - out = softmax(out) - out = dropout(out) - out = matmul(out, v) - # combine heads - out = transpose(out, perm=[0, 2, 1, 3]) - # project to output - out = linear(out) - if add_residual: - out = residual + dropout(out) - else: - out = dropout(out) - if not pre_layer_norm: - out = layer_norm(out) + .. code-block:: text + + >>> residual = x + >>> if pre_layer_norm: + ... out = layer_norm(x) + ... else: + ... out = x + >>> # compute q, k, v + >>> out = matmul(out, qkv_weight) + qkv_bias + >>> out = transpose(out, perm=[2, 0, 3, 1, 4]) + >>> # extract q, k and v from out + >>> q = out[0:1,::] * (head_dim ** -0.5) + >>> k = out[1:2,::] + >>> v = out[2:3,::] + >>> out = matmul(q, k, transpose_y=True) + >>> out = out + attn_mask + >>> out = softmax(out) + >>> out = dropout(out) + >>> out = matmul(out, v) + >>> # combine heads + >>> out = transpose(out, perm=[0, 2, 1, 3]) + >>> # project to output + >>> out = linear(out) + >>> if add_residual: + ... out = residual + dropout(out) + ... else: + ... out = dropout(out) + >>> if not pre_layer_norm: + ... out = layer_norm(out) Parameters: @@ -581,30 +627,31 @@ def fused_multi_head_attention( .. code-block:: python - # required: gpu - import paddle - import paddle.incubate.nn.functional as F - - # input: [batch_size, seq_len, embed_dim] - x = paddle.rand(shape=(2, 4, 128), dtype="float32") - # qkv_weight: [3, num_head, head_dim, embed_dim] - qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") - # qkv_bias: [3, num_head, head_dim] - qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") - # linear_weight: [embed_dim, embed_dim] - linear_weight = paddle.rand(shape=(128, 128), dtype="float32") - # linear_bias: [embed_dim] - linear_bias = paddle.rand(shape=[128], dtype="float32") - # self attention mask: [batch_size, num_heads, seq_len, seq_len] - attn_mask = paddle.rand(shape=(2, 4, 4, 4), dtype="float32") - - # output: [batch_size, seq_len, embed_dim] - output = F.fused_multi_head_attention( - x, qkv_weight, linear_weight, False, - None, None, None, None, 1e-5, qkv_bias, - linear_bias, None, attn_mask) - # [2, 4, 128] - print(output.shape) + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + >>> import paddle.incubate.nn.functional as F + + >>> # input: [batch_size, seq_len, embed_dim] + >>> x = paddle.rand(shape=(2, 4, 128), dtype="float32") + >>> # qkv_weight: [3, num_head, head_dim, embed_dim] + >>> qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") + >>> # qkv_bias: [3, num_head, head_dim] + >>> qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") + >>> # linear_weight: [embed_dim, embed_dim] + >>> linear_weight = paddle.rand(shape=(128, 128), dtype="float32") + >>> # linear_bias: [embed_dim] + >>> linear_bias = paddle.rand(shape=[128], dtype="float32") + >>> # self attention mask: [batch_size, num_heads, seq_len, seq_len] + >>> attn_mask = paddle.rand(shape=(2, 4, 4, 4), dtype="float32") + + >>> # output: [batch_size, seq_len, embed_dim] + >>> output = F.fused_multi_head_attention( + ... x, qkv_weight, linear_weight, False, + ... None, None, None, None, 1e-5, qkv_bias, + ... linear_bias, None, attn_mask) + >>> print(output.shape) + [2, 4, 128] """ seed = None @@ -621,9 +668,9 @@ def fused_multi_head_attention( f"The rank of the x should be 3, but received {x.ndim}." ) - if in_dynamic_mode(): - if default_main_program().random_seed != 0: - seed = default_main_program().random_seed + if in_dynamic_or_pir_mode(): + if paddle.static.default_main_program().random_seed != 0: + seed = paddle.static.default_main_program().random_seed # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out, # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out @@ -669,72 +716,125 @@ def fused_multi_head_attention( "When enable transpose_qkv_wb, the 1st dim of qkv_bias and 2nd dim of " "qkv_weight should be the same, i.e., embed_dim." ) - ( - _, - _, - _, - _, - _, - _, - _, - _, - _, - _, - _, - _, - _, - _, - _, - _, - _, - _, - cache_kv_out, - final_out, - ) = _legacy_C_ops.fused_attention( - x, - pre_ln_scale, - pre_ln_bias, - qkv_weight, - qkv_bias, - cache_kv, - attn_mask, - linear_weight, - linear_bias, - ln_scale, - ln_bias, - 'num_heads', - num_heads, - 'transpose_qkv_wb', - transpose_qkv_wb, - 'pre_layer_norm', - pre_layer_norm, - 'epsilon', - pre_ln_epsilon, - 'dropout_rate', - dropout_rate, - 'attn_dropout_rate', - attn_dropout_rate, - 'ln_epsilon', - ln_epsilon, - 'is_test', - not training, - 'attn_dropout_fix_seed', - seed is not None, - 'dropout_fix_seed', - seed is not None, - 'attn_dropout_seed', - seed if seed is not None else 0, - 'dropout_seed', - seed if seed is not None else 0, - 'attn_dropout_implementation', - mode, - 'dropout_implementation', - mode, - 'add_residual', - add_residual, - 'ring_id', - ring_id, - ) + if in_dynamic_mode(): + ( + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + cache_kv_out, + final_out, + ) = _legacy_C_ops.fused_attention( + x, + pre_ln_scale, + pre_ln_bias, + qkv_weight, + qkv_bias, + cache_kv, + attn_mask, + linear_weight, + linear_bias, + ln_scale, + ln_bias, + 'num_heads', + num_heads, + 'transpose_qkv_wb', + transpose_qkv_wb, + 'pre_layer_norm', + pre_layer_norm, + 'epsilon', + pre_ln_epsilon, + 'dropout_rate', + dropout_rate, + 'attn_dropout_rate', + attn_dropout_rate, + 'ln_epsilon', + ln_epsilon, + 'is_test', + not training, + 'attn_dropout_fix_seed', + seed is not None, + 'dropout_fix_seed', + seed is not None, + 'attn_dropout_seed', + seed if seed is not None else 0, + 'dropout_seed', + seed if seed is not None else 0, + 'attn_dropout_implementation', + mode, + 'dropout_implementation', + mode, + 'add_residual', + add_residual, + 'ring_id', + ring_id, + ) + else: + ( + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + cache_kv_out, + final_out, + ) = _C_ops.fused_attention( + x, + pre_ln_scale, + pre_ln_bias, + qkv_weight, + qkv_bias, + cache_kv, + attn_mask, + linear_weight, + linear_bias, + ln_scale, + ln_bias, + num_heads, + transpose_qkv_wb, + pre_layer_norm, + pre_ln_epsilon, + attn_dropout_rate, + not training, + seed is not None, + seed if seed is not None else 0, + mode, + dropout_rate, + seed is not None, + seed if seed is not None else 0, + mode, + ln_epsilon, + add_residual, + ring_id, + ) + if cache_kv is not None: return final_out, cache_kv_out return final_out @@ -906,39 +1006,39 @@ def fused_multi_transformer( This operator only supports running on GPU. The function of the transformer layer is consistent with the following pseudo code: - .. code-block:: python - - if pre_layer_norm: - out = layer_norm(x) - out = qkv_linear(out) + qkv_bias - else: - out = qkv_linear(x) + qkv_bias - out = transpose(out, perm=[2, 0, 3, 1, 4]) - # extract q, k and v from out. - q = out[0:1, ::] - k = out[1:2, ::] - v = out[2:3, ::] - out = q * k^t - out = attn_mask + out - out = softmax(out) - out = dropout(out) - out = out * v - out = transpose(out, perm=[0, 2, 1, 3]) - out = linear(out) - if pre_layer_norm: - out = x + dropout(out + bias) - else: - out = layer_norm(x + dropout(out + bias)) - - residual = out; - if pre_layer_norm: - out = ffn_layer_norm(out) - out = ffn1_linear(out) - out = dropout(activation(out + ffn1_bias)) - out = ffn2_linear(out) - out = residual + dropout(out + ffn2_bias) - if not pre_layer_norm: - out = ffn_layer_norm(out) + .. code-block:: text + + >>> if pre_layer_norm: + ... out = layer_norm(x) + ... out = qkv_linear(out) + qkv_bias + ... else: + ... out = qkv_linear(x) + qkv_bias + >>> out = transpose(out, perm=[2, 0, 3, 1, 4]) + >>> # extract q, k and v from out. + >>> q = out[0:1, ::] + >>> k = out[1:2, ::] + >>> v = out[2:3, ::] + >>> out = q * k^t + >>> out = attn_mask + out + >>> out = softmax(out) + >>> out = dropout(out) + >>> out = out * v + >>> out = transpose(out, perm=[0, 2, 1, 3]) + >>> out = linear(out) + >>> if pre_layer_norm: + ... out = x + dropout(out + bias) + ... else: + ... out = layer_norm(x + dropout(out + bias)) + + >>> residual = out; + >>> if pre_layer_norm: + ... out = ffn_layer_norm(out) + >>> out = ffn1_linear(out) + >>> out = dropout(activation(out + ffn1_bias)) + >>> out = ffn2_linear(out) + >>> out = residual + dropout(out + ffn2_bias) + >>> if not pre_layer_norm: + ... out = ffn_layer_norm(out) Args: x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16 or float32, the shape is `[batch\_size, sequence\_length, d\_model]`. @@ -996,48 +1096,49 @@ def fused_multi_transformer( Examples: .. code-block:: python - # required: gpu - import paddle - import paddle.incubate.nn.functional as F - - # input: [batch_size, seq_len, embed_dim] - x = paddle.rand(shape=(2, 4, 128), dtype="float32") - - # ln_scale: [embed_dim], ln_bias: [embed_dim] - ln_scale = paddle.rand(shape=(128,), dtype="float32") - ln_bias = paddle.rand(shape=(128,), dtype="float32") - - # qkv_weight: [3, num_head, head_dim, embed_dim], qkv_bias: [3, num_head, head_dim] - qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") - qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") - - # linear_weight: [embed_dim, embed_dim], linear_bias: [embed_dim] - linear_weight = paddle.rand(shape=(128, 128), dtype="float32") - linear_bias = paddle.rand(shape=(128,), dtype="float32") - - # ffn_ln_scale: [embed_dim], ffn_ln_bias: [embed_dim] - ffn_ln_scale = paddle.rand(shape=(128,), dtype="float32") - ffn_ln_bias = paddle.rand(shape=(128,), dtype="float32") - - # ffn1_weight: [embed_dim, 4*embed_dim], ffn1_bias: [4*embed_dim] - ffn1_weight = paddle.rand(shape=(128, 4*128), dtype="float32") - ffn1_bias = paddle.rand(shape=(4*128,), dtype="float32") - - # ffn2_weight: [4*embed_dim, embed_dim], ffn2_bias: [embed_dim] - ffn2_weight = paddle.rand(shape=(4*128, 128), dtype="float32") - ffn2_bias = paddle.rand(shape=(128,), dtype="float32") - - # self attention mask: [batch_size, 1, seq_len, seq_len] - attn_mask = paddle.rand(shape=(2, 1, 4, 4), dtype="float32") - - # output: [batch_size, seq_len, embed_dim] - output = F.fused_multi_transformer( - x, [ln_scale], [ln_bias], [qkv_weight], [qkv_bias], - [linear_weight], [linear_bias], [ffn_ln_scale], [ffn_ln_bias], - [ffn1_weight], [ffn1_bias], [ffn2_weight], [ffn2_bias], - attn_mask=attn_mask) - # [2, 4, 128] - print(output.shape) + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + >>> import paddle.incubate.nn.functional as F + + >>> # input: [batch_size, seq_len, embed_dim] + >>> x = paddle.rand(shape=(2, 4, 128), dtype="float32") + + >>> # ln_scale: [embed_dim], ln_bias: [embed_dim] + >>> ln_scale = paddle.rand(shape=(128,), dtype="float32") + >>> ln_bias = paddle.rand(shape=(128,), dtype="float32") + + >>> # qkv_weight: [3, num_head, head_dim, embed_dim], qkv_bias: [3, num_head, head_dim] + >>> qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") + >>> qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") + + >>> # linear_weight: [embed_dim, embed_dim], linear_bias: [embed_dim] + >>> linear_weight = paddle.rand(shape=(128, 128), dtype="float32") + >>> linear_bias = paddle.rand(shape=(128,), dtype="float32") + + >>> # ffn_ln_scale: [embed_dim], ffn_ln_bias: [embed_dim] + >>> ffn_ln_scale = paddle.rand(shape=(128,), dtype="float32") + >>> ffn_ln_bias = paddle.rand(shape=(128,), dtype="float32") + + >>> # ffn1_weight: [embed_dim, 4*embed_dim], ffn1_bias: [4*embed_dim] + >>> ffn1_weight = paddle.rand(shape=(128, 4*128), dtype="float32") + >>> ffn1_bias = paddle.rand(shape=(4*128,), dtype="float32") + + >>> # ffn2_weight: [4*embed_dim, embed_dim], ffn2_bias: [embed_dim] + >>> ffn2_weight = paddle.rand(shape=(4*128, 128), dtype="float32") + >>> ffn2_bias = paddle.rand(shape=(128,), dtype="float32") + + >>> # self attention mask: [batch_size, 1, seq_len, seq_len] + >>> attn_mask = paddle.rand(shape=(2, 1, 4, 4), dtype="float32") + + >>> # output: [batch_size, seq_len, embed_dim] + >>> output = F.fused_multi_transformer( + ... x, [ln_scale], [ln_bias], [qkv_weight], [qkv_bias], + ... [linear_weight], [linear_bias], [ffn_ln_scale], [ffn_ln_bias], + ... [ffn1_weight], [ffn1_bias], [ffn2_weight], [ffn2_bias], + ... attn_mask=attn_mask) + >>> print(output.shape) + [2, 4, 128] """ if mode not in ('downscale_in_infer', 'upscale_in_train'): raise ValueError( diff --git a/python/paddle/incubate/nn/functional/masked_multihead_attention.py b/python/paddle/incubate/nn/functional/masked_multihead_attention.py index 93b9b1419855b..9b1f3d464ab48 100644 --- a/python/paddle/incubate/nn/functional/masked_multihead_attention.py +++ b/python/paddle/incubate/nn/functional/masked_multihead_attention.py @@ -71,21 +71,22 @@ def masked_multihead_attention( Examples: .. code-block:: python - # required: gpu - import paddle - import paddle.incubate.nn.functional as F + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> import paddle.incubate.nn.functional as F + >>> paddle.device.set_device('gpu') - # input: [batch_size, 3 * num_head * dim_head] - x = paddle.rand(shape=(2, 3 * 32 * 128), dtype="float32") + >>> # input: [batch_size, 3 * num_head * dim_head] + >>> x = paddle.rand(shape=(2, 3 * 32 * 128), dtype="float32") - # src_mask: [batch_size, 1, 1, sequence_length] - src_mask = paddle.rand(shape=(2, 1, 1, 10), dtype="float32") + >>> # src_mask: [batch_size, 1, 1, sequence_length] + >>> src_mask = paddle.rand(shape=(2, 1, 1, 10), dtype="float32") - # cache_kv: [2, batch_size, num_head, max_seq_len, dim_head] - cache_kv = paddle.rand(shape=(2, 2, 32, 64, 128), dtype="float32") + >>> # cache_kv: [2, batch_size, num_head, max_seq_len, dim_head] + >>> cache_kv = paddle.rand(shape=(2, 2, 32, 64, 128), dtype="float32") - output = F.masked_multihead_attention( - x, src_mask=src_mask, cache_kv=cache_kv) + >>> output = F.masked_multihead_attention( + ... x, src_mask=src_mask, cache_kv=cache_kv) """ diff --git a/python/paddle/incubate/nn/functional/variable_length_memory_efficient_attention.py b/python/paddle/incubate/nn/functional/variable_length_memory_efficient_attention.py index 06f9772628abd..9643600250f65 100644 --- a/python/paddle/incubate/nn/functional/variable_length_memory_efficient_attention.py +++ b/python/paddle/incubate/nn/functional/variable_length_memory_efficient_attention.py @@ -54,38 +54,40 @@ def variable_length_memory_efficient_attention( Examples: .. code-block:: python - # required: gpu - import math - import paddle - from paddle.incubate.nn.functional import variable_length_memory_efficient_attention - - batch = 1 - num_head = 8 - seq_len = 256 - head_size = 32 - - dtype = paddle.float16 - - query = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype) - key = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype) - value = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype) - seq_lens = paddle.to_tensor([seq_len, ] * batch, dtype='int32') - mask = paddle.randn([batch, 1, seq_len, seq_len], dtype=dtype) - - scale = float(1.0 / math.sqrt(head_size)) - - def naive_attention_impl(query, key, value, mask, scale): - qk_res = paddle.matmul(query, key, transpose_y=True) - attention = qk_res * scale - attention = attention + mask - softmax_result = paddle.nn.functional.softmax(attention, -1) - result = paddle.matmul(softmax_result, value) - return result - - out = naive_attention_impl(query, key, value, mask, scale) - # equals to: out = variable_length_memory_efficient_attention(query, key, value, seq_lens, seq_lens, mask, scale) - - print(out.shape) # [batch, seq_len, num_head, head_size] + >>> # doctest: +REQUIRES(env:GPU) + >>> import math + >>> import paddle + >>> from paddle.incubate.nn.functional import variable_length_memory_efficient_attention + >>> paddle.device.set_device('gpu') + + >>> batch = 1 + >>> num_head = 8 + >>> seq_len = 256 + >>> head_size = 32 + + >>> dtype = paddle.float16 + + >>> query = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype) + >>> key = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype) + >>> value = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype) + >>> seq_lens = paddle.to_tensor([seq_len, ] * batch, dtype='int32') + >>> mask = paddle.randn([batch, 1, seq_len, seq_len], dtype=dtype) + + >>> scale = float(1.0 / math.sqrt(head_size)) + + >>> def naive_attention_impl(query, key, value, mask, scale): + ... qk_res = paddle.matmul(query, key, transpose_y=True) + ... attention = qk_res * scale + ... attention = attention + mask + ... softmax_result = paddle.nn.functional.softmax(attention, -1) + ... result = paddle.matmul(softmax_result, value) + ... return result + + >>> out = naive_attention_impl(query, key, value, mask, scale) + >>> # equals to: out = variable_length_memory_efficient_attention(query, key, value, seq_lens, seq_lens, mask, scale) + + >>> print(out.shape) # [batch, seq_len, num_head, head_size] + [1, 8, 256, 32] """ if scale is None: head_size = query.shape[3] diff --git a/python/paddle/incubate/nn/layer/fused_dropout_nd.py b/python/paddle/incubate/nn/layer/fused_dropout_nd.py index ded171158fe3d..09f083da88c74 100644 --- a/python/paddle/incubate/nn/layer/fused_dropout_nd.py +++ b/python/paddle/incubate/nn/layer/fused_dropout_nd.py @@ -54,6 +54,7 @@ class FusedDropout(paddle.nn.Layer): .. code-block:: python >>> import paddle + >>> paddle.seed(2023) >>> x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype="float32") >>> m = paddle.incubate.nn.FusedDropout(p=0.5) @@ -61,15 +62,15 @@ class FusedDropout(paddle.nn.Layer): >>> y_train = m(x) >>> print(y_train) Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, - [[2., 0., 6.], - [0., 0., 0.]]) + [[0., 0., 6.], + [0., 0., 0.]]) >>> m.eval() # switch the model to test phase >>> y_test = m(x) >>> print(y_test) Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, - [[1., 2., 3.], - [4., 5., 6.]]) + [[1., 2., 3.], + [4., 5., 6.]]) """ def __init__(self, p=0.5, axis=None, mode="upscale_in_train", name=None): diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py index 871d3a28b5ca7..1feccad46e40d 100644 --- a/python/paddle/incubate/operators/graph_send_recv.py +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -14,17 +14,19 @@ import numpy as np -import paddle from paddle import _C_ops from paddle.base.data_feeder import ( check_dtype, check_type, check_variable_and_dtype, - convert_dtype, ) from paddle.base.framework import Variable from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode +from paddle.framework import in_dynamic_or_pir_mode +from paddle.geometric.message_passing.utils import ( + convert_out_size_to_list, + get_out_size_tensor_inputs, +) from paddle.utils import deprecated @@ -134,89 +136,55 @@ def graph_send_recv( ) # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. - - if in_dynamic_mode(): - out_size = convert_out_size_to_list(out_size) + if in_dynamic_or_pir_mode(): + out_size = convert_out_size_to_list(out_size, 'graph_send_recv') return _C_ops.send_u_recv( x, src_index, dst_index, pool_type.upper(), out_size ) - - check_variable_and_dtype( - x, "X", ("float32", "float64", "int32", "int64"), "graph_send_recv" - ) - check_variable_and_dtype( - src_index, "Src_index", ("int32", "int64"), "graph_send_recv" - ) - check_variable_and_dtype( - dst_index, "Dst_index", ("int32", "int64"), "graph_send_recv" - ) - if out_size: - check_type( - out_size, - 'out_size', - (int, np.int32, np.int64, Variable), - 'graph_send_recv', + else: + check_variable_and_dtype( + x, "X", ("float32", "float64", "int32", "int64"), "graph_send_recv" ) - if isinstance(out_size, Variable): - check_dtype( - out_size.dtype, 'out_size', ['int32', 'int64'], 'graph_send_recv' + check_variable_and_dtype( + src_index, "Src_index", ("int32", "int64"), "graph_send_recv" + ) + check_variable_and_dtype( + dst_index, "Dst_index", ("int32", "int64"), "graph_send_recv" + ) + if out_size: + check_type( + out_size, + 'out_size', + (int, np.int32, np.int64, Variable), + 'graph_send_recv', + ) + if isinstance(out_size, Variable): + check_dtype( + out_size.dtype, + 'out_size', + ['int32', 'int64'], + 'graph_send_recv', + ) + + helper = LayerHelper("graph_send_recv", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + dst_count = helper.create_variable_for_type_inference( + dtype="int32", stop_gradient=True ) - helper = LayerHelper("graph_send_recv", **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) - dst_count = helper.create_variable_for_type_inference( - dtype="int32", stop_gradient=True - ) - - inputs = {"X": x, "Src_index": src_index, "Dst_index": dst_index} - attrs = {"reduce_op": pool_type.upper()} - get_out_size_tensor_inputs( - inputs=inputs, attrs=attrs, out_size=out_size, op_type='graph_send_recv' - ) - - helper.append_op( - type="graph_send_recv", - inputs=inputs, - outputs={"Out": out, "Dst_count": dst_count}, - attrs=attrs, - ) - return out - - -def convert_out_size_to_list(out_size): - """ - Convert out_size(int, np.int32, np.int64, Variable) to list - in imperative mode. - """ - if out_size is None: - out_size = [0] - elif isinstance(out_size, (int, np.int32, np.int64)): - out_size = [out_size] - else: - out_size = [int(out_size)] - return out_size - + inputs = {"X": x, "Src_index": src_index, "Dst_index": dst_index} + attrs = {"reduce_op": pool_type.upper()} + get_out_size_tensor_inputs( + inputs=inputs, + attrs=attrs, + out_size=out_size, + op_type='graph_send_recv', + ) -def get_out_size_tensor_inputs(inputs, attrs, out_size, op_type): - """ - Convert out_size(int, np.int32, np.int64, Variable) to inputs - and attrs in static graph mode. - """ - if out_size is None: - attrs['out_size'] = [0] - elif isinstance(out_size, (int, np.int32, np.int64)): - attrs['out_size'] = [out_size] - elif isinstance(out_size, Variable): - out_size.stop_gradient = True - check_dtype( - out_size.dtype, - 'out_size', - ['int32', 'int64'], - op_type, - '(When type of out_size in' + op_type + ' is Variable.)', + helper.append_op( + type="graph_send_recv", + inputs=inputs, + outputs={"Out": out, "Dst_count": dst_count}, + attrs=attrs, ) - if convert_dtype(out_size.dtype) == 'int64': - out_size = paddle.cast(out_size, 'int32') - inputs["Out_size"] = out_size - else: - raise TypeError("Out_size only supports Variable or int.") + return out diff --git a/python/paddle/incubate/optimizer/gradient_merge.py b/python/paddle/incubate/optimizer/gradient_merge.py index 022e4dc8fbb7b..3cd17992ef5e8 100644 --- a/python/paddle/incubate/optimizer/gradient_merge.py +++ b/python/paddle/incubate/optimizer/gradient_merge.py @@ -50,40 +50,40 @@ class GradientMergeOptimizer: Examples: .. code-block:: python - import paddle - import paddle.base as base - import numpy as np - - def gen_data(batch_size): - return {"x": np.random.random(size=(batch_size, 32)).astype('float32'), - "y": np.random.random(size=(batch_size, 1)).astype('int64')} - - def mlp(input_x, input_y, hid_dim=128, label_dim=2): - fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) - prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') - cost = paddle.nn.functional.cross_entropy( - input=prediction, label=input_y, - reduction='none', use_softmax=False - ) - sum_cost = paddle.mean(cost) - return sum_cost, fc_1, prediction - - input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') - cost, fc_1, pred = mlp(input_x, input_y) - sgd = paddle.optimizer.Adam(learning_rate=0.01) - sgd = paddle.incubate.optimizer.GradientMergeOptimizer(sgd, k_steps=4, avg=True) - sgd.minimize(cost) - - place = base.CPUPlace() - exe = base.Executor(place) - exe.run(base.default_startup_program()) - - for i in range(10): - cost_val = exe.run(feed=gen_data(32), - program=base.default_main_program(), - fetch_list=[cost.name]) - print("step=%d, cost=%f" % (i, cost_val[0])) + >>> import paddle + >>> import numpy as np + >>> paddle.enable_static() + + >>> def gen_data(batch_size): + ... return {"x": np.random.random(size=(batch_size, 32)).astype('float32'), + ... "y": np.random.random(size=(batch_size, 1)).astype('int64')} + + >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2): + ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) + ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') + ... cost = paddle.nn.functional.cross_entropy( + ... input=prediction, label=input_y, + ... reduction='none', use_softmax=False + ... ) + ... sum_cost = paddle.mean(cost) + ... return sum_cost, fc_1, prediction + + >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') + >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') + >>> cost, fc_1, pred = mlp(input_x, input_y) + >>> sgd = paddle.optimizer.Adam(learning_rate=0.01) + >>> sgd = paddle.incubate.optimizer.GradientMergeOptimizer(sgd, k_steps=4, avg=True) + >>> sgd.minimize(cost) + + >>> place = paddle.CPUPlace() + >>> exe = paddle.static.Executor(place) + >>> exe.run(paddle.static.default_startup_program()) + + >>> for i in range(10): + ... cost_val = exe.run(feed=gen_data(32), + ... program=paddle.static.default_main_program(), + ... fetch_list=[cost.name]) + ... print("step=%d, cost=%f" % (i, cost_val[0])) """ GRAD_MERGE_COND_NAME = "grad_merge_cond_name" diff --git a/python/paddle/incubate/optimizer/lars_momentum.py b/python/paddle/incubate/optimizer/lars_momentum.py index 261fb8038f193..8f3f1e4e2e64b 100644 --- a/python/paddle/incubate/optimizer/lars_momentum.py +++ b/python/paddle/incubate/optimizer/lars_momentum.py @@ -63,24 +63,23 @@ class LarsMomentumOptimizer(Optimizer): Examples: .. code-block:: python - import paddle - import paddle.base as base - import numpy as np - - paddle.enable_static() - np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) - inp = paddle.static.data( - name="inp", shape=[2, 2], dtype='float32') - out = paddle.static.nn.fc(inp, size=3) - out = paddle.sum(out) - optimizer = base.optimizer.LarsMomentumOptimizer(learning_rate=0.001, momentum=0.9) - optimizer.minimize(out) - - exe = base.Executor(base.CPUPlace()) - exe.run(base.default_startup_program()) - exe.run( - feed={"inp": np_inp}, - fetch_list=[out.name]) + >>> import paddle + >>> import numpy as np + + >>> paddle.enable_static() + >>> np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + >>> inp = paddle.static.data( + ... name="inp", shape=[2, 2], dtype='float32') + >>> out = paddle.static.nn.fc(inp, size=3) + >>> out = paddle.sum(out) + >>> optimizer = paddle.incubate.optimizer.LarsMomentumOptimizer(learning_rate=0.001, momentum=0.9) + >>> optimizer.minimize(out) + + >>> exe = paddle.static.Executor(paddle.CPUPlace()) + >>> exe.run(paddle.static.default_startup_program()) + >>> exe.run( + ... feed={"inp": np_inp}, + ... fetch_list=[out.name]) """ _velocity_acc_str = "velocity" diff --git a/python/paddle/incubate/optimizer/pipeline.py b/python/paddle/incubate/optimizer/pipeline.py index 6c0e80b1f5710..ed2263625b28f 100644 --- a/python/paddle/incubate/optimizer/pipeline.py +++ b/python/paddle/incubate/optimizer/pipeline.py @@ -48,47 +48,47 @@ class PipelineOptimizer: Examples: .. code-block:: python - import paddle - import paddle.base as base - import paddle.base.layers as layers - import numpy as np - - paddle.enable_static() - with base.device_guard("gpu:0"): - x = paddle.static.data(name='x', shape=[-1, 1], dtype='int64', lod_level=0) - y = paddle.static.data(name='y', shape=[-1, 1], dtype='int64', lod_level=0) - data_loader = base.io.DataLoader.from_generator( - feed_list=[x, y], - capacity=64, - use_double_buffer=True, - iterable=False) - - emb_x = layers.embedding(input=x, param_attr=base.ParamAttr(name="embx"), size=[10,2], is_sparse=False) - emb_y = layers.embedding(input=y, param_attr=base.ParamAttr(name="emby",learning_rate=0.9), size=[10,2], is_sparse=False) - - with base.device_guard("gpu:1"): - concat = layers.concat([emb_x, emb_y], axis=1) - fc = paddle.static.nn.fc(x=concat, name="fc", size=1, num_flatten_dims=1, bias_attr=False) - loss = paddle.mean(fc) - optimizer = paddle.optimizer.SGD(learning_rate=0.5) - optimizer = paddle.incubate.optimizer.PipelineOptimizer(optimizer) - optimizer.minimize(loss) - - def train_reader(): - for _ in range(4): - x = np.random.random(size=[1]).astype('int64') - y = np.random.random(size=[1]).astype('int64') - yield x, y - data_loader.set_sample_generator(train_reader, batch_size=1) - - place = base.CUDAPlace(0) - exe = base.Executor(place) - exe.run(base.default_startup_program()) - batch_size = 1 - data_loader.start() - exe.train_from_dataset( - base.default_main_program()) - data_loader.reset() + >>> import paddle + >>> import paddle.base as base + >>> import paddle.base.layers as layers + >>> import numpy as np + + >>> paddle.enable_static() + >>> with base.device_guard("gpu:0"): + ... x = paddle.static.data(name='x', shape=[-1, 1], dtype='int64', lod_level=0) + ... y = paddle.static.data(name='y', shape=[-1, 1], dtype='int64', lod_level=0) + ... data_loader = base.io.DataLoader.from_generator( + ... feed_list=[x, y], + ... capacity=64, + ... use_double_buffer=True, + ... iterable=False) + + ... emb_x = layers.embedding(input=x, param_attr=base.ParamAttr(name="embx"), size=[10,2], is_sparse=False) + ... emb_y = layers.embedding(input=y, param_attr=base.ParamAttr(name="emby",learning_rate=0.9), size=[10,2], is_sparse=False) + + >>> with base.device_guard("gpu:1"): + ... concat = layers.concat([emb_x, emb_y], axis=1) + ... fc = paddle.static.nn.fc(x=concat, name="fc", size=1, num_flatten_dims=1, bias_attr=False) + ... loss = paddle.mean(fc) + >>> optimizer = paddle.optimizer.SGD(learning_rate=0.5) + >>> optimizer = paddle.incubate.optimizer.PipelineOptimizer(optimizer) + >>> optimizer.minimize(loss) + + >>> def train_reader(): + ... for _ in range(4): + ... x = np.random.random(size=[1]).astype('int64') + ... y = np.random.random(size=[1]).astype('int64') + ... yield x, y + >>> data_loader.set_sample_generator(train_reader, batch_size=1) + + >>> place = paddle.CUDAPlace(0) + >>> exe = paddle.static.Executor(place) + >>> exe.run(paddle.static.default_startup_program()) + >>> batch_size = 1 + >>> data_loader.start() + >>> exe.train_from_dataset( + ... paddle.static.default_main_program()) + >>> data_loader.reset() """ def __init__(self, optimizer, num_microbatches=1, start_cpu_core_id=0): @@ -746,7 +746,7 @@ def _check_stage(cur_id, prev_id): is_backward = self._is_backward_op(op) assert is_forward or is_backward, ( 'send/recv in pipeline should only be inserted in forward or backward,' - 'please check the op_role of op={}'.format(op) + f'please check the op_role of op={op}' ) if is_forward: diff --git a/python/paddle/incubate/optimizer/recompute.py b/python/paddle/incubate/optimizer/recompute.py index 9cbd8894f1889..2545115fa0d01 100644 --- a/python/paddle/incubate/optimizer/recompute.py +++ b/python/paddle/incubate/optimizer/recompute.py @@ -49,45 +49,57 @@ class RecomputeOptimizer(Optimizer): Examples: .. code-block:: python - import paddle - import paddle.base as base - import numpy as np - - paddle.enable_static() - - def gen_data(): - return {"x": np.random.random(size=(32, 32)).astype('float32'), - "y": np.random.randint(2, size=(32, 1)).astype('int64')} - def mlp(input_x, input_y, hid_dim=128, label_dim=2): - print(input_x) - fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) - prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') - cost = paddle.nn.functional.cross_entropy( - input=prediction, label=input_y, - reduction='none', use_softmax=False - ) - sum_cost = paddle.mean(cost) - return sum_cost, fc_1, prediction - input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') - cost, fc_1, pred = mlp(input_x, input_y) - - sgd = paddle.optimizer.Adam(learning_rate=0.01) - sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) - sgd._set_checkpoints([fc_1, pred]) - sgd.minimize(cost) - - print("Finished optimize") - place = base.CPUPlace() - exe = base.Executor(place) - exe.run(base.default_startup_program()) - step = 10 - - for i in range(step): - cost_val = exe.run(feed=gen_data(), - program=base.default_main_program(), - fetch_list=[cost.name]) - print("step=%d cost=%f" % (i, cost_val[0])) + >>> import paddle + >>> import numpy as np + + >>> paddle.enable_static() + + >>> def gen_data(): + ... return {"x": np.random.random(size=(32, 32)).astype('float32'), + ... "y": np.random.randint(2, size=(32, 1)).astype('int64')} + >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2): + ... print(input_x) + ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) + ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') + ... cost = paddle.nn.functional.cross_entropy( + ... input=prediction, label=input_y, + ... reduction='none', use_softmax=False + ... ) + ... sum_cost = paddle.mean(cost) + ... return sum_cost, fc_1, prediction + >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') + >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') + >>> cost, fc_1, pred = mlp(input_x, input_y) + + >>> sgd = paddle.optimizer.Adam(learning_rate=0.01) + >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) + >>> sgd._set_checkpoints([fc_1, pred]) + >>> sgd.minimize(cost) + + >>> print("Finished optimize") + Finished optimize + >>> place = paddle.CPUPlace() + >>> exe = paddle.static.Executor(place) + >>> exe.run(paddle.static.default_startup_program()) + >>> step = 10 + + >>> for i in range(step): + ... cost_val = exe.run(feed=gen_data(), + ... program=paddle.static.default_main_program(), + ... fetch_list=[cost.name]) + ... print("step=%d cost=%f" % (i, cost_val[0])) + var x : LOD_TENSOR.shape(-1, 32).dtype(float32).stop_gradient(True) + Finished optimize + step=0 cost=0.737203 + step=1 cost=1.308077 + step=2 cost=0.768422 + step=3 cost=1.239475 + step=4 cost=0.882643 + step=5 cost=0.738027 + step=6 cost=0.819374 + step=7 cost=0.818534 + step=8 cost=0.753692 + step=9 cost=0.787448 """ @@ -132,33 +144,34 @@ def load(self, state_dict): Examples: .. code-block:: python - import paddle - import paddle.base as base - - paddle.enable_static() - def mlp(input_x, input_y, hid_dim=128, label_dim=2): - fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) - prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') - cost = paddle.nn.functional.cross_entropy( - input=prediction, label=input_y, - reduction='none', use_softmax=False - ) - sum_cost = paddle.mean(cost) - return sum_cost, fc_1, prediction - - input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') - cost, fc_1, pred = mlp(input_x, input_y) - print("Finished FF") - - sgd = paddle.optimizer.Adam(learning_rate=0.01) - sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) - sgd._set_checkpoints([fc_1, pred]) - try: - state_dict = {} - sgd.load(state_dict) - except NotImplementedError as e: - print(e) + >>> import paddle + + >>> paddle.enable_static() + >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2): + ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) + ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') + ... cost = paddle.nn.functional.cross_entropy( + ... input=prediction, label=input_y, + ... reduction='none', use_softmax=False + ... ) + ... sum_cost = paddle.mean(cost) + ... return sum_cost, fc_1, prediction + + >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') + >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') + >>> cost, fc_1, pred = mlp(input_x, input_y) + >>> print("Finished FF") + Finished FF + + >>> sgd = paddle.optimizer.Adam(learning_rate=0.01) + >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) + >>> sgd._set_checkpoints([fc_1, pred]) + >>> try: + ... state_dict = {} + ... sgd.load(state_dict) + >>> except NotImplementedError as e: + ... print(e) + load function is not supported by Recompute Optimizer for now """ raise NotImplementedError( "load function is not supported by Recompute Optimizer for now" @@ -177,42 +190,42 @@ def apply_gradients(self, params_grads): Examples: .. code-block:: python - import paddle - import paddle.base as base - import paddle.base.framework as framework - - paddle.enable_static() - - def mlp(input_x, input_y, hid_dim=128, label_dim=2): - fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) - prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') - cost = paddle.nn.functional.cross_entropy( - input=prediction, label=input_y, - reduction='none', use_softmax=False - ) - sum_cost = paddle.mean(cost) - return sum_cost, fc_1, prediction - - - input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') - cost, fc_1, pred = mlp(input_x, input_y) - print("Finished FF") - - sgd = paddle.optimizer.Adam(learning_rate=0.01) - sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) - sgd._set_checkpoints([fc_1, pred]) - params_grads = sgd.backward( - cost, - startup_program=None, - parameter_list=None, - no_grad_set=None) - - program = cost.block.program - with framework.program_guard(program, None): - optimize_ops = sgd.apply_gradients(params_grads) - - print("Finished apply gradients") + >>> import paddle + >>> import paddle.base.framework as framework + + >>> paddle.enable_static() + + >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2): + ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) + ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') + ... cost = paddle.nn.functional.cross_entropy( + ... input=prediction, label=input_y, + ... reduction='none', use_softmax=False + ... ) + ... sum_cost = paddle.mean(cost) + ... return sum_cost, fc_1, prediction + + >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') + >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') + >>> cost, fc_1, pred = mlp(input_x, input_y) + >>> print("Finished FF") + Finished FF + + >>> sgd = paddle.optimizer.Adam(learning_rate=0.01) + >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) + >>> sgd._set_checkpoints([fc_1, pred]) + >>> params_grads = sgd.backward( + ... cost, + ... startup_program=None, + ... parameter_list=None, + ... no_grad_set=None) + + >>> program = cost.block.program + >>> with framework.program_guard(program, None): + ... optimize_ops = sgd.apply_gradients(params_grads) + + >>> print("Finished apply gradients") + Finished apply gradients """ return self._optimizer.apply_gradients(params_grads=params_grads) @@ -651,36 +664,36 @@ def backward( Examples: .. code-block:: python - import paddle - import paddle.base as base - - paddle.enable_static() - - def mlp(input_x, input_y, hid_dim=128, label_dim=2): - fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) - prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') - cost = paddle.nn.functional.cross_entropy( - input=prediction, label=input_y, - reduction='none', use_softmax=False - ) - sum_cost = paddle.mean(cost) - return sum_cost, fc_1, prediction - - - input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') - cost, fc_1, pred = mlp(input_x, input_y) - print("Finished FF") - - sgd = paddle.optimizer.Adam(learning_rate=0.01) - sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) - sgd._set_checkpoints([fc_1, pred]) - params_grads = sgd.backward( - cost, - startup_program=None, - parameter_list=None, - no_grad_set=None) - print("Finished backward") + >>> import paddle + + >>> paddle.enable_static() + + >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2): + ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) + ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') + ... cost = paddle.nn.functional.cross_entropy( + ... input=prediction, label=input_y, + ... reduction='none', use_softmax=False + ... ) + ... sum_cost = paddle.mean(cost) + ... return sum_cost, fc_1, prediction + + >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') + >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') + >>> cost, fc_1, pred = mlp(input_x, input_y) + >>> print("Finished FF") + Finished FF + + >>> sgd = paddle.optimizer.Adam(learning_rate=0.01) + >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) + >>> sgd._set_checkpoints([fc_1, pred]) + >>> params_grads = sgd.backward( + ... cost, + ... startup_program=None, + ... parameter_list=None, + ... no_grad_set=None) + >>> print("Finished backward") + Finished backward """ assert ( self._checkpoints is not None @@ -733,39 +746,41 @@ def apply_optimize(self, loss, startup_program, params_grads): params_grads (list): list of (param, grad) pair to do optimization. Examples: .. code-block:: python - import paddle - import paddle.base as base - paddle.enable_static() - - def mlp(input_x, input_y, hid_dim=128, label_dim=2): - fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) - prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') - cost = paddle.nn.functional.cross_entropy( - input=prediction, label=input_y, - reduction='none', use_softmax=False - ) - sum_cost = paddle.mean(cost) - return sum_cost, fc_1, prediction - - input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') - cost, fc_1, pred = mlp(input_x, input_y) - print("Finished FF") - - sgd = paddle.optimizer.Adam(learning_rate=0.01) - sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) - sgd._set_checkpoints([fc_1, pred]) - params_grads = sgd.backward( - cost, - startup_program=None, - parameter_list=None, - no_grad_set=None) - - optimize_ops = sgd.apply_optimize( - cost, startup_program=None, params_grads=params_grads) - - print("Finished apply_optimize") + >>> import paddle + + >>> paddle.enable_static() + + >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2): + ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) + ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') + ... cost = paddle.nn.functional.cross_entropy( + ... input=prediction, label=input_y, + ... reduction='none', use_softmax=False + ... ) + ... sum_cost = paddle.mean(cost) + ... return sum_cost, fc_1, prediction + + >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') + >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') + >>> cost, fc_1, pred = mlp(input_x, input_y) + >>> print("Finished FF") + Finished FF + + >>> sgd = paddle.optimizer.Adam(learning_rate=0.01) + >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) + >>> sgd._set_checkpoints([fc_1, pred]) + >>> params_grads = sgd.backward( + ... cost, + ... startup_program=None, + ... parameter_list=None, + ... no_grad_set=None) + + >>> optimize_ops = sgd.apply_optimize( + ... cost, startup_program=None, params_grads=params_grads) + + >>> print("Finished apply_optimize") + Finished apply_optimize """ func = ( diff --git a/python/paddle/incubate/tensor/math.py b/python/paddle/incubate/tensor/math.py index 46ef434ba44b5..24b303107adc9 100644 --- a/python/paddle/incubate/tensor/math.py +++ b/python/paddle/incubate/tensor/math.py @@ -15,7 +15,7 @@ from paddle import _C_ops from paddle.base.data_feeder import check_variable_and_dtype from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode +from paddle.framework import in_dynamic_or_pir_mode from paddle.utils import deprecated __all__ = [] @@ -66,7 +66,7 @@ def segment_sum(data, segment_ids, name=None): [4., 5., 6.]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.segment_pool(data, segment_ids, "SUM") else: check_variable_and_dtype( @@ -135,7 +135,7 @@ def segment_mean(data, segment_ids, name=None): """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.segment_pool(data, segment_ids, "MEAN") check_variable_and_dtype( @@ -203,7 +203,7 @@ def segment_min(data, segment_ids, name=None): """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.segment_pool(data, segment_ids, "MIN") check_variable_and_dtype( @@ -271,7 +271,7 @@ def segment_max(data, segment_ids, name=None): """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out = _C_ops.segment_pool(data, segment_ids, "MAX") return out diff --git a/python/paddle/io/__init__.py b/python/paddle/io/__init__.py index 8d9a1909f07ca..6a0f6e05a37d5 100755 --- a/python/paddle/io/__init__.py +++ b/python/paddle/io/__init__.py @@ -25,6 +25,7 @@ Sampler, SequenceSampler, Subset, + SubsetRandomSampler, TensorDataset, WeightedRandomSampler, get_worker_info, @@ -48,4 +49,5 @@ 'WeightedRandomSampler', 'random_split', 'Subset', + 'SubsetRandomSampler', ] diff --git a/python/paddle/io/dataloader/__init__.py b/python/paddle/io/dataloader/__init__.py index bb65463f70afc..aff32fd70de49 100644 --- a/python/paddle/io/dataloader/__init__.py +++ b/python/paddle/io/dataloader/__init__.py @@ -29,3 +29,4 @@ from .sampler import SequenceSampler from .sampler import RandomSampler from .sampler import WeightedRandomSampler +from .sampler import SubsetRandomSampler diff --git a/python/paddle/io/dataloader/sampler.py b/python/paddle/io/dataloader/sampler.py index 44bc545f777cd..f6bb2e41b4b8f 100644 --- a/python/paddle/io/dataloader/sampler.py +++ b/python/paddle/io/dataloader/sampler.py @@ -15,6 +15,7 @@ import numpy as np from ...framework import core +from ...tensor import randperm class Sampler: @@ -340,3 +341,45 @@ def __iter__(self): def __len__(self): mul = np.prod(self.weights.shape) // self.weights.shape[-1] return self.num_samples * mul + + +class SubsetRandomSampler(Sampler): + r""" + Randomly sample elements from a given list of indices, without replacement. + + Args: + indices (sequence): a sequence of indices + + Examples: + + .. code-block:: python + + >>> import paddle + >>> from paddle.io import SubsetRandomSampler + + >>> paddle.seed(2023) + >>> sampler = SubsetRandomSampler(indices=[1, 3, 5, 7, 9]) + + >>> for index in sampler: + ... print(index) + 9 + 3 + 7 + 5 + 1 + + """ + + def __init__(self, indices): + if len(indices) == 0: + raise ValueError( + "The length of `indices` in SubsetRandomSampler should be greater than 0." + ) + self.indices = indices + + def __iter__(self): + for i in randperm(len(self.indices)): + yield self.indices[i] + + def __len__(self) -> int: + return len(self.indices) diff --git a/python/paddle/jit/api.py b/python/paddle/jit/api.py index 3873a150b6ab4..7125e76717f3e 100644 --- a/python/paddle/jit/api.py +++ b/python/paddle/jit/api.py @@ -15,17 +15,18 @@ # Temporary disable isort to avoid circular import # This can be removed after the circular import is resolved -# isort: skip_file from __future__ import annotations +import inspect import os import pickle +import sys +import threading +import types import warnings from collections import OrderedDict -import inspect -import threading +from contextlib import contextmanager from typing import Any -import types import paddle from paddle.base import core, dygraph @@ -39,43 +40,52 @@ program_desc_tracing_guard, switch_to_static_graph, ) -from .dy2static import logging_utils -from .dy2static.convert_call_func import ( - ConversionOptions, - add_ignore_module, -) -from .dy2static.program_translator import ( - ProgramTranslator, - StaticFunction, - ASTStaticFunction, - SymbolicStaticFunction, - unwrap_decorators, -) -from paddle.jit.translated_layer import ( - TranslatedLayer, - INFER_MODEL_SUFFIX, - INFER_PARAMS_SUFFIX, - INFER_PARAMS_INFO_SUFFIX, - INFER_PROPERTY_SUFFIX, -) -from paddle.nn import Layer from paddle.base.executor import Executor, scope_guard from paddle.base.framework import ( Block, + EagerParamBase, + Parameter, Program, Variable, - Parameter, - EagerParamBase, -) -from paddle.base.framework import ( _current_expected_place, _dygraph_guard, _dygraph_tracer, + dygraph_only, ) -from paddle.base.framework import dygraph_only from paddle.base.wrapped_decorator import wrap_decorator -from paddle.static.io import save_inference_model from paddle.framework import in_dynamic_mode +from paddle.nn import Layer +from paddle.static.io import save_inference_model +from paddle.utils.environments import ( + BooleanEnvironmentVariable, + EnvironmentVariableGuard, +) + +from .dy2static import logging_utils +from .dy2static.convert_call_func import ConversionOptions, add_ignore_module +from .dy2static.program_translator import ( + ASTStaticFunction, + ProgramTranslator, + StaticFunction, + SymbolicStaticFunction, + convert_to_static, + unwrap_decorators, +) +from .translated_layer import ( + INFER_MODEL_SUFFIX, + INFER_PARAMS_INFO_SUFFIX, + INFER_PARAMS_SUFFIX, + INFER_PROPERTY_SUFFIX, + TranslatedLayer, +) + +ENV_ENABLE_SOT = BooleanEnvironmentVariable("ENABLE_FALL_BACK", True) + + +@contextmanager +def sot_mode_guard(value: bool): + with EnvironmentVariableGuard(ENV_ENABLE_SOT, value): + yield def create_program_from_desc(program_desc): @@ -165,7 +175,7 @@ def __impl__(*args, **kwargs): "We will just return dygraph output." ) return dygraph_func(*args, **kwargs) - static_func = program_translator.get_func(dygraph_func) + static_func = convert_to_static(dygraph_func) return static_func(*args, **kwargs) return __impl__ @@ -223,9 +233,7 @@ def ignore_module(modules: list[Any]): def _check_and_set_backend(backend, build_strategy): if backend not in ['CINN', None]: raise ValueError( - "The backend of to_static should be 'CINN' or None, but received {}.".format( - backend - ) + f"The backend of to_static should be 'CINN' or None, but received {backend}." ) if backend == 'CINN': build_strategy.build_cinn_pass = True @@ -236,28 +244,31 @@ def to_static( input_spec=None, build_strategy=None, backend=None, - enable_fallback=None, **kwargs, ): """ - Converts imperative dygraph APIs into declarative function APIs. Decorator + Converts dynamic graph APIs into static graph function APIs. Decorator @to_static handles the Program and Executor of static graph mode and returns - the result as dygraph Tensor(s). Users could use the returned dygraph - Tensor(s) to do imperative training, inference, or other operations. If the - decorated function calls other imperative function, the called one will be - converted into declarative function as well. + the result as dynamic graph Tensor(s). Users could use the returned dynamic + graph Tensor(s) to do dynamic graph training, inference, or other operations. + If the decorated function calls other dynamic graph function, the called one + will be converted into static graph function as well. + Args: - function (callable): callable imperative function. - input_spec(list[InputSpec]|tuple[InputSpec]): list/tuple of InputSpec to specific the shape/dtype/name - information of each input Tensor. - build_strategy(BuildStrategy|None): This argument is used to compile the + function (callable): Callable dynamic graph function. If it used as a + decorator, the decorated function will be parsed as this parameter. + input_spec (list[InputSpec]|tuple[InputSpec]): list/tuple of InputSpec to + specific the shape/dtype/name information of each input Tensor. + build_strategy (BuildStrategy|None): This argument is used to compile the converted program with the specified options, such as operators' fusion in the computational graph and memory optimization during the execution of the computational graph. For more information about build_strategy, please refer to :code:`paddle.static.BuildStrategy`. The default is None. - backend(str, Optional): Specifies compilation backend, which can be `CINN` or None. When backend is `CINN`, CINN compiler will be used to speed up training and inference. - kwargs: Support keys including `property`, set `property` to True if the fucntion is python property. - + backend(str, Optional): Specifies compilation backend, which can be `CINN` or + None. When backend is `CINN`, CINN compiler will be used to speed up + training and inference. + kwargs: Support keys including `property`, set `property` to True if the function + is python property. Returns: Tensor(s): containing the numerical result. @@ -285,24 +296,28 @@ def to_static( """ property = kwargs.get("property", False) + full_graph = kwargs.get("full_graph", None) def decorated(python_func): """ Decorates a python function into a ASTStaticFunction object. """ - nonlocal enable_fallback - if enable_fallback is None: - flag = os.environ.get("ENABLE_FALL_BACK", None) - if flag == "True": - enable_fallback = True - else: # None or True - enable_fallback = False + nonlocal full_graph + if full_graph is None: + flag = ENV_ENABLE_SOT.get() + full_graph = not flag - StaticClass = StaticFunctionClass = { - True: SymbolicStaticFunction, - False: ASTStaticFunction, - }[enable_fallback] + if sys.version_info >= (3, 12) and not full_graph: + warnings.warn( + "full_graph=False is not supported in Python 3.12+. Set full_graph=True automatically" + ) + full_graph = True + + StaticClass = { + False: SymbolicStaticFunction, + True: ASTStaticFunction, + }[full_graph] # Step 1. unwrap the function if it is already decorated. _, python_func = unwrap_decorators(python_func) @@ -1108,7 +1123,7 @@ def save(layer, path, input_spec=None, **configs): static_forward = to_static( inner_layer.forward, input_spec=inner_input_spec, - enable_fallback=False, + full_graph=True, ) concrete_program = ( static_forward.concrete_program_specify_input_spec( @@ -1146,15 +1161,13 @@ def save(layer, path, input_spec=None, **configs): static_function = to_static( static_func, input_spec=inner_input_spec, - enable_fallback=False, + full_graph=True, ) concrete_program = static_function.concrete_program if static_function._class_instance is None: warnings.warn( - '`jit.save` will only save the `Program`, not the parameters. If you have to save the parameters, please make sure that {} is a member function of `paddle.nn.Layer` and the saved parameters are in `state_dict`'.format( - layer - ) + f'`jit.save` will only save the `Program`, not the parameters. If you have to save the parameters, please make sure that {layer} is a member function of `paddle.nn.Layer` and the saved parameters are in `state_dict`' ) # when save multi `StaticFunction`, all `StaticFunction` share params. diff --git a/python/paddle/jit/dy2static/__init__.py b/python/paddle/jit/dy2static/__init__.py index 522814d2d293c..115ac0a00275f 100644 --- a/python/paddle/jit/dy2static/__init__.py +++ b/python/paddle/jit/dy2static/__init__.py @@ -15,21 +15,29 @@ from .assert_transformer import AssertTransformer # noqa: F401 from .ast_transformer import DygraphToStaticAst # noqa: F401 from .convert_call_func import convert_call as Call # noqa: F401 -from .convert_operators import convert_assert as Assert # noqa: F401 -from .convert_operators import convert_attr as Attr # noqa: F401 -from .convert_operators import convert_ifelse as IfElse # noqa: F401 -from .convert_operators import convert_len as Len # noqa: F401 -from .convert_operators import convert_load as Ld # noqa: F401 -from .convert_operators import convert_logical_and as And # noqa: F401 -from .convert_operators import convert_logical_not as Not # noqa: F401 -from .convert_operators import convert_logical_or as Or # noqa: F401 -from .convert_operators import convert_pop as Pop # noqa: F401 -from .convert_operators import convert_shape as Shape # noqa: F401 -from .convert_operators import convert_shape_compare # noqa: F401 -from .convert_operators import convert_var_dtype as AsDtype # noqa: F401 -from .convert_operators import convert_while_loop as While # noqa: F401 -from .convert_operators import indexable as Indexable # noqa: F401 -from .convert_operators import unpack_by_structure as Unpack # noqa: F401 + +# isort: off +# NOTE(gouzil): isort will delete the import +# TODO(gouzil): Remove `isort: off` after adding the `combine-as-imports` configuration +from .convert_operators import ( # noqa: F401 + convert_assert as Assert, + convert_attr as Attr, + convert_ifelse as IfElse, + convert_len as Len, + convert_load as Ld, + convert_logical_and as And, + convert_logical_not as Not, + convert_logical_or as Or, + convert_pop as Pop, + convert_shape_compare, + convert_shape as Shape, + convert_var_dtype as AsDtype, + convert_while_loop as While, + indexable as Indexable, + unpack_by_structure as Unpack, +) + +# isort: on from .program_translator import convert_to_static # noqa: F401 from .static_analysis import NodeVarType, StaticAnalysisVisitor # noqa: F401 from .utils import UndefinedVar, ast_to_source_code, saw # noqa: F401 diff --git a/python/paddle/jit/dy2static/convert_call_func.py b/python/paddle/jit/dy2static/convert_call_func.py index 22ca5e756568a..8fa47657426c5 100644 --- a/python/paddle/jit/dy2static/convert_call_func.py +++ b/python/paddle/jit/dy2static/convert_call_func.py @@ -124,9 +124,7 @@ def is_unsupported(func): if func is v: translator_logger.log( 2, - "Whitelist: {} is part of built-in module and does not have to be transformed.".format( - func - ), + f"Whitelist: {func} is part of built-in module and does not have to be transformed.", ) return True @@ -142,9 +140,7 @@ def is_unsupported(func): if is_paddle_func(func): translator_logger.log( 2, - "Whitelist: {} is part of Paddle module and does not have to be transformed.".format( - func - ), + f"Whitelist: {func} is part of Paddle module and does not have to be transformed.", ) return True @@ -198,9 +194,7 @@ def convert_call(func): if options is not None and options.not_convert: translator_logger.log( 2, - "{} is not converted when it is decorated by 'paddle.jit.not_to_static'.".format( - func - ), + f"{func} is not converted when it is decorated by 'paddle.jit.not_to_static'.", ) return func @@ -280,9 +274,7 @@ def convert_call(func): # If func is not in __globals__, it does not need to be transformed # because it has been transformed before. translator_logger.warn( - "{} doesn't have to be transformed to static function because it has been transformed before, it will be run as-is.".format( - func - ) + f"{func} doesn't have to be transformed to static function because it has been transformed before, it will be run as-is." ) converted_call = func except AttributeError: @@ -334,9 +326,7 @@ def convert_call(func): if converted_call is None: translator_logger.warn( - "{} doesn't have to be transformed to static function, and it will be run as-is.".format( - func - ) + f"{func} doesn't have to be transformed to static function, and it will be run as-is." ) return func diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index 47618392175d9..72092b10d44d5 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -19,6 +19,7 @@ from paddle.base.data_feeder import convert_dtype from paddle.base.dygraph.base import _convert_into_variable, in_to_static_mode from paddle.base.framework import Variable, core, default_main_program +from paddle.pir import OpResult from .py_layer import StaticPyLayer from .utils import ( @@ -33,7 +34,7 @@ def convert_attr(x, attr): - if isinstance(x, Variable) and attr == "size": + if isinstance(x, (Variable, OpResult)) and attr == "size": return x.size() else: return getattr(x, attr) @@ -52,6 +53,17 @@ def convert_load(x): return StaticPyLayer(x) # get the new output of the var + if isinstance(x, OpResult): + cur_block = default_main_program().current_block() + + from paddle.jit.pir_dy2static.parameter_recorder import ( + _global_inplace_map, + ) + + new_var = _global_inplace_map.get(cur_block.program, id(x)) + if new_var is not None: + return new_var + if isinstance(x, Variable): cur_block = default_main_program().current_block() @@ -69,7 +81,7 @@ def convert_load(x): def indexable(x, code=None): - if isinstance(x, Variable): + if isinstance(x, (Variable, OpResult)): return x elif hasattr(x, '__iter__'): return list(x) @@ -83,7 +95,7 @@ def indexable(x, code=None): def unpack_by_structure(target, structure): """unified unpack interface for paddle and python.""" - if isinstance(target, Variable): + if isinstance(target, (Variable, OpResult)): return _unpack_by_structure_paddle(target, structure) else: return _unpack_by_structure_python(target, structure) @@ -130,7 +142,7 @@ def convert_while_loop( # NOTE: It may be slower if cond is very expensive, but usually cond is just O(1). # If loop_vars is changed during cond callable, then it causes bug, but current logical_and/logical_not/... doesn't change the loop_vars. pred = cond() - if isinstance(pred, Variable): + if isinstance(pred, (Variable, OpResult)): _run_paddle_while( cond, body, getter, setter, return_name_ids, push_pop_names ) @@ -195,7 +207,7 @@ def new_cond_fn(*args): def _run_py_while(cond, body, getter, setter): while True: pred = cond() - if isinstance(pred, Variable): + if isinstance(pred, (Variable, OpResult)): raise Dygraph2StaticException( "python while pred change from bool to variable." ) @@ -229,11 +241,11 @@ def convert_logical_and(x_func, y_func): if `x>1` is False, `y<1` should NOT be run. """ x_value = x_func() - if not isinstance(x_value, Variable): + if not isinstance(x_value, (Variable, OpResult)): return _run_py_logical_and(lambda: x_value, y_func) y_value = y_func() - if not isinstance(y_value, Variable): + if not isinstance(y_value, (Variable, OpResult)): return _run_py_logical_and(lambda: y_value, lambda: x_value) return _run_paddle_logical_and(x_value, y_value) @@ -247,7 +259,7 @@ def _run_paddle_logical_and(x, y): def _run_py_logical_and(x_func, y_func): x_value = x_func() - assert not isinstance(x_value, Variable) + assert not isinstance(x_value, (Variable, OpResult)) # NOTE(liym27): # 1. Returns y_func() if x_value is False; @@ -280,11 +292,11 @@ def convert_logical_or(x_func, y_func): if `x>1` is True, `y<1` should NOT be run. """ x_value = x_func() - if not isinstance(x_value, Variable): + if not isinstance(x_value, (Variable, OpResult)): return _run_py_logical_or(lambda: x_value, y_func) y_value = y_func() - if not isinstance(y_value, Variable): + if not isinstance(y_value, (Variable, OpResult)): return _run_py_logical_or(lambda: y_value, lambda: x_value) return _run_paddle_logical_or(x_value, y_value) @@ -298,7 +310,7 @@ def _run_paddle_logical_or(x, y): def _run_py_logical_or(x_func, y_func): x_value = x_func() - assert not isinstance(x_value, Variable) + assert not isinstance(x_value, (Variable, OpResult)) # NOTE(liym27): # 1. Returns y_func() if x_value is False; @@ -317,7 +329,7 @@ def convert_logical_not(x): A python bool variable or a bool Tensor. """ - if isinstance(x, Variable): + if isinstance(x, (Variable, OpResult)): return _run_paddle_logical_not(x) else: return _run_py_logical_not(x) @@ -357,7 +369,7 @@ def convert_ifelse( ``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` . """ - if isinstance(pred, Variable): + if isinstance(pred, (Variable, OpResult)): out = _run_paddle_cond( pred, true_fn, @@ -428,15 +440,11 @@ def new_false_fn(): "Unsupported return type of true_fn and false_fn in cond", str(e) ): raise Dygraph2StaticException( - "Your if/else have different return type. TODO: add link to modifty. {}".format( - str(e) - ) + f"Your if/else have different return type. TODO: add link to modifty. {str(e)}" ) if re.search("Incompatible return values of", str(e)): raise Dygraph2StaticException( - "Your if/else have different number of return value. TODO: add link to modifty. {}".format( - str(e) - ) + f"Your if/else have different number of return value. TODO: add link to modifty. {str(e)}" ) raise e get_args = lambda: helper.get(return_name_ids) @@ -460,7 +468,7 @@ def _remove_no_value_return_var(out): align_ret = out[0] if isinstance(align_ret, tuple): for index, item in enumerate(align_ret): - if isinstance(item, Variable) and ( + if isinstance(item, (Variable, OpResult)) and ( RETURN_NO_VALUE_VAR_NAME in item.name ): # return None @@ -473,7 +481,7 @@ def _remove_no_value_return_var(out): break for index, item in enumerate(processed_out): - if isinstance(item, Variable) and ( + if isinstance(item, (Variable, OpResult)) and ( RETURN_NO_VALUE_VAR_NAME in item.name ): processed_out = processed_out[:index] @@ -540,7 +548,7 @@ def convert_len(var): operations are added in `len` transformation, such as appending `shape_op` in var.block. """ - if isinstance(var, Variable): + if isinstance(var, (Variable, OpResult)): assert var.ndim > 0, "len() of a 0-D tensor is wrong" if var.type in [ core.VarDesc.VarType.LOD_TENSOR, @@ -560,14 +568,14 @@ def convert_len(var): % type(var) ) else: - if isinstance(var, VariableTuple): + if isinstance(var, (VariableTuple)): return var.__len__() return len(var) def convert_zip(*args): for i, arg in enumerate(args): - if isinstance(arg, Variable) and arg.shape[0] == -1: + if isinstance(arg, (Variable, OpResult)) and arg.shape[0] == -1: raise RuntimeError( "Not support zip(tensor, ...) when tensor.shape[0] == -1, " f"but found args[{str(i)}].shape[0] == -1 in 'zip'" @@ -586,7 +594,7 @@ class VariableTuple: def __init__(self, var, start=0): self.var = var self.len = convert_len(var) - if isinstance(self.len, Variable): + if isinstance(self.len, (Variable, OpResult)): self.rag = paddle.arange(start, start + self.len, 1, paddle.int64) else: self.rag = range(start, start + self.len) @@ -599,14 +607,14 @@ def __len__(self): def convert_enumerate(*args): - has_variable = any(isinstance(x, Variable) for x in args) + has_variable = any(isinstance(x, (Variable, OpResult)) for x in args) if has_variable: return VariableTuple(*args) return enumerate(*args) def convert_range(*args): - has_variable = any(isinstance(x, Variable) for x in args) + has_variable = any(isinstance(x, (Variable, OpResult)) for x in args) if has_variable: if len(args) == 1: return paddle.arange(0, args[0], 1, paddle.int64) @@ -631,7 +639,7 @@ def has_negative(list_shape): # (2) if x.shape does not contains -1, return lsit(x.shape) directly - if isinstance(x, Variable): + if isinstance(x, (Variable, OpResult)): values = list(x.shape) if has_negative(values): shape_tensor = paddle.shape(x) @@ -670,7 +678,7 @@ def convert_shape_compare(left, *args): args_len % 2 == 0 ), "Illegal input for convert_shape_compare, *args should be op(str), var, op(str), var ..." num_cmp = args_len // 2 - if isinstance(left, Variable): + if isinstance(left, (Variable, OpResult)): def reduce_compare(x, op_str, y): element_wise_result = eval("x " + op_str + " y") @@ -715,14 +723,14 @@ def reduce_compare(x, op_str, y): def cast_bool_if_necessary(var): - assert isinstance(var, Variable) + assert isinstance(var, (Variable, OpResult)) if convert_dtype(var.dtype) not in ['bool']: var = paddle.cast(var, dtype="bool") return var def convert_var_dtype(var, dtype): - if isinstance(var, Variable): + if isinstance(var, (Variable, OpResult)): src_dtype = convert_dtype(var.dtype) assert src_dtype in [ 'bool', @@ -739,9 +747,7 @@ def convert_var_dtype(var, dtype): 'bool', 'int', 'float', - ], "The casted target dtype is {}, which is not supported in type casting.".format( - dtype - ) + ], f"The casted target dtype is {dtype}, which is not supported in type casting." cast_map = { 'bool': 'bool', 'int': 'int32', @@ -756,7 +762,7 @@ def convert_assert(cond, message=""): """ A function representation of a Python ``assert`` statement. """ - if isinstance(cond, Variable): + if isinstance(cond, (Variable, OpResult)): cond = paddle.cast(cond, "bool") # NOTE: message is not used because Paddle Assert has no corresponding parameter to use. from paddle.static.nn.control_flow import Assert @@ -772,7 +778,7 @@ def convert_print(*objects, sep=' ', end='\n', file=None, flush=False): at compile time and only print the Tensor values at runtime. """ for obj in objects: - if isinstance(obj, Variable): + if isinstance(obj, (Variable, OpResult)): paddle.static.Print(obj) print(*objects, sep=sep, end=end, file=file, flush=flush) @@ -789,7 +795,7 @@ def convert_pop(target, *args): A item poped from target. """ - is_variable = isinstance(target, Variable) + is_variable = isinstance(target, (Variable, OpResult)) if is_variable: is_tensor_array = target.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY diff --git a/python/paddle/jit/dy2static/error.py b/python/paddle/jit/dy2static/error.py index 96124f1369087..827f5474c82d8 100644 --- a/python/paddle/jit/dy2static/error.py +++ b/python/paddle/jit/dy2static/error.py @@ -156,7 +156,9 @@ def __init__(self): self.suggestion_dict = { ('is not initialized.', 'Hint:', 'IsInitialized'): ( "Please ensure all your sublayers are inheritted from nn.Layer.", - "Please ensure there is no tensor created explicitly depended on external data, we suggest to register it as buffer tensor. See https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/04_dygraph_to_static/export_model/principle_cn.html#parameters-buffers for details", + "Please ensure there is no tensor created explicitly depended on external data, " + + "we suggest to register it as buffer tensor. " + + "See https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/jit/principle_cn.html#buffers for details", ) } diff --git a/python/paddle/jit/dy2static/function_spec.py b/python/paddle/jit/dy2static/function_spec.py index f302dd4e2c6ca..565013b17c3ef 100644 --- a/python/paddle/jit/dy2static/function_spec.py +++ b/python/paddle/jit/dy2static/function_spec.py @@ -173,7 +173,7 @@ def args_to_input_spec(self, args, kwargs): return args_with_spec, kwargs_with_spec @switch_to_static_graph - def newir_to_static_inputs_with_spec(self, input_with_spec, main_program): + def pir_to_static_inputs_with_spec(self, input_with_spec, main_program): """ Constructs feed layer by inputs with InputSpec information for main program. @@ -450,9 +450,7 @@ def check_type_and_len(input, spec, check_length=False): real_spec.shape = input_spec.shape else: logging_utils.warn( - "input spec is not compatitable with real inputs. input_spec: {input_spec} , real_spec: {real_spec} ".format( - input_spec=input_spec, real_spec=real_spec - ) + f"input spec is not compatitable with real inputs. input_spec: {input_spec} , real_spec: {real_spec} " ) return real_spec else: diff --git a/python/paddle/jit/dy2static/ifelse_transformer.py b/python/paddle/jit/dy2static/ifelse_transformer.py index 8da098959aa6e..02129b02bf103 100644 --- a/python/paddle/jit/dy2static/ifelse_transformer.py +++ b/python/paddle/jit/dy2static/ifelse_transformer.py @@ -319,9 +319,7 @@ def _valid_nonlocal_names(return_name_ids, nonlocal_names): for name in return_name_ids: if name not in nonlocal_names: raise ValueError( - "Required returned var '{}' must be in 'nonlocal' statement '', but not found.".format( - name - ) + f"Required returned var '{name}' must be in 'nonlocal' statement '', but not found." ) nonlocal_names.remove(name) diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index 04255140ae9ca..2b6cca032beae 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from copy import deepcopy import numpy as np @@ -24,7 +23,7 @@ from paddle.base.compiler import BuildStrategy from paddle.base.data_feeder import check_type, convert_dtype from paddle.base.dygraph.base import switch_to_static_graph -from paddle.base.framework import _apply_pass +from paddle.base.framework import _apply_pass, get_flags from paddle.base.unique_name import guard as UniqueNameGuard from paddle.optimizer.lr import LRScheduler @@ -215,7 +214,8 @@ def __init__( ) # program_id -> list(scope) - self._scope_cache = {} + self._pir_scope_cache = {} + self._legacy_scope_cache = {} self._hooker = None self._backend = kwargs.get('backend', None) self._grad_var_names = {} @@ -240,7 +240,6 @@ def __call__(self, inputs): self._create_scope_vec( program_id=self.program_id, use_scope_cache=True ), - self._double_grads, self._cuda_graph_vec, *attrs ) @@ -268,25 +267,23 @@ def set_hooker(self, hooker): self._hooker = hooker def _get_scope(self, program_id=None, use_scope_cache=False): - if use_scope_cache: - if program_id not in self._scope_cache: - scope = core.Scope() - self._scope_cache[program_id] = [scope] - return scope - else: - for scope in self._scope_cache[program_id]: - if scope._can_reused: - return scope - scope = core.Scope() - self._scope_cache[program_id].append(scope) - return scope + if get_flags('FLAGS_enable_pir_in_executor')[ + 'FLAGS_enable_pir_in_executor' + ]: + _scope_cache = self._pir_scope_cache else: + _scope_cache = self._legacy_scope_cache + if not use_scope_cache: return core.Scope() - - @LazyInitialized - def _double_grads(self): - # TODO: check the affects. - return None + if program_id not in _scope_cache: + _scope_cache[program_id] = [] + cached_scopes = _scope_cache[program_id] + for scope in cached_scopes: + if scope._can_reused: + return scope + scope = core.Scope() + cached_scopes.append(scope) + return scope # whole @switch_to_static_graph @@ -839,7 +836,9 @@ def _apply_inplace_pass(self, forward_program, backward_program): "mem_opt_skip_vars": forward_mem_opt_skip_vars, "for_partial_block": True, } - if not os.getenv("FLAGS_enable_new_ir_in_executor"): + if not get_flags('FLAGS_enable_pir_in_executor')[ + 'FLAGS_enable_pir_in_executor' + ]: _apply_pass( forward_program, empty_startup_program, @@ -853,7 +852,9 @@ def _apply_inplace_pass(self, forward_program, backward_program): "mem_opt_skip_vars": backward_mem_opt_skip_vars, "for_partial_block": True, } - if not os.getenv("FLAGS_enable_new_ir_in_executor"): + if not get_flags('FLAGS_enable_pir_in_executor')[ + 'FLAGS_enable_pir_in_executor' + ]: _apply_pass( backward_program, empty_startup_program, @@ -959,13 +960,10 @@ def create_out(var_id): return input_vars, out_vars, input_var_names def _create_scope_vec(self, program_id=None, use_scope_cache=False): - # Hold forward variables - tmp_scope_vec = None inner_scope = self._get_scope( program_id=program_id, use_scope_cache=use_scope_cache ) - tmp_scope_vec = [inner_scope] - return tmp_scope_vec + return [inner_scope] def _create_cuda_graph_vec(self): var = core.eager.Tensor( @@ -1056,19 +1054,6 @@ def _set_grad_type(self, params, train_program): continue param._set_grad_type(grad_var.type()) - def _remove_op_call_stack(self, main_program): - """ - Remove op's python call stack with redundant low-level error messages related to - transforamtions to avoid confusing users. - """ - assert isinstance(main_program, framework.Program) - for block in main_program.blocks: - for op in block.ops: - if op.has_attr("op_callstack"): - op._remove_attr("op_callstack") - - return main_program - def _check_params_all_inited(self, main_program): """ Check all params from main program are already initialized, see details as follows: @@ -1154,4 +1139,9 @@ def add_build_strategy_for( builded_program = paddle.static.Program() for var in program.block(0).vars.values(): builded_program.block(0)._clone_variable(var, False) + + # set back the parent_idx of blocks + for origin, current in zip(program.blocks, builded_program.blocks): + current.desc.set_parent_idx(origin.desc.parent) + return builded_program diff --git a/python/paddle/jit/dy2static/newir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py similarity index 58% rename from python/paddle/jit/dy2static/newir_partial_program.py rename to python/paddle/jit/dy2static/pir_partial_program.py index 198cc105b3ec1..81655c3440a86 100644 --- a/python/paddle/jit/dy2static/newir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -13,8 +13,6 @@ # limitations under the License. import itertools -import os -from copy import deepcopy import numpy as np @@ -23,11 +21,10 @@ from paddle import _legacy_C_ops from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard from paddle.autograd.ir_backward import grad -from paddle.base import core, framework, program_guard +from paddle.base import core, framework from paddle.base.compiler import BuildStrategy from paddle.base.data_feeder import check_type, convert_dtype from paddle.base.dygraph.base import switch_to_static_graph -from paddle.base.framework import _apply_pass from paddle.framework import use_pir_api from paddle.optimizer.lr import LRScheduler from paddle.pir import OpResult, fake_op_result, is_fake_op_result @@ -38,6 +35,20 @@ __all__ = [] +class cached_property: + """ + Descriptor to implement lazy initialization of property. + """ + + def __init__(self, function): + self.function = function + + def __get__(self, instance, cls): + val = self.function(instance) + setattr(instance, self.function.__name__, val) + return val + + class NestSequence: """ A wrapper class that easily to flatten and restore the nest structure of @@ -97,45 +108,151 @@ def __getitem__(self, item): return self.__input_list[item] -class LazyInitialized: - """ - Descriptor to implement lazy initialization of property. +class RunableProgram: + """a pir program ready for run_program_op to run. constructed by 3 parts: + - pir program (pir::Program) + - in_out_values + - input_x values ([pir::OpResult]) + - input_param values ([pir::OpResult]) + - output values ([pir::OpResult]) + - forward_backward_ranges + - forward_range (tuple(Int, Int)) | None + - backward_range (tuple(Int, Int)) | None """ - def __init__(self, function): - self.function = function + def __init__( + self, + program, + in_out_values, + grad_in_out_values=None, + forward_range=None, + backward_range=None, + ): + assert isinstance( + in_out_values, tuple + ), "in_out_values must be tuple with len == 3" + assert ( + len(in_out_values) == 3 + ), "in_out_values must be tuple with len == 3" + assert isinstance( + in_out_values[0], list + ), "in_out_values must be tuple with len == 3" + self.program = program + self.x_values, self.param_values, self.out_values = in_out_values + self.forward_range = forward_range + self.backward_range = backward_range + if self.forward_range is None: + self.forward_range = (0, len(self.program.global_block().ops)) + if self.backward_range is None: + self.backward_range = ( + len(self.program.global_block().ops), + len(self.program.global_block().ops), + ) + if grad_in_out_values is None: + grad_in_out_values = [], [], [] + ( + self.x_grad_values, + self.p_grad_values, + self.o_grad_values, + ) = grad_in_out_values + + def clone(self): + cloned_program, mapping = paddle.base.libpaddle.pir.clone_program( + self.program, self.x_values, self.param_values, self.out_values + ) + cloned_x = [mapping[x] for x in self.x_values] + cloned_param = [mapping[p] for p in self.param_values] + cloned_out = [mapping[o] for o in self.out_values] + return RunableProgram( + cloned_program, + (cloned_x, cloned_param, cloned_out), + None, + self.forward_range, + self.backward_range, + ) - def __get__(self, instance, cls): - val = self.function(instance) - setattr(instance, self.function.__name__, val) - return val + def split_forward_backward(self): + [ + fwd_prog, + bwd_prog, + ], prog_attr = paddle.base.libpaddle.pir.split_program( + self.program, + self.x_values, + self.param_values, + self.out_values, + self.x_grad_values, + self.p_grad_values, + self.o_grad_values, + list(self.forward_range), + list(self.backward_range), + ) + + return [fwd_prog, bwd_prog], prog_attr + + @cached_property + def _forward_backward_program(self): + return self.split_forward_backward() + + @property + def program_attr(self): + return self._forward_backward_program[1] + + @property + def forward_program(self): + return self._forward_backward_program[0][0] + + @property + def backward_program(self): + return self._forward_backward_program[0][1] -class ProgramInfo: +class PirPassContext: """ - A helper class to recoder Program information + PirPassContext is a class that only has staticmethod currently. + It will create a new RunableProgram after calling apply method. """ - def __init__(self): - self.op_size = { - 'fp32': -1, - 'amp': -1, - 'fp16': -1, - } - self.programs = {} - self.mode = "infer" + INPUT_OP_NAME = "pd_op.data" + PARM_OP_NAME = "builtin.get_parameter" + OUTPUT_OP_NAME = "builtin.set_parameter" + + @classmethod + def apply(cls, runable_program, build_strategy): + # TODO(Aurelius84): Currently only support infer mode, + # and we just use forward_program because backward_program + # is empty. + if not build_strategy.build_cinn_pass: + return runable_program + elif not paddle.is_compiled_with_cinn(): + raise RuntimeError( + "Please install PaddlePaddle compiled with CINN while setting build_strategy.build_cinn_pass = True." + ) - def __call__(self, key, prog_creator): - """ - Recoder infer program and op size. + fwd_program = paddle.base.libpaddle.pir.apply_pir_pass( + runable_program.forward_program + ) + in_out_values = cls._prepare_attr(fwd_program) + return RunableProgram(fwd_program, in_out_values) + + @classmethod + def _prepare_attr(cls, program): """ - assert key in ['fp32', 'amp', 'fp16'] - if key not in self.programs: - infer_prog = prog_creator(is_infer_mode=True) - self.programs[key] = infer_prog - self.op_size[key] = infer_prog.desc.global_block().op_size() + After applying Pass, we need to update the Input/Parameter/Output Value + that refer to the new program. - return self.programs[key], self.op_size[key] + NOTE: We assume that Inputs come from INPUT_OP, Params come from + PARM_OP and Output come from OUTPUT_OP. + """ + inputs, params, outputs = [], [], [] + for op in program.global_block().ops: + op_name = op.name() + if op_name == cls.INPUT_OP_NAME: + inputs.append(op.result(0)) + elif op_name == cls.PARM_OP_NAME: + params.append(op.result(0)) + elif op_name == cls.OUTPUT_OP_NAME: + outputs.append(op.operand(0).source()) + return inputs, params, outputs class PartialProgramLayerHook: @@ -189,7 +306,6 @@ def __init__( self._cuda_graph_pool_id = 0 # Set default mode to train self.training = True - self._infer_info = ProgramInfo() self._program_extra_info = {} amp_dtype, custom_white_list, custom_black_list = None, None, None @@ -218,14 +334,13 @@ def __call__(self, inputs): Execute static graph by Interpreter and Return dynamic Tensors. """ in_vars, out_vars = self._prepare(inputs) - self._cast_fp16_if_pure_fp16(in_vars) attrs = self._prepare_attributes() # self._sync_lr_value_with_scheduler() c_run_program_fn = None if use_pir_api(): - c_run_program_fn = _legacy_C_ops.newir_run_program + c_run_program_fn = _legacy_C_ops.pir_run_program else: c_run_program_fn = _legacy_C_ops.run_program c_run_program_fn( @@ -235,7 +350,6 @@ def __call__(self, inputs): self._create_scope_vec( program_id=self.program_id, use_scope_cache=True ), - self._double_grads, self._cuda_graph_vec, *attrs, ) @@ -243,6 +357,19 @@ def __call__(self, inputs): restored_nest_out = self._restore_out(out_vars) return self._remove_no_value(restored_nest_out) + @cached_property + def origin_runable_program(self): + inputs = list( + filter(lambda x: isinstance(x, OpResult), self._inputs.tolist()) + ) + outputs = list( + filter(lambda x: isinstance(x, OpResult), self._outputs.tolist()) + ) + params = self._param_values + return RunableProgram( + self._origin_main_program, (inputs, params, outputs) + ) + def _sync_lr_value_with_scheduler(self): """Update lr_var value with calculated by lr_scheduler.""" main_program = self._origin_main_program @@ -262,212 +389,49 @@ def set_hooker(self, hooker): self._hooker = hooker def _get_scope(self, program_id=None, use_scope_cache=False): - if use_scope_cache: - if program_id not in self._scope_cache: - scope = core.Scope() - self._scope_cache[program_id] = [scope] - return scope - else: - for scope in self._scope_cache[program_id]: - if scope._can_reused: - return scope - scope = core.Scope() - self._scope_cache[program_id].append(scope) - return scope - else: + if not use_scope_cache: return core.Scope() - - @LazyInitialized - def _double_grads(self): - # TODO: check the affects. - return None + if program_id not in self._scope_cache: + self._scope_cache[program_id] = [] + cached_scopes = self._scope_cache[program_id] + for scope in cached_scopes: + if scope._can_reused: + return scope + scope = core.Scope() + cached_scopes.append(scope) + return scope # whole @switch_to_static_graph def _create_program(self, is_infer_mode=False): if is_infer_mode: - infer_program = self._origin_main_program.clone( - for_test=is_infer_mode + # TODO(xiongkun) who to transfer the pruning program? + infer_program = self.origin_runable_program.clone() + infer_program = PirPassContext.apply( + infer_program, self._build_strategy ) - if self._hooker: - infer_program = self._hooker.after_infer(infer_program) + # TODO(Aurelius84): Support this later. + # if self._hooker: + # infer_program = self._hooker.after_infer(infer_program) return infer_program else: - train_program = self._append_backward_desc( - self._origin_main_program - ) + train_program = self.origin_runable_program.clone() + train_program = self._append_backward_desc(train_program) # Note: Only set grad type once after initializing train program. So we put it here. self._set_grad_type(self._params, train_program) return train_program - @switch_to_static_graph - def _create_amp_program(self, is_infer_mode=False): - amp_program = self._origin_main_program.clone(for_test=is_infer_mode) - with program_guard(amp_program): - paddle.static.amp.fp16_utils.cast_model_to_fp16( - amp_program, self._amp_list, use_fp16_guard=False, level='O1' - ) - if is_infer_mode: - if self._hooker: - amp_program = self._hooker.after_infer(amp_program) - return amp_program - else: - train_amp_program = self._append_backward_desc(amp_program) - self._set_grad_type(self._params, train_amp_program) - return train_amp_program - - @switch_to_static_graph - def _create_pure_fp16_program(self, is_infer_mode=False): - pure_fp16_program = self._origin_main_program.clone( - for_test=is_infer_mode - ) - with program_guard(pure_fp16_program): - paddle.static.amp.fp16_utils.cast_model_to_fp16( - pure_fp16_program, self._amp_list, use_fp16_guard=False - ) - - if is_infer_mode: - if self._hooker: - pure_fp16_program = self._hooker.after_infer(pure_fp16_program) - return pure_fp16_program - else: - train_pure_fp16_program = self._append_backward_desc( - pure_fp16_program - ) - self._set_grad_type(self._params, train_pure_fp16_program) - return train_pure_fp16_program - - @switch_to_static_graph - def _create_forward_backward_train_program(self): - whole_program = self._train_program - forward_end_op_index = self.get_forward_end_op_idx(whole_program) - assert forward_end_op_index >= 0 - return self._get_forward_backward_program_form( - whole_program, forward_end_op_index - ) - - @switch_to_static_graph - def _create_forward_backward_train_amp_program(self): - whole_program = self._train_amp_program - forward_end_op_index = self.get_forward_end_op_idx(whole_program) - assert forward_end_op_index >= 0 - - return self._get_forward_backward_program_form( - whole_program, forward_end_op_index - ) - - @switch_to_static_graph - def _create_forward_backward_train_pure_fp16_program(self): - whole_program = self._train_pure_fp16_program - forward_end_op_index = self.get_forward_end_op_idx(whole_program) - assert forward_end_op_index >= 0 - - return self._get_forward_backward_program_form( - whole_program, forward_end_op_index - ) - - @LazyInitialized - def _train_program(self): - return self._create_program() - - @LazyInitialized - def _infer_program(self): - program, op_size = self._infer_info('fp32', self._create_program) - return self._build_infer_program(program, op_size) - - @LazyInitialized - def _train_amp_program(self): - return self._create_amp_program() - - @LazyInitialized - def _infer_amp_program(self): - program, op_size = self._infer_info('amp', self._create_amp_program) - return self._build_infer_program(program, op_size) - - @LazyInitialized - def _train_pure_fp16_program(self): - return self._create_pure_fp16_program() - - @LazyInitialized - def _infer_pure_fp16_program(self): - program, op_size = self._infer_info( - 'fp16', self._create_pure_fp16_program - ) - return self._build_infer_program(program, op_size) - - @LazyInitialized - def _train_forward_backward_program(self): - program = self._create_forward_backward_train_program() - return program - - @LazyInitialized - def _train_amp_forward_backward_program(self): - program = self._create_forward_backward_train_amp_program() - return program - - @LazyInitialized - def _empty_backward_program_for_eval(self): - return paddle.static.Program() - - @LazyInitialized - def _train_pure_fp16_forward_backward_program(self): - program = self._create_forward_backward_train_pure_fp16_program() - return program - - @LazyInitialized + @cached_property def _train_program_id(self): - program_id = paddle.utils._hash_with_id(self._train_program, self) + program_id = paddle.utils._hash_with_id(self.train_program, self) core._set_cached_executor_build_strategy( program_id, self._build_strategy ) return program_id - @LazyInitialized + @cached_property def _infer_program_id(self): - return paddle.utils._hash_with_id(self._infer_program, self) - - @LazyInitialized - def _train_amp_program_id(self): - program_id = paddle.utils._hash_with_id(self._train_amp_program, self) - core._set_cached_executor_build_strategy( - program_id, self._build_strategy - ) - return program_id - - @LazyInitialized - def _infer_amp_program_id(self): - return paddle.utils._hash_with_id(self._infer_amp_program, self) - - @LazyInitialized - def _train_pure_fp16_program_id(self): - program_id = paddle.utils._hash_with_id( - self._train_pure_fp16_program, self - ) - core._set_cached_executor_build_strategy( - program_id, self._build_strategy - ) - return program_id - - @LazyInitialized - def _infer_pure_fp16_program_id(self): - return paddle.utils._hash_with_id(self._infer_pure_fp16_program, self) - - def get_forward_end_op_idx(self, program): - return self._program_extra_info[ - paddle.utils._hash_with_id(program, self) - ]['forward_end_op_idx'] - - def get_program_extra(self, program): - if ( - paddle.utils._hash_with_id(program, self) - not in self._program_extra_info - ): - self._program_extra_info[ - paddle.utils._hash_with_id(program, self) - ] = {} - return self._program_extra_info[ - paddle.utils._hash_with_id(program, self) - ] + return paddle.utils._hash_with_id(self.infer_program, self) @property def program(self): @@ -484,73 +448,24 @@ def program_id(self): """ Return current train or eval program hash id. """ + if _in_amp_guard() or _in_pure_fp16_guard(): + raise NotImplementedError("not implement error.") if self.training: - if _in_amp_guard(): - return self._train_amp_program_id - elif _in_pure_fp16_guard(): - return self._train_pure_fp16_program_id - else: - return self._train_program_id + return self._train_program_id else: - if _in_amp_guard(): - return self._infer_amp_program_id - elif _in_pure_fp16_guard(): - return self._infer_pure_fp16_program_id - else: - return self._infer_program_id + return self._infer_program_id - @property + @cached_property def train_program(self): - if _in_amp_guard(): - return self._train_amp_program - elif _in_pure_fp16_guard(): - return self._train_pure_fp16_program - else: - return self._train_program + if _in_amp_guard() or _in_pure_fp16_guard(): + raise NotImplementedError("not implement error.") + return self._create_program() - @property + @cached_property def infer_program(self): - if _in_amp_guard(): - return self._infer_amp_program - elif _in_pure_fp16_guard(): - return self._infer_pure_fp16_program - else: - return self._infer_program - - @property - def forward_program(self): - if self.training: - if _in_amp_guard(): - progs = self._train_amp_forward_backward_program - elif _in_pure_fp16_guard(): - progs = self._train_pure_fp16_forward_backward_program - else: - progs = self._train_forward_backward_program - return progs[0] - else: - return self.infer_program - - @property - def backward_program(self): - if self.training: - if _in_amp_guard(): - progs = self._train_amp_forward_backward_program - elif _in_pure_fp16_guard(): - progs = self._train_pure_fp16_forward_backward_program - else: - progs = self._train_forward_backward_program - return progs[1] - else: - """ - Can't just return paddle.static.Program(), because self.backward_program is a property, - whenever we call this method, a tmp Program() object is created and is gc immediatly - after executed the following line in PartialProgramLayer.__call__. - - >>> self.backward_program.desc.global_block(), - - When we access RunProgramAPI, it's possible to get an invalid backward_program address. - """ - return self._empty_backward_program_for_eval + if _in_amp_guard() or _in_pure_fp16_guard(): + raise NotImplementedError("not implement error.") + return self._create_program(is_infer_mode=True) def _verify_program(self, main_program): """ @@ -643,21 +558,17 @@ def _insert_aggregation_ops_for_var(target_program, var): _insert_aggregation_ops_for_var(target_program, _var) @switch_to_static_graph - def _append_backward_desc(self, main_program): - program = main_program - - targets = list( - filter(lambda x: isinstance(x, OpResult), self._outputs.tolist()) - ) + def _append_backward_desc(self, train_runnable_program: RunableProgram): + program = train_runnable_program.program + targets = train_runnable_program.out_values + # TODO(@zhuoge): refine the interface, use runable_program to apply passes. if self._hooker: program, targets = self._hooker.before_append_backward( program, targets ) - self._outputs = NestSequence(targets, need_check=True) - inputs = list( - filter(lambda x: isinstance(x, OpResult), self._inputs.tolist()) - ) - combined_inputs = list(itertools.chain(inputs, self._param_values)) + inputs = train_runnable_program.x_values + params = train_runnable_program.param_values + combined_inputs = list(itertools.chain(inputs, params)) forward_end_idx = len(program.global_block().ops) if targets: with backend_guard(self._backend): @@ -693,7 +604,6 @@ def _append_backward_desc(self, main_program): ) = self._hooker.after_append_backward( program, targets, forward_end_idx ) - self._outputs = NestSequence(targets, need_check=True) # TODO: add later # self.prepare_gradient_aggregation( @@ -703,24 +613,23 @@ def _append_backward_desc(self, main_program): mapping_op_result = ( lambda x: x if isinstance(x, OpResult) else fake_op_result() ) - hash_id = paddle.utils._hash_with_id(program, self) - extra_info = self._program_extra_info.get(hash_id, {}) - extra_info['forward_inputs'] = inputs - extra_info['forward_outputs'] = targets - extra_info['forward_end_op_idx'] = forward_end_idx inputs_size = len(inputs) - extra_info['forward_inputs_grads'] = list( + x_grad_value = list( map(mapping_op_result, grad_info_map[0:inputs_size]) ) - extra_info['forward_params_grads'] = list( - map(mapping_op_result, grad_info_map[inputs_size:]) + p_grad_value = list(map(mapping_op_result, grad_info_map[inputs_size:])) + o_grad_value = list(map(mapping_op_result, forward_outputs_grads)) + backward_start_op_index = forward_end_idx + 2 * len( + list(filter(lambda r: r.stop_gradient is False, self._outputs)) ) - extra_info['forward_outputs_grads'] = list( - map(mapping_op_result, forward_outputs_grads) + backward_end_op_index = len(program.global_block().ops) + return RunableProgram( + program, + (inputs, params, targets), + (x_grad_value, p_grad_value, o_grad_value), + (0, forward_end_idx), + (backward_start_op_index, backward_end_op_index), ) - self._program_extra_info[hash_id] = extra_info - - return program def _prune_unused_params(self, program): """ @@ -741,33 +650,18 @@ def _prune_unused_params(self, program): self._params = required_params self._param_values = required_param_values - def _cast_fp16_if_pure_fp16(self, in_vars): - if _in_pure_fp16_guard(): - for i, var in enumerate(in_vars): - name = var.name - if ( - self.program.global_block().has_var(name) - and self.program.global_block().var(name).dtype - == paddle.float16 - ): - in_vars[i] = var.astype('float16') - in_vars[i].name = name - def _prepare_attributes(self): attrs = [ 'forward_global_block', - self.forward_program.global_block(), + self.program.forward_program.global_block(), 'backward_global_block', - self.backward_program.global_block(), + self.program.backward_program.global_block(), 'is_test', not self.training, 'program_id', self.program_id, ] - - for key, val in self.get_program_extra(self.forward_program)[ - 'program_attr' - ].items(): + for key, val in self.program.program_attr.items(): attrs.append(key) attrs.append(val) @@ -782,141 +676,6 @@ def _prepare_attributes(self): ) return attrs - @switch_to_static_graph - def _build_infer_program(self, infer_program, forward_end_op_index): - forward_skip_vars = self._parse_skip_gc_vars(infer_program) - builded_infer_program = add_build_strategy_for( - infer_program, - 0, - forward_end_op_index, - self._build_strategy, - forward_skip_vars, - ) - self._apply_inplace_pass(builded_infer_program, None) - return builded_infer_program - - @switch_to_static_graph - def _get_forward_backward_program_form( - self, whole_program, forward_end_op_index - ): - # NOTE(dev): We apply build_strategy for backward firstly to - # avoid skipping more gc variables. - forward_inputs_grads = self.get_program_extra(whole_program)[ - 'forward_inputs_grads' - ] - forward_inputs = self.get_program_extra(whole_program)['forward_inputs'] - forward_outputs = self.get_program_extra(whole_program)[ - 'forward_outputs' - ] - forward_parameters = self._param_values - forward_outputs_grads = self.get_program_extra(whole_program)[ - 'forward_outputs_grads' - ] - forward_params_grads = self.get_program_extra(whole_program)[ - 'forward_params_grads' - ] - backward_start_op_index = forward_end_op_index + 2 * len( - list(filter(lambda r: r.stop_gradient is False, self._outputs)) - ) - backward_end_op_index = len(whole_program.global_block().ops) - # For Backward process in CINN, all param@GRAD shoule be skipped for GC, because - # they will be shared in scope and used by optimizer. - - # TODO(xiongkun): consider cinn later. - # backward_skip_vars = self._parse_skip_gc_vars( - # whole_program - # ) + self._grad_var_names.get('param', []) - ( - forward_program, - backward_program, - ), program_attr = paddle.base.libpaddle.pir.program_split( - whole_program, - forward_inputs, - forward_parameters, - forward_outputs, - forward_inputs_grads, - forward_params_grads, - forward_outputs_grads, - [0, forward_end_op_index], - [backward_start_op_index, backward_end_op_index], - ) - self.get_program_extra(forward_program)["program_attr"] = program_attr - return [forward_program, backward_program] - - def _apply_inplace_pass(self, forward_program, backward_program): - attr_types = { - "use_cuda": "bool", - "mem_opt_skip_vars": "list[str]", - "for_partial_block": "bool", - } - empty_startup_program = paddle.static.Program() - use_cuda = True if core.is_compiled_with_cuda() else False - # skip data var - forward_mem_opt_skip_vars = self._parse_skip_gc_vars( - forward_program, backward_program - ) - backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program) - if forward_program: - attrs = { - "use_cuda": use_cuda, - "mem_opt_skip_vars": forward_mem_opt_skip_vars, - "for_partial_block": True, - } - if not os.getenv("FLAGS_enable_new_ir_in_executor"): - _apply_pass( - forward_program, - empty_startup_program, - "buffer_shared_inplace_pass", - attrs, - attr_types, - ) - if backward_program: - attrs = { - "use_cuda": use_cuda, - "mem_opt_skip_vars": backward_mem_opt_skip_vars, - "for_partial_block": True, - } - if not os.getenv("FLAGS_enable_new_ir_in_executor"): - _apply_pass( - backward_program, - empty_startup_program, - "buffer_shared_inplace_pass", - attrs, - attr_types, - ) - - @LazyInitialized - def _inout_var_names(self): - """ - Returns Variable Names from self._inputs and self.outputs - """ - var_names = [] - for var in self._inputs: - if isinstance(var, paddle.base.framework.Variable): - var_names.append(var.desc.name()) - for var in self._outputs: - if isinstance(var, paddle.base.framework.Variable): - var_names.append(var.desc.name()) - return var_names - - def _parse_skip_gc_vars(self, program, backward_program=None): - """ - Parse variables that need to skip GC after execute it. - If specify backward_program, it will keep the variables used in backward. - """ - # skip data var, DO NOT ignore this deepcopy - skip_vars = deepcopy(self._inout_var_names) - for var_name, var in program.global_block().vars.items(): - if var.is_data: - skip_vars.append(var_name) - - if backward_program: - for var_name in core.parse_safe_eager_deletion_skip_vars( - backward_program.desc, True - ): - skip_vars.append(var_name) - return skip_vars - def _prepare(self, inputs): """ Prepare inputs, outputs, attrs. @@ -981,13 +740,10 @@ def create_out(var_id): return input_vars, out_vars def _create_scope_vec(self, program_id=None, use_scope_cache=False): - # Hold forward variables - tmp_scope_vec = None inner_scope = self._get_scope( program_id=program_id, use_scope_cache=use_scope_cache ) - tmp_scope_vec = [inner_scope] - return tmp_scope_vec + return [inner_scope] def _create_cuda_graph_vec(self): var = core.eager.Tensor( @@ -1063,16 +819,15 @@ def _remove_no_value(self, out_vars): return out_vars - def _set_grad_type(self, params, train_program): + def _set_grad_type(self, params, train_program: RunableProgram): # NOTE: if user set sparse gradient mode, the param's gradient # will be SelectedRows, not LoDTensor. But tracer will just # set param grad Tensor by forward Tensor(LoDTensor) # If we don't change grad_var type here, RunProgramOp need # transform SelectedRows to LoDTensor forcibly, it may not # be user wanted result. - forward_params_grads = self.get_program_extra(train_program)[ - 'forward_params_grads' - ] + forward_params_grads = train_program.p_grad_values + train_program = train_program.program for param, value in zip(params, forward_params_grads): if is_fake_op_result(value): continue @@ -1089,19 +844,6 @@ def _set_grad_type(self, params, train_program): "only support selected_row and dense_tensor grad type." ) - def _remove_op_call_stack(self, main_program): - """ - Remove op's python call stack with redundant low-level error messages related to - transforamtions to avoid confusing users. - """ - assert isinstance(main_program, framework.Program) - for block in main_program.blocks: - for op in block.ops: - if op.has_attr("op_callstack"): - op._remove_attr("op_callstack") - - return main_program - def _check_params_all_inited(self, main_program): """ Check all params from main program are already initialized, see details as follows: @@ -1144,10 +886,3 @@ def partial_program_from(concrete_program, from_method=False): concrete_program.parameters, **concrete_program.kwargs, ) - - -@switch_to_static_graph -def add_build_strategy_for( - program, start_op_index, end_op_index, build_strategy=None, skip_vars=None -): - raise NotImplementedError("Not implemented yet.") diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 65da105499b20..29fb188598dd0 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -15,7 +15,6 @@ import collections import inspect import os -import textwrap import threading import warnings import weakref @@ -43,20 +42,19 @@ get_buffers, get_parameters, ) -from .newir_partial_program import ( - PartialProgramLayerHook as PirPartialProgramLayerHook, -) from .origin_info import ( attach_origin_info, create_and_update_origin_info_map, update_op_callstack_with_origin_info, ) from .partial_program import PartialProgramLayerHook +from .pir_partial_program import ( + PartialProgramLayerHook as PirPartialProgramLayerHook, +) from .utils import ( ALREADY_D2S, NO_SHAPE_VAR_TYPE, ast_to_func, - ast_to_source_code, backend_guard, func_to_source_code, input_specs_compatible, @@ -200,7 +198,7 @@ class CacheKey: 'class_instance', 'kwargs', '_spec_names_id', - '_new_ir_flags', + '_pir_flags', ] def __init__( @@ -230,9 +228,7 @@ def __init__( self._spec_names_id = _hash_spec_names( input_args_with_spec, input_kwargs_with_spec ) - self._new_ir_flags = os.environ.get( - 'FLAGS_enable_new_ir_in_executor', None - ) + self._pir_flags = os.environ.get('FLAGS_enable_pir_in_executor', None) @classmethod def from_func_and_args(cls, function_spec, args, kwargs, class_instance): @@ -276,7 +272,7 @@ def __hash__(self): self.class_instance, with_hook, is_train, - self._new_ir_flags, + self._pir_flags, ) ) @@ -475,11 +471,9 @@ def __call__(self, *args, **kwargs): if not in_dynamic_mode(): raise RuntimeError( - "Failed to run the callable object {} decorated by '@paddle.jit.to_static', " + f"Failed to run the callable object {self.dygraph_function} decorated by '@paddle.jit.to_static', " "because it is NOT in dynamic mode. Please disable the static graph mode to enter dynamic mode with the " - "following API: paddle.disable_static().".format( - self.dygraph_function - ) + "following API: paddle.disable_static()." ) return self._perform_call(*args, **kwargs) @@ -680,8 +674,8 @@ def function_spec(self): def raise_error_template(func_str): def _raise_error(*args, **kwargs): error_template = ( - "Can't call {func} when enable_fallback=True." - "Use paddle.jit.to_static(enable_fallback=False) instead." + "Can't call {func} when full_graph=False. " + "Use paddle.jit.to_static(full_graph=True) instead." ) raise RuntimeError(error_template.format(func=func_str)) @@ -692,8 +686,8 @@ class SymbolicStaticFunction(StaticFunction): def __init__(self, function, input_spec=None, **kwargs): if input_spec is not None: warnings.warn( - "\nSymbolic Trace don't support input_spec arguments. It will not produce any effect.\n" - "1. You can disable fallback mode by `paddle.jit.to_static(enable_fallback=False)` to switch to AST to static, then you can assign input spec.\n" + "full_graph=False don't support input_spec arguments. It will not produce any effect.\n" + "You can set full_graph=True, then you can assign input spec.\n" ) super().__init__(function, input_spec, **kwargs) self.last_call_input_spec = None @@ -800,7 +794,7 @@ def _perform_call(self, *args, **kwargs): else: logging_utils.warn( "Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'" - " if you can't handle this {} yourself.".format(type(e)) + f" if you can't handle this {type(e)} yourself." ) raise e @@ -1155,7 +1149,7 @@ def __init__( @staticmethod @switch_to_static_graph - def newir_from_func_spec( + def pir_from_func_spec( func_spec, input_spec, input_kwargs_spec, class_instance, **kwargs ): """ @@ -1191,10 +1185,10 @@ def newir_from_func_spec( with ir_static.program_guard(main_program, startup_program): with _to_static_mode_guard_(is_to_static=True): # 1. Adds `paddle.static.data` layers for input if needed - static_inputs = func_spec.newir_to_static_inputs_with_spec( + static_inputs = func_spec.pir_to_static_inputs_with_spec( input_spec, main_program ) - _kwargs = func_spec.newir_to_static_inputs_with_spec( + _kwargs = func_spec.pir_to_static_inputs_with_spec( input_kwargs_spec, main_program ) if class_instance: @@ -1223,7 +1217,7 @@ def newir_from_func_spec( raise # 3. Gets all ParamBases and buffered VarBases in the function - from ..newir_dy2static.parameter_recorder import ( + from ..pir_dy2static.parameter_recorder import ( _global_parameter_recorder, ) @@ -1536,7 +1530,7 @@ def _build_once(self, cache_key): enable_fallback = enable_prim try: if use_pir_api(): - concrete_program = ConcreteProgram.newir_from_func_spec( + concrete_program = ConcreteProgram.pir_from_func_spec( func_spec=cache_key.function_spec, input_spec=cache_key.input_args_with_spec, input_kwargs_spec=cache_key.input_kwargs_with_spec, @@ -1584,12 +1578,12 @@ def _build_once(self, cache_key): ) if use_pir_api(): - from .newir_partial_program import partial_program_from + from .pir_partial_program import partial_program_from partial_program = partial_program_from( concrete_program, cache_key.class_instance is not None ) - else: # TODO(new_ir): remove later. + else: # TODO(pir): remove later. from .partial_program import partial_program_from partial_program = partial_program_from( @@ -1762,37 +1756,6 @@ def __init__(self): self.enable_to_static = True def enable(self, enable_to_static): - """ - Enable or disable the converting from imperative to static graph by - ProgramTranslator globally. - - Args: - enable_to_static (bool): True or False to enable or disable converting to static. - - Returns: - None. - - Examples: - .. code-block:: python - - >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest') - >>> import paddle - >>> def func(x): - ... if paddle.mean(x) > 0: - ... x_v = x - 1 - ... else: - ... x_v = x + 1 - ... return x_v - ... - ... - >>> prog_trans = paddle.jit.dy2static.program_translator.ProgramTranslator() - - >>> x = paddle.ones([1, 2]) - >>> x_v = prog_trans.get_output(func, x) - >>> print(x_v) - Tensor(shape=[1, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - [[0., 0.]]) - """ check_type( enable_to_static, "enable_to_static", @@ -1801,274 +1764,6 @@ def enable(self, enable_to_static): ) self.enable_to_static = enable_to_static - def get_output(self, dygraph_func, *args, **kwargs): - """ - Returns the output dygraph Tensor for dygraph function. The dygraph - function will be translated into static graph function so the under - beneath numerical result will be calculated by static graph mode. - - Args: - dygraph_func (callable): the dygraph function. - *args (tuple): the input argument of dygraph_func. - **kwargs (dict): the input argument of dygraph_func. - - Returns: - Tensor or tuple of Tensors: the dygraph Tensor containing digital result. - - Examples: - .. code-block:: python - - >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest') - >>> import paddle - >>> def func(x): - ... if paddle.mean(x) > 0: - ... x_v = x - 1 - ... else: - ... x_v = x + 1 - ... return x_v - ... - ... - >>> prog_trans = paddle.jit.dy2static.program_translator.ProgramTranslator() - - >>> x = paddle.ones([1, 2]) - >>> x_v = prog_trans.get_output(func, x) - >>> print(x_v) - Tensor(shape=[1, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - [[0., 0.]]) - """ - assert callable( - dygraph_func - ), "Input dygraph_func is not a callable in ProgramTranslator.get_output" - - if not self.enable_to_static: - # Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message) - # will show up **only once**. - logging_utils.warn( - "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. " - "We will just return dygraph output. " - "Please call ProgramTranslator.enable(True) if you would like to get static output." - ) - return dygraph_func(*args, **kwargs) - try: - function_spec = FunctionSpec(dygraph_func) - cache_key = CacheKey.from_func_and_args( - function_spec, - args, - kwargs, - getattr(dygraph_func, '__self__', None), - ) - _, partial_program_layer = self._program_cache[cache_key] - - if args and isinstance(args[0], layers.Layer): - # Synchronize self.training attribute. - partial_program_layer.training = args[0].training - args = args[1:] - try: - return partial_program_layer(args) - except BaseException as e: - # NOTE: - # 1. If e is raised in compile time, e should have been attached to ERROR_DATA before; - # 2. If e raised in runtime, e should be attached to ERROR_DATA here. - if not hasattr(e, error.ERROR_DATA): - # runtime error - error.attach_error_data(e, in_runtime=True) - raise - except BaseException as e: - error_data = getattr(e, error.ERROR_DATA, None) - if error_data: - error_data.raise_new_exception() - else: - logging_utils.warn( - "Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'" - " if you can't handle this {} yourself.".format(type(e)) - ) - raise e - - def get_func(self, dygraph_func): - """ - Returns a callable function which converts imperative dygraph APIs of - the input dygraph_func into declarative net-building APIs, which means - it doesn't return immediate digital result as get_output does. - Users should handle Program and Executor by themselves. - - Args: - dygraph_func (callable): the dygraph function. - - Returns: - callable: converting imperative dygraph APIs into declarative - net-building APIs. - - Examples: - .. code-block:: python - - >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest') - >>> import paddle - >>> def func(x): - ... if paddle.mean(x) > 0: - ... x_v = x - 1 - ... else: - ... x_v = x + 1 - ... return x_v - ... - >>> prog_trans = paddle.jit.dy2static.program_translator.ProgramTranslator() - >>> static_func = prog_trans.get_func(func) - >>> print(callable(static_func)) - True - """ - assert callable( - dygraph_func - ), "Input dygraph_func is not a callable in ProgramTranslator.get_func" - - if not self.enable_to_static: - logging_utils.warn( - "The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable to False. We will " - "just return dygraph output. Please call ProgramTranslator.enable(True) if you would like to get static output." - ) - return dygraph_func - - static_func = convert_to_static(dygraph_func) - return static_func - - def get_program(self, dygraph_func, *args, **kwargs): - """ - Returns the translated static program and input/output Tensors from - dygraph function. The users can use the program to run by executor. - - Args: - dygraph_func (callable): the dygraph function. - *args (tuple): the input argument of dygraph_func. - **kwargs (dict): the input argument of dygraph_func. - - Returns: - tuple of (main_program, startup_program, inputs, outputs) whose - types are (Program, Program, list of Tensors, list of Tensors). - main_program: the converted main program. - startup_program: the converted startup program. - inputs: list of input Tensors which need to be fed. - outputs: list of output Tensors which users can fetch. - - Examples: - .. code-block:: python - - >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest') - >>> import paddle - >>> def func(x): - ... if paddle.mean(x) > 0: - ... x_v = x - 1 - ... else: - ... x_v = x + 1 - ... return x_v - ... - >>> prog_trans = paddle.jit.dy2static.program_translator.ProgramTranslator() - >>> x = paddle.ones([1, 2]) - >>> main_prog, start_prog, inputs, outputs = prog_trans.get_program(func, x) - >>> print([i.name for i in inputs]) - >>> # [u'generated_tensor_0'] the feed input Tensor name representing x - >>> print([o.name for o in outputs]) - >>> # [u'_generated_var_4'] the fetch output Tensor name representing x_v - """ - assert callable( - dygraph_func - ), "Input dygraph_func is not a callable in ProgramTranslator.get_program" - - if not self.enable_to_static: - logging_utils.warn( - "The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable to False." - "We will just return dygraph output. " - "Please call ProgramTranslator.enable(True) if you would like to get static output." - ) - return dygraph_func(*args, **kwargs) - - function_spec = FunctionSpec(dygraph_func) - cache_key = CacheKey.from_func_and_args( - function_spec, args, kwargs, getattr(dygraph_func, '__self__', None) - ) - concrete_program, partial_program_layer = self._program_cache[cache_key] - - # Note: concrete_program hold all input/output infos include non-Variable - input_vars = [ - var - for var in concrete_program.inputs - if isinstance(var, framework.Variable) - ] - output_vars = [ - var - for var in concrete_program.outputs - if isinstance(var, framework.Variable) - ] - - return ( - concrete_program.main_program, - concrete_program.startup_program, - input_vars, - output_vars, - ) - - def get_code(self, dygraph_func): - """ - Returns the translated static function string code from dygraph function. - - Args: - dygraph_func (callable): the dygraph function. - - Returns: - str: the string code of translated static function. - - Examples: - .. code-block:: python - - >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest') - >>> import paddle - >>> def func(x): - ... if paddle.mean(x) > 0: - ... x_v = x - 1 - ... else: - ... x_v = x + 1 - ... return x_v - ... - >>> prog_trans = paddle.jit.dy2static.program_translator.ProgramTranslator() - - >>> code = prog_trans.get_code(func) - >>> print(type(code)) - - """ - assert callable( - dygraph_func - ), "Input dygraph_func is not a callable in ProgramTranslator.get_code" - # Gets AST from dygraph function - - unwrap_func = unwrap(dygraph_func) - raw_code = inspect.getsource(unwrap_func) - code = textwrap.dedent(raw_code) - root = gast.parse(code) - - # Transform AST - dygraph_to_static = DygraphToStaticAst() - root = dygraph_to_static.get_static_ast(root) - - # Get source_code - source_code = ast_to_source_code(root) - return source_code - - def get_program_cache(self): - """ - Returns the ProgramCache instance. This method is used by PaddlePaddle - developers to manage program cache in ProgramTranslator. Normal users - don't have to call this method. - - Returns: - ProgramCache: ProgramCache instance of ProgramTranslator. - - Examples: - .. code-block:: python - - >>> import paddle - - >>> prog_trans = paddle.jit.dy2static.program_translator.ProgramTranslator() - >>> prog_cache = prog_trans.get_program_cache() - """ - return self._program_cache - def enable_to_static(enable_to_static_bool): """ diff --git a/python/paddle/jit/dy2static/utils_helper.py b/python/paddle/jit/dy2static/utils_helper.py index 601e3241d7464..3e10eafb4a480 100644 --- a/python/paddle/jit/dy2static/utils_helper.py +++ b/python/paddle/jit/dy2static/utils_helper.py @@ -18,7 +18,7 @@ import astor import numpy as np # noqa: F401 -import paddle # noqa: F401 +import paddle from paddle import base # noqa: F401 from paddle.base import dygraph, layers # noqa: F401 from paddle.base.dygraph import to_variable # noqa: F401 @@ -183,3 +183,14 @@ def type_from_annotation(annotation): # raise warning if not found warn("Currently we don't support annotation: %s" % annotation_str) return NodeVarType.UNKNOWN + + +def set_dynamic_shape(variable, shape_list): + if paddle.base.dygraph.base.in_to_static_mode(): + assert isinstance( + variable, paddle.base.framework.Variable + ), "In to_static mode, variable must be a Variable." + variable.desc.set_shape(shape_list) + else: + # in dygraph mode, dynamic shape is not needed, just do nothing. + return diff --git a/python/paddle/jit/newir_dy2static/__init__.py b/python/paddle/jit/pir_dy2static/__init__.py similarity index 100% rename from python/paddle/jit/newir_dy2static/__init__.py rename to python/paddle/jit/pir_dy2static/__init__.py diff --git a/python/paddle/jit/newir_dy2static/parameter_recorder.py b/python/paddle/jit/pir_dy2static/parameter_recorder.py similarity index 60% rename from python/paddle/jit/newir_dy2static/parameter_recorder.py rename to python/paddle/jit/pir_dy2static/parameter_recorder.py index 2bebff160c20e..91e24b1b22997 100644 --- a/python/paddle/jit/newir_dy2static/parameter_recorder.py +++ b/python/paddle/jit/pir_dy2static/parameter_recorder.py @@ -13,7 +13,6 @@ # limitations under the License. import paddle -from paddle.base import framework from ..dy2static.program_translator import _program_hash, synchronized @@ -43,7 +42,7 @@ def get(self, program, tensor): type=tensor.type, initializer=non_used_initializer, ) - if isinstance(tensor, framework.EagerParamBase): + if isinstance(tensor, paddle.Tensor): params.add(tensor) mappings[id(tensor)] = op_result return mappings[id(tensor)] @@ -61,4 +60,49 @@ def pop(self, program): return list(params), list(params_values) +class InplaceMap: + def __init__(self): + self.params_dict = {} + + @synchronized + def add(self, program, id, param): + """use the default_program as key, append param the parameter list.""" + key = _program_hash(program) + if key not in self.params_dict: + self.params_dict[key] = {} + + params = self.params_dict[key] + params[id] = param + + def get(self, program, id): + params = self.params_dict.get(_program_hash(program)) + if params is None: + return None + if id not in params: + return None + root_var = params[id] + saved = [] + while id(root_var) in params.keys(): + saved.append(root_var) + root_var = params[id(root_var)] + for var in saved: + params[id(var)] = root_var + return root_var + + def restore_checkpoint(self, checkpoint): + # InplaceMap is a nested effect. + # when enter a block, we should save a checkpoint + # when exit a block, we should restore a checkpoint + # for example: + # if cond > 0: + # x [:] = 0 + # return x + # x[:] only effect current cond block, we should restore in false block. + self.params_dict = checkpoint + + def save_checkpoint(self): + return dict(self.params_dict.items()) + + _global_parameter_recorder = ParametersRecorder() +_global_inplace_map = InplaceMap() diff --git a/python/paddle/jit/sot/__init__.py b/python/paddle/jit/sot/__init__.py index 1b45c0c55389b..c297cc0840ece 100644 --- a/python/paddle/jit/sot/__init__.py +++ b/python/paddle/jit/sot/__init__.py @@ -18,5 +18,4 @@ add_breakpoint, add_event, ) -from .opcode_translator.skip_files import skip_function # noqa: F401 from .translate import symbolic_translate # noqa: F401 diff --git a/python/paddle/jit/sot/infer_meta.py b/python/paddle/jit/sot/infer_meta.py index 8ea3ec28f19a4..a88338bdf2e74 100644 --- a/python/paddle/jit/sot/infer_meta.py +++ b/python/paddle/jit/sot/infer_meta.py @@ -182,7 +182,7 @@ def infer_meta_for_layer(layer, *args, **kwargs): assert isinstance( layer, paddle.nn.Layer ), f"Expect a Layer, but got {layer}." - layer = paddle.jit.to_static(layer, enable_fallback=False) + layer = paddle.jit.to_static(layer, full_graph=True) args_, kwargs_ = convert_meta_to_input_spec((args, kwargs)) diff --git a/python/paddle/jit/sot/opcode_translator/__init__.py b/python/paddle/jit/sot/opcode_translator/__init__.py index bf230190e3e11..392faa56e7126 100644 --- a/python/paddle/jit/sot/opcode_translator/__init__.py +++ b/python/paddle/jit/sot/opcode_translator/__init__.py @@ -13,3 +13,6 @@ # limitations under the License. from .transform import eval_frame_callback # noqa: F401 +from .skip_files import setup_skip_files + +setup_skip_files() diff --git a/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py index 67d656f4dcd75..c99ce0c552a8a 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py +++ b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py @@ -14,6 +14,7 @@ from __future__ import annotations +import gc import traceback import types from typing import List, Tuple @@ -228,3 +229,5 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction: raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e finally: simulator.cleanup() + del simulator + gc.collect() diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py index 61f72b267b2de..185ea35867a1d 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -26,9 +26,10 @@ from ...infer_meta import InferMetaCache, LayerInferMetaCache, MetaInfo from ...profiler import EventGuard, event_register -from ...symbolic.statement_ir import Symbol +from ...symbolic.statement_ir import Reference, Symbol from ...symbolic.symbolic_context import SymbolicTraceContext from ...utils import ( + ENV_SHOW_TRACKERS, NameGenerator, OrderedSet, inner_error_default_handler, @@ -37,9 +38,9 @@ log, log_do, map_if, - show_trackers, tmp_name_guard, ) +from ..instruction_utils import get_instructions from .guard import Guard, StringifyExpression, make_guard from .mutable_data import MutationDel, MutationNew, MutationSet from .pycode_generator import PyCodeGen @@ -241,16 +242,88 @@ def guard_fn(self) -> Guard: return make_guard(guards) - def start_compile_with_name_store(self, ret_vars, to_store_vars): + def _restore_origin_opcode(self, stack_vars, store_var_info, instr_idx): + class VariableLoader: + def __init__(self, store_var_info, pycode_gen): + self._store_var_info = store_var_info + self._pycode_gen: PyCodeGen = pycode_gen + + def load(self, var, allow_push_null=True): + if isinstance(var, NullVariable): + # PUSH_NULL is an opcode + if allow_push_null: + var.reconstruct(self._pycode_gen) + else: + # Avoid passing NULL as a parameter to the resume function + self._pycode_gen.gen_load_null_variable() + return + # only restored vars in stack, so used gen_load to process global var + self._pycode_gen.gen_load(self._store_var_info[var]) + + origin_instr = get_instructions(self.pycode_gen._origin_code) + + for instr in origin_instr[0:instr_idx]: + if ( + instr.opname == 'LOAD_FAST' + and instr.argval in self.pycode_gen._frame.f_locals.keys() + and isinstance( + self.pycode_gen._frame.f_locals[instr.argval], NullVariable + ) + ): + self.pycode_gen._frame.f_locals[instr.argval].reconstruct( + self.pycode_gen + ) + elif ( + instr.opname == 'LOAD_GLOBAL' + and instr.argval in self.pycode_gen._frame.f_globals.keys() + and isinstance( + self.pycode_gen._frame.f_globals[instr.argval], NullVariable + ) + ): + self.pycode_gen._frame.f_globals[instr.argval].reconstruct( + self.pycode_gen + ) + else: + self.pycode_gen.extend_instrs([instr]) + + nop = self.pycode_gen._add_instr("NOP") + + for instr in origin_instr: + if instr.jump_to == origin_instr[instr_idx]: + instr.jump_to = nop + + self.pycode_gen.hooks.append( + lambda: self.pycode_gen.extend_instrs( + iter(origin_instr[instr_idx + 1 :]) + ) + ) + + self.pycode_gen.gen_enable_eval_frame() + + name_gen = NameGenerator("__start_compile_saved_orig_") + + for var in stack_vars[::-1]: + store_var_info[var] = name_gen.next() + self.pycode_gen.gen_store_fast(store_var_info[var]) + + return VariableLoader(store_var_info, self.pycode_gen) + + def _build_compile_fn_with_name_store(self, ret_vars, to_store_vars): class VariableLoader: def __init__(self, index_for_load, pycode_gen): self._index_for_load = index_for_load - self._pycode_gen = pycode_gen + self._pycode_gen: PyCodeGen = pycode_gen - def load(self, var): + def load(self, var, allow_push_null=True): if isinstance(var, NullVariable): - var.reconstruct(self._pycode_gen) + # PUSH_NULL is an opcode + if allow_push_null: + var.reconstruct(self._pycode_gen) + else: + # Avoid passing NULL as a parameter to the resume function + self._pycode_gen.gen_load_null_variable() return + # all vars to be load are saved by this function, so load_fast is correct self._pycode_gen.gen_load_fast(self._index_for_load[var.id]) # var_id -> local_name mapping @@ -260,7 +333,8 @@ def load(self, var): ) self.start_compile(*(ret_vars + to_store_vars)) name_gen = NameGenerator("__start_compile_saved_") - for var in to_store_vars: + + for var in to_store_vars[::-1]: index_for_load[var.id] = name_gen.next() def _log_fn(): @@ -271,8 +345,8 @@ def _log_fn(): log_do(4, _log_fn) - for var in to_store_vars[::-1]: self.pycode_gen.gen_store_fast(index_for_load[var.id]) + return VariableLoader(index_for_load, self.pycode_gen) @event_register("start_compile", event_level=2) @@ -337,7 +411,7 @@ def start_compile(self, *ret_vars: VariableBase): self.restore_side_effects(self.side_effects.proxy_variables) self.pycode_gen.gen_enable_eval_frame() - tracker_output_path = show_trackers() + tracker_output_path = ENV_SHOW_TRACKERS.get() if tracker_output_path: from .tracker_viewer import view_tracker @@ -422,6 +496,7 @@ def get_opcode_executor_stack(): def call_layer( self, layer: PaddleLayerVariable, + weak_ref: bool, *args: VariableBase, **kwargs: VariableBase, ): @@ -438,7 +513,7 @@ def infer_meta_fn(layer, *metas, **kwmetas): def compute_fn(layer, inputs, outputs, stacks): self.sir_ctx.call_LAYER( - layer.value, + Reference(layer.value, weak_ref), inputs=inputs, outputs=outputs, stacks=stacks, @@ -547,6 +622,18 @@ def _find_tensor_outputs( Args: outputs: output variables """ + + def collect_related_dummy_tensor(var): + if isinstance(var.tracker, DummyTracker): + if isinstance(var, TensorVariable): + return [var] + else: + retval = [] + for inp in var.tracker.inputs: + retval.extend(collect_related_dummy_tensor(inp)) + return retval + return [] + output_tensors: OrderedSet[TensorVariable] = OrderedSet() # Find Tensor Variables from outputs. for output in outputs: @@ -554,6 +641,9 @@ def _find_tensor_outputs( if isinstance(output, TensorVariable): output_tensors.add(output) else: + for inp in output.tracker.inputs: + for _var in collect_related_dummy_tensor(inp): + output_tensors.add(_var) # Guard output that can not be traced. self.add_global_guarded_variable(output) # Find Tensor Variables from side effects Variables. diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index 240ca8f1b889e..052b89c1cc1e1 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -30,6 +30,7 @@ from ...profiler import EventGuard, event_register from ...psdb import NO_BREAKGRAPH_CODES from ...utils import ( + ENV_MIN_GRAPH_SIZE, BreakGraphError, FallbackError, InnerError, @@ -37,7 +38,6 @@ SotUndefinedVar, log, log_do, - min_graph_size, ) from ..custom_code import CustomCode from ..instruction_utils import ( @@ -227,7 +227,6 @@ def jump_break_graph_decorator(normal_jump: Callable): def inner(self: OpcodeExecutor, instr: Instruction): result = self.stack.top if isinstance(result, TensorVariable): - self.stack.pop() # fallback when in OpcodeExecutor # raise error in OpcodeInlineExecutor log(3, "[BreakGraph] jump break graph, because if tensor\n") @@ -327,7 +326,11 @@ class OpcodeExecutorBase: """ + class EmptyCode: + pass + call_stack: list[OpcodeExecutorBase] = [] + empty_code = EmptyCode() @staticmethod def validate_value(value): @@ -352,7 +355,7 @@ def __init__(self, code: types.CodeType, graph: FunctionGraph): self._current_line: int = -1 self._instructions = get_instructions(self._code) self._graph = graph - self.new_code: types.CodeType | None = None + self.new_code: types.CodeType | None = self.empty_code self.guard_fn = None self._name = "Executor" self._call_shape: tuple[ @@ -1460,6 +1463,7 @@ def __init__(self, frame: types.FrameType, **kwargs): def cleanup(self): self._graph.pycode_gen = None Dispatcher.graph = None + self.call_stack[:] = [] @event_register("OpcodeExecutor: _prepare_virtual_env", event_level=2) def _prepare_virtual_env(self): @@ -1505,7 +1509,39 @@ def _prepare_virtual_env(self): ) ) - def _create_resume_fn(self, index, stack_size=0): + def gen_compute_in_break_with_name_store(self, restore_names, instr_idx): + """ + branch 1: if the graph size is too small, just run in dygraph + branch 2: if the graph is big enough, create compiled_fn + + This api will generator opcodes in different situation, the generated codes + will do the same thing as origin code. + + restore_names: + the names used in resume functions, branch 2 will restore these values, + branch 1 also need these names for generating opcode, but they are not + needed to be restored + instr_idx: + the index for branch 1 to find the boundary and copy origin opcode + """ + if self._graph.sir_ctx.TOS.graph_size() < ENV_MIN_GRAPH_SIZE.get(): + store_var_info = {} + for name in restore_names: + _var = self.get_var(name) + if _var not in self.stack: + store_var_info[_var] = name + return self._graph._restore_origin_opcode( + list(self.stack), store_var_info, instr_idx + ) + else: + store_vars = list(self.stack) + for name in restore_names: + _var = self.get_var(name) + if _var not in self.stack: + store_vars.append(_var) + return self._graph._build_compile_fn_with_name_store([], store_vars) + + def _create_resume_fn(self, index, stack_size): """ Create a resume function and its inputs at the specified index. @@ -1522,7 +1558,7 @@ def _create_resume_fn(self, index, stack_size=0): return fn, inputs @fallback_when_occur_error - def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): + def _break_graph_in_jump(self, result: TensorVariable, instr: Instruction): """ Break the graph at a JUMP instruction. @@ -1532,7 +1568,10 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): """ self._graph.add_global_guarded_variable(result) - stack_size = len(self.stack) + # minus the bool value + stack_size = len(self.stack) - 1 + + # gen call static fn opcode if_fn, if_inputs = self._create_resume_fn( self.indexof(instr) + 1, stack_size ) @@ -1540,29 +1579,15 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): self.indexof(instr.jump_to), stack_size ) - # gen call static fn opcode - inputs_name = if_inputs | else_inputs - inputs_var = [ - self.get_var(name) - for name in inputs_name - if self.get_var(name) is not result - ] - ret_vars = [ - result, - ] + inputs_var - # Collect all the to store variables. - store_vars = [] - for stack_arg in self.stack: - store_vars.append(stack_arg) - for name in inputs_name: - store_vars.append(self.get_var(name)) + inputs_names = if_inputs | else_inputs - var_loader = self._graph.start_compile_with_name_store( - ret_vars, store_vars + var_loader = self.gen_compute_in_break_with_name_store( + inputs_names, self.indexof(instr) ) - # only pop the input of if/else resume fn, and keep the bool tensor result on the stack - for _ in inputs_var: - self._graph.pycode_gen.gen_pop_top() + + var_loader.load(result) + # the result is used by if opcode, and should not be input of resume_fn + self.stack.pop() # gen call if/else resume fn opcode if if_fn is not None: @@ -1570,8 +1595,10 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): if_fn, if_fn.__code__.co_name ) insert_index = len(self._graph.pycode_gen._instructions) - 1 - for stack_arg in self.stack: - var_loader.load(stack_arg) + for i, stack_arg in enumerate(self.stack): + var_loader.load( + stack_arg, allow_push_null=i >= len(self.stack) - 1 + ) for name in if_inputs: var_loader.load(self.get_var(name)) self._graph.pycode_gen.gen_call_function( @@ -1587,8 +1614,10 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): else_fn, else_fn.__code__.co_name ) jump_to = self._graph.pycode_gen._instructions[-1] - for stack_arg in self.stack: - var_loader.load(stack_arg) + for i, stack_arg in enumerate(self.stack): + var_loader.load( + stack_arg, allow_push_null=i >= len(self.stack) - 1 + ) for name in else_inputs: var_loader.load(self.get_var(name)) self._graph.pycode_gen.gen_call_function( @@ -1628,55 +1657,31 @@ def _break_graph_in_call( self.stack = origin_stack # gen call static fn opcode - ret_vars = [ - arg - for arg in self.stack - if isinstance(arg, (TensorVariable, ContainerVariable)) - ] + resume_input_name = analysis_inputs(self._instructions, index + 1) - ret_vars = ret_vars + [ - self.get_var(name) - for name in resume_input_name - if self.get_var(name) not in ret_vars - ] - # Collect all the to store variables. - store_vars = [] - for stack_arg in self.stack: - store_vars.append(stack_arg) - for name in resume_input_name: - store_vars.append(self.get_var(name)) - var_loader = self._graph.start_compile_with_name_store( - ret_vars, store_vars + var_loader = self.gen_compute_in_break_with_name_store( + resume_input_name, self.indexof(instr) ) - for _ in ret_vars: - self._graph.pycode_gen.gen_pop_top() - # gen graph break call fn opcode stack_effect = calc_stack_effect(instr) pop_n = push_n - stack_effect for i, stack_arg in enumerate(self.stack): - # Avoid passing NULL as a parameter to the resume function - if ( - isinstance(stack_arg, NullVariable) - and i < len(self.stack) - pop_n - ): - self._graph.pycode_gen.gen_load_object( - NullVariable(), f'null_var_{i}', push_null=False - ) - else: - var_loader.load(stack_arg) + var_loader.load( + stack_arg, allow_push_null=i >= len(self.stack) - pop_n + ) # gen call resume fn opcode # NOTE(SigureMo): In Python 3.11,we need generate KW_NAMES if the call shape is not None. self._graph.pycode_gen.gen_kw_names(self._call_shape) - self._graph.pycode_gen.add_pure_instructions([instr]) + self._graph.pycode_gen.extend_instrs([instr]) self.stack.pop_n(pop_n) stack_size = len(self.stack) + push_n resume_fn, _ = self._create_resume_fn(index + 1, stack_size) + if resume_fn: self._graph.pycode_gen.gen_load_object( resume_fn, resume_fn.__code__.co_name @@ -1699,30 +1704,12 @@ def _break_graph_in_call( def transform(self): self.run() - if self.new_code is None: + if self.new_code is self.empty_code: raise InnerError("OpExecutor return a empty new_code.") - # stopped by RETURN_VALUE and has sir len is enough => disable_eval_frame - simulate_complete = bool(self.stop_state == "Return") - if simulate_complete: - if self._graph.sir_ctx.TOS.graph_size() < min_graph_size(): - raise FallbackError( - "Fallback after simulate for reasons.", - disable_eval_frame=True, - ) - else: - # if simulate stop with graph successfully, the all codes will be - # surrounded by the eval_frame triggers which exist in self.new_code - # we need not set disable_eval_frame=False here (for it already is) - return ( - CustomCode(self.new_code, True), - self.guard_fn, - ) - else: - # if return because breakgraph, need open eval_frame - return ( - CustomCode(self.new_code, False), - self.guard_fn, - ) + return ( + CustomCode(self.new_code, self.new_code is None), + self.guard_fn, + ) def _gen_loop_body_between( self, inputs: list, for_iter_idx: int, start: int, end: int @@ -1839,9 +1826,9 @@ def _break_graph_in_for_loop( log(3, "[Resumed Function]: break graph in loop create loop body as\n") log_do(3, lambda: dis.dis(loop_body_fn)) - # 0.3 create after loop part function + # 0.3 create after loop part function, minus 1 for iterator after_loop_fn, fn_inputs = self._create_resume_fn( - loop_body_end_idx, len(self.stack) + loop_body_end_idx, len(self.stack) - 1 ) total_inputs = OrderedSet(list(fn_inputs) + list(loop_body_inputs[:-1])) @@ -1852,23 +1839,17 @@ def _break_graph_in_for_loop( for name in total_inputs if name in chain(self._locals, self._cells) ] - ret_vars = [self.get_var(name) for name in ret_names] - store_vars = [ret_vars[idx] for idx in range(len(ret_names))] - store_vars.extend(iter(self.stack)) - store_vars.append(iterator.get_hold()) - var_loader = self._graph.start_compile_with_name_store( - ret_vars, store_vars - ) - for _ in ret_vars: - self._graph.pycode_gen.gen_pop_top() + var_loader = self.gen_compute_in_break_with_name_store( + ret_names, self.indexof(for_iter) + ) - # 2. restore vars - for idx in range(len(ret_names)): - var_loader.load(ret_vars[idx]) - self._graph.pycode_gen.gen_store(ret_names[idx], self._code) + # 2. restore vars with origin name + for name in ret_names: + var_loader.load(self.get_var(name)) + self._graph.pycode_gen.gen_store(name, self._code) - # 3. setup vars which is created in loop + # 3. setup vars which is created in loop as Undefind undefined_names = set() for name in loop_body_inputs[:-1]: if not self.has_var(name, all_used_vars[name]): @@ -1876,12 +1857,9 @@ def _break_graph_in_for_loop( self._graph.pycode_gen.gen_load_const(SotUndefinedVar()) self._graph.pycode_gen.gen_store(name, self._code) - # close eval_frame - # TODO: need support effective strategies - # self._graph.pycode_gen.gen_disable_eval_frame() - # 4.1 load iterator - iterator.reconstruct(self._graph.pycode_gen) + var_loader.load(iterator) + self.stack.pop() # 4.2 gen FOR_ITER and unpack data self._graph.pycode_gen.extend_instrs( @@ -1925,10 +1903,6 @@ def _break_graph_in_for_loop( for_iter.jump_to = nop jump_if_break.jump_to = nop - # open eval_frame - # TODO: need support effective strategies - # self._graph.pycode_gen.gen_enable_eval_frame() - # 8. call after_loop_fn self._graph.pycode_gen.gen_load_object( after_loop_fn, after_loop_fn.__code__.co_name @@ -2047,17 +2021,20 @@ def FOR_ITER(self, instr): try: if not isinstance(iterator, SequenceIterVariable): - raise BreakGraphError() + raise BreakGraphError( + f"Can not simulate iterator of {type(iterator)}." + ) backup_iter_idx = iterator.idx self._inline_call_for_loop(iterator, instr) self._lasti = self.indexof(instr.jump_to) except BreakGraphError as e: - log(3, f"{e}") + log(3, f"[FOR_ITER] sim for loop failed for: {e}\n") if backup_iter_idx: iterator.idx = backup_iter_idx self._graph.remove_global_guarded_variable(iterator) + self.stack.push(iterator) self._break_graph_in_for_loop(iterator, instr) return Stop(state="BreakGraph") @@ -2066,8 +2043,12 @@ def RETURN_VALUE(self, instr: Instruction): len(self.stack) == 1 ), f"Stack must have one element, but get {len(self.stack)} elements." ret_val = self.stack.pop() - self._graph.start_compile(ret_val) - self._graph.pycode_gen.gen_return() - self.new_code = self._graph.pycode_gen.gen_pycode() + if self._graph.sir_ctx.TOS.graph_size() < ENV_MIN_GRAPH_SIZE.get(): + py_codegen = PyCodeGen(self._frame) + self.new_code = py_codegen.replace_null_variable() + else: + self._graph.start_compile(ret_val) + self._graph.pycode_gen.gen_return() + self.new_code = self._graph.pycode_gen.gen_pycode() self.guard_fn = self._graph.guard_fn return Stop(state="Return") diff --git a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py index d8ddb23d15fc1..29764afdca4eb 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py +++ b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py @@ -21,6 +21,7 @@ import random import sys import types +from functools import cached_property from typing import TYPE_CHECKING import opcode @@ -432,6 +433,7 @@ def __init__( self._f_globals = frame.f_globals self._instructions = [] self.disable_eval_frame = disable_eval_frame + self.hooks = [] if self.disable_eval_frame: self.gen_disable_eval_frame() @@ -492,16 +494,21 @@ def gen_pycode(self) -> types.CodeType: Returns: CodeType: The generated code object. """ + for hook in self.hooks: + hook() + self.hooks.clear() + self.insert_prefix_instructions() modify_instrs(self._instructions) modify_vars(self._instructions, self._code_options) new_code = gen_new_opcode( self._instructions, self._code_options, PYCODE_ATTRIBUTES ) + return new_code def gen_resume_fn_at( - self, index: int, stack_size: int = 0 + self, index: int, stack_size: int ) -> tuple[None | types.FunctionType, OrderedSet[str]]: """ Generates a resume function at the specified index in the instruction list. @@ -514,6 +521,7 @@ def gen_resume_fn_at( tuple: The resume function object and the inputs to the function. """ + self._instructions = get_instructions(self._origin_code) # TODO(dev): could give an example code here? if self._instructions[index].opname == 'RETURN_VALUE': @@ -521,6 +529,7 @@ def gen_resume_fn_at( inputs = analysis_inputs(self._instructions, index) fn_name = ResumeFnNameFactory().next() stack_arg_str = fn_name + '_stack_{}' + self._instructions = ( [ gen_instr('LOAD_FAST', argval=stack_arg_str.format(i)) @@ -537,13 +546,12 @@ def gen_resume_fn_at( + list(inputs) + [ var_name - for var_name in self._origin_code.co_varnames + for var_name in self._code_options['co_varnames'] if var_name not in inputs ] ) self.update_code_name(fn_name, is_resumed_fn=True) - new_code = self.gen_pycode() if len(new_code.co_freevars) + len(new_code.co_cellvars) > 0: raise FallbackError("Break graph in closure is not support.") @@ -551,6 +559,12 @@ def gen_resume_fn_at( return fn, inputs + @cached_property + def global_null_variable(self): + from .variables.basic import NullVariable + + return NullVariable() + def gen_disable_eval_frame(self): """ Generates instructions to disable the evaluation frame. @@ -744,6 +758,13 @@ def gen_load_object(self, obj, obj_name: str, push_null: bool = True): self._f_globals[obj_name] = obj self.gen_load_global(obj_name, push_null=push_null) + def gen_load_null_variable(self): + """ + Generate the bytecode for loading a null variable. + """ + null_var = self.global_null_variable + self.gen_load_object(null_var, "___null_var", push_null=False) + def gen_load_fast(self, name): """ Generate the bytecode for loading a local variable. @@ -1005,12 +1026,6 @@ def gen_return(self): def gen_get_iter(self): self._add_instr("GET_ITER") - def add_pure_instructions(self, instructions): - """ - add instructions and do nothing. - """ - self._instructions.extend(instructions) - def _add_instr(self, *args, **kwargs): instr = gen_instr(*args, **kwargs) self._instructions.append(instr) @@ -1048,8 +1063,17 @@ def replace_null_variable(self): ): has_null_variable = True self._frame.f_locals[instr.argval].reconstruct(self) + elif ( + instr.opname == 'LOAD_GLOBAL' + and instr.argval in self._frame.f_globals.keys() + and isinstance( + self._frame.f_globals[instr.argval], NullVariable + ) + ): + has_null_variable = True + self._frame.f_globals[instr.argval].reconstruct(self) else: - self.add_pure_instructions([instr]) + self.extend_instrs([instr]) if has_null_variable: new_code = self.gen_pycode() diff --git a/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py b/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py index e7389de5b8805..e7dec76fbea78 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py @@ -206,6 +206,9 @@ def top(self, value): assert len(self) > 0, "stack is empty" self.peek[1] = value + def __contains__(self, value): + return value in self._data + def __iter__(self): return iter(self._data) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py index ba0a7f51c91a0..35495c651a40e 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py @@ -335,6 +335,9 @@ def __eq__(self, var): else: return self.id == var.id + def __hash__(self): + return hash(self.id) + return SotTensor(self.id) raise BreakGraphError( diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py index 819580710beba..ecc2e3216f7e4 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py @@ -266,7 +266,7 @@ def __init__( def call_function(self, /, *args, **kwargs): if is_break_graph_tensor_methods(self.method_name): - raise BreakGraphError() + raise BreakGraphError("call break_graph_tensor_method.") return self.graph.call_tensor_method(self.method_name, *args, **kwargs) def bind(self, instance: VariableBase, name: str): @@ -522,7 +522,10 @@ def __init__( def call_function(self, /, *args, **kwargs): self.graph.add_global_guarded_variable(self) - return self.graph.call_layer(self, *args, **kwargs) + # when layer is created in forward function, we use strong ref because it can't have + # weigths and buffers, see PaddleLayerClassVariable for details. + weak_ref = not isinstance(self.tracker, CreateLayerTracker) + return self.graph.call_layer(self, weak_ref, *args, **kwargs) def make_stringify_guard(self) -> list[StringifyExpression]: if isinstance(self.tracker, CreateLayerTracker): @@ -740,10 +743,18 @@ class PaddleLayerClassVariable(ClassVariable): def __init__(self, class_: type, graph: FunctionGraph, tracker: Tracker): super().__init__(class_, graph, tracker) + def check_no_weight_and_buffers(self, paddle_layer): + has_parameters = len(paddle_layer.parameters()) > 0 + has_buffers = len(paddle_layer.buffers()) > 0 + return not has_parameters and not has_buffers + def call_function(self, /, *args, **kwargs): input_py_args = [var.get_py_value() for var in args] input_py_kwargs = {k: v.get_py_value() for k, v in kwargs.items()} new_layer = self.value(*input_py_args, **input_py_kwargs) + assert self.check_no_weight_and_buffers( + new_layer + ), "You have created a layer in to_static function which may have Potential bugs. please create it in __init__/main function." return PaddleLayerVariable( new_layer, self.graph, CreateLayerTracker(self, args, kwargs) ) diff --git a/python/paddle/jit/sot/opcode_translator/skip_files.py b/python/paddle/jit/sot/opcode_translator/skip_files.py index 7753309debce9..ca7f3552ad6ac 100644 --- a/python/paddle/jit/sot/opcode_translator/skip_files.py +++ b/python/paddle/jit/sot/opcode_translator/skip_files.py @@ -19,7 +19,6 @@ import copy import copyreg import dataclasses -import distutils import enum import functools import importlib @@ -55,8 +54,6 @@ import paddle -from ..utils import log - NEED_SKIP_THIRD_PARTIY_MODULES = { abc, collections, @@ -94,7 +91,6 @@ codecs, uuid, setuptools, - distutils, warnings, } @@ -105,6 +101,11 @@ NEED_SKIP_THIRD_PARTIY_MODULES.add(sre_compile) NEED_SKIP_THIRD_PARTIY_MODULES.add(sre_parse) +if sys.version_info < (3, 12): + import distutils + + NEED_SKIP_THIRD_PARTIY_MODULES.add(distutils) + def _strip_init_py(s): return re.sub(r"__init__.py$", "", s) @@ -131,47 +132,15 @@ def _module_dir(m: types.ModuleType): f"^({'|'.join(map(re.escape, skip_file_names))})" ) -customed_skip_code = set() - no_skip_code = {paddle.nn.Sequential.forward.__code__} +with_graph_codes = ( + paddle.nn.Layer.__call__.__code__, + paddle.nn.Layer._dygraph_call_func.__code__, +) + -def need_skip_path(filepath: str) -> bool: - """ - Check if the file should be skipped and not transcribed. - - Args: - filepath: The path of the file to check. - - Returns: - bool: True if the file should be skipped. - """ - if not filepath.startswith("<"): - filepath = os.path.abspath(filepath) - return bool(skip_file_name_re.match(filepath)) - - -def skip_function(function): - customed_skip_code.add(function.__code__) - return function - - -def need_skip(frame): - pycode = frame.f_code - if pycode in no_skip_code: - return False - if pycode in customed_skip_code: - log(3, f"Skip frame by code: {pycode}\n") - return True - filename = pycode.co_filename - if sys.version_info >= (3, 11) and filename.startswith(" CustomCode: with EventGuard( f"eval_frame_callback: {frame.f_code.co_name}", event_level=2 ): - # is generator - if frame.f_code.co_flags & 0x20 > 0: - return CustomCode(None, True) + log_format( + 2, "[eval_frame_callback] start to translate: {}\n", frame.f_code + ) + log_do(4, partial(print_locals, frame)) + + log_format(3, "[transform] OriginCode: {}\n", frame.f_code.co_name) + log_do(3, lambda: dis.dis(frame.f_code)) + + custom_code = OpcodeExecutorCache()(frame, **kwargs) - # NOTE(SigureMo): Temporary fallback when code has exception handling. - if sys.version_info >= (3, 11) and frame.f_code.co_exceptiontable: - log( + if custom_code.code is None: + log_format( 3, - f"[eval_frame_callback] {frame.f_code} has co_exceptiontable\n", + "[transform] NewCode (same as origin code): {}\n", + frame.f_code.co_name, ) - return CustomCode(None, False) - - if need_skip(frame): - log(3, f"[eval_frame_callback] skip {frame.f_code}\n") - custom_code = CustomCode(None, False) - new_code = frame.f_code else: - log( - 2, f"[eval_frame_callback] start to translate: {frame.f_code}\n" - ) - log_do(4, partial(print_locals, frame)) - - log(3, f"[transform] OriginCode: {frame.f_code.co_name}\n") - log_do(3, lambda: dis.dis(frame.f_code)) - - custom_code = OpcodeExecutorCache()(frame, **kwargs) - - if custom_code.code is None: - log( - 3, - "[transform] NewCode (same as origin code): " - + frame.f_code.co_name - + "\n", - ) - new_code = frame.f_code - else: - log( - 3, - "[transform] NewCode: " + custom_code.code.co_name + "\n", - ) - log_do(3, lambda: dis.dis(custom_code.code)) - new_code = custom_code.code - - # just check those codes which need open eval_frame - if ( - custom_code.disable_eval_frame is False - and CodeStatus().is_code_without_graph(new_code) - ): - log( + log_format( 3, - "[eval_frame_callback] Code has no graph, block it.\n", + "[transform] NewCode: {}\n", + custom_code.code.co_name, ) - return CustomCode(None, True) + log_do(3, lambda: dis.dis(custom_code.code)) return custom_code diff --git a/python/paddle/jit/sot/symbolic/compile_cache.py b/python/paddle/jit/sot/symbolic/compile_cache.py index 8fa7444ff0684..b189f9ce2278d 100644 --- a/python/paddle/jit/sot/symbolic/compile_cache.py +++ b/python/paddle/jit/sot/symbolic/compile_cache.py @@ -14,18 +14,23 @@ from __future__ import annotations +import inspect from typing import TYPE_CHECKING import paddle +from paddle.amp.auto_cast import amp_state +from paddle.base.data_feeder import convert_dtype +from paddle.framework import _dygraph_tracer from ..profiler import EventGuard from ..utils import ( Cache, - CodeStatus, GraphLogger, Singleton, StepInfoManager, + log, log_do, + map_if, ) from .interpreter import compile_sir @@ -33,6 +38,14 @@ from .symbolic_context import SymbolicTraceContext +def trace_back_frames(): + frame = inspect.currentframe() + while frame.f_back is not None: + frame = frame.f_back + code = frame.f_code + paddle.framework.core.sot_set_with_graph(code) + + def clear_eager_tensor_name(output_tensors): for output_tensor in output_tensors: output_tensor.name = "" @@ -49,15 +62,39 @@ def __init__(self, compiled_fn, SIR): self.concrete_program = None self.SIR = SIR # for debug + def amp_cast_inputs(self, args, kwargs): + """Prepare inputs for amp, cast float16 into float32 if needed.""" + current_amp_state = amp_state() + if current_amp_state is None: + return args, kwargs + # skip if not gpu / xpu / custom place + tracer = _dygraph_tracer() + if not ( + tracer._expected_place.is_gpu_place() + or tracer._expected_place.is_xpu_place() + or tracer._expected_place.is_custom_place() + ): + return args, kwargs + amp_dtype = convert_dtype(current_amp_state["dtype"]) + log(3, f"[AMP] Cast {amp_dtype} into float32\n") + return map_if( + (args, kwargs), + pred=lambda x: isinstance(x, paddle.Tensor) + and convert_dtype(x.dtype) == amp_dtype, + true_fn=lambda x: x.cast(paddle.float32), + false_fn=lambda x: x, + ) + def __call__(self, *args, **kwargs): with EventGuard(f"FallbackWrapper: {self.SIR.name}"): if StepInfoManager().need_back_trace: - CodeStatus().trace_back_frames() + trace_back_frames() log_do( 2, lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR), ) + args, kwargs = self.amp_cast_inputs(args, kwargs) log_do( 4, lambda: print( @@ -137,7 +174,7 @@ def value_fn(self, context: SymbolicTraceContext, sir_name: str, **kwargs): compile_sir(context, sir_name), build_strategy=build_strategy, backend=backend, - enable_fallback=False, + full_graph=True, ), context.get_sir(sir_name), ) diff --git a/python/paddle/jit/sot/symbolic/statement_ir.py b/python/paddle/jit/sot/symbolic/statement_ir.py index 11a08f36acd9d..1e0ab465e0bd8 100644 --- a/python/paddle/jit/sot/symbolic/statement_ir.py +++ b/python/paddle/jit/sot/symbolic/statement_ir.py @@ -22,12 +22,26 @@ import weakref from typing import Any, Callable -import paddle from paddle.utils import is_sequence, map_structure from ..utils import NameGenerator, OrderedSet, Singleton, flatten_extend +class Reference: # to unify weak_ref and strong_ref + def __init__(self, value, is_weak): + self.is_weak = is_weak + if is_weak is True: + self.ref = weakref.ref(value) + else: + self.ref = value + + def __call__(self): + if self.is_weak is True: + return self.ref() + else: + return self.ref + + class Symbol: """ Symbol is used to distinguish a string and a `math variable`. @@ -139,7 +153,7 @@ def __init__( class LayerStatement(Statement): def __init__( self, - layer: paddle.nn.Layer, + layer: Reference, # Reference of paddle.nn.Layer inputs: list[Symbol], outputs: list[Symbol], stacks: list[str], @@ -147,7 +161,7 @@ def __init__( super().__init__( "layer", layer.__class__.__name__, inputs, outputs, stacks ) - self.layer = weakref.ref(layer) + self.layer = layer class StatementIR: diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index a1f26ea622772..02fc91e62873b 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -12,7 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .code_status import CodeStatus # noqa: F401 +from .envs import ( # noqa: F401 + ENV_CLEAN_CODE, + ENV_COST_MODEL, + ENV_MIN_GRAPH_SIZE, + ENV_SHOW_TRACKERS, + ENV_SOT_LOG_LEVEL, + ENV_STRICT_MODE, + cost_model_guard, + strict_mode_guard, + min_graph_size_guard, +) from .exceptions import ( # noqa: F401 BreakGraphError, FallbackError, @@ -35,7 +45,6 @@ SotUndefinedVar, StepInfoManager, StepState, - cost_model, count_if, current_tmp_name_records, execute_time, @@ -52,11 +61,10 @@ list_find_index_by_id, log, log_do, + log_format, map_if, map_if_extend, meta_str, - min_graph_size, no_eval_frame, - show_trackers, tmp_name_guard, ) diff --git a/python/paddle/jit/sot/utils/code_status.py b/python/paddle/jit/sot/utils/code_status.py deleted file mode 100644 index 007e77f634004..0000000000000 --- a/python/paddle/jit/sot/utils/code_status.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from enum import Enum - -import paddle - -from .utils import Singleton, log - - -class CodeState(Enum): - UNKNOW = 1 - WITH_GRAPH = 2 - WITHOUT_GRAPH = 3 - - -class CodeInfo: - def __init__(self): - self.state = CodeState.UNKNOW - self.counter = 0 - - def __repr__(self): - return f"state: {self.state}, counter: {self.counter}" - - -@Singleton -class CodeStatus: - WITH_GRAPH_API = [ - paddle.nn.Layer.__call__.__code__, - paddle.nn.Layer._dygraph_call_func.__code__, - ] - - def __init__(self): - self.code_map = {} - self.setup_code_map() - - def setup_code_map(self): - for code in self.WITH_GRAPH_API: - info = CodeInfo() - info.state = CodeState.WITH_GRAPH - self.code_map[code] = info - - def clear(self): - self.code_map.clear() - self.setup_code_map() - - def is_code_without_graph(self, code): - if code not in self.code_map: - info = CodeInfo() - self.code_map[code] = info - else: - info = self.code_map[code] - - if info.state == CodeState.WITHOUT_GRAPH: - return True - if info.state == CodeState.UNKNOW: - info.counter += 1 - if info.counter >= 10: - log( - 3, - f"[CodeStatus] Switch state to WITHOUT_GRAPH for {code}\n", - ) - info.state = CodeState.WITHOUT_GRAPH - return False - - def trace_back_frames(self): - frame = inspect.currentframe() - while frame.f_back is not None: - frame = frame.f_back - code = frame.f_code - if code in self.code_map: - info = self.code_map[code] - if info.state != CodeState.WITH_GRAPH: - log( - 3, - f"[CodeStatus] Switch state to WITH_GRAPH for {code}\n", - ) - info.state = CodeState.WITH_GRAPH diff --git a/python/paddle/jit/sot/utils/envs.py b/python/paddle/jit/sot/utils/envs.py new file mode 100644 index 0000000000000..a7d8ceafb7f0c --- /dev/null +++ b/python/paddle/jit/sot/utils/envs.py @@ -0,0 +1,49 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager + +from paddle.utils.environments import ( + BooleanEnvironmentVariable, + EnvironmentVariableGuard, + IntegerEnvironmentVariable, + StringEnvironmentVariable, +) + +ENV_COST_MODEL = BooleanEnvironmentVariable("COST_MODEL", False) +ENV_MIN_GRAPH_SIZE = IntegerEnvironmentVariable("MIN_GRAPH_SIZE", 10) +ENV_SOT_LOG_LEVEL = IntegerEnvironmentVariable("SOT_LOG_LEVEL", 0) +ENV_STRICT_MODE = BooleanEnvironmentVariable("STRICT_MODE", False) +ENV_SHOW_TRACKERS = StringEnvironmentVariable("SHOW_TRACKERS", "") +ENV_CLEAN_CODE = BooleanEnvironmentVariable("CLEAN_CODE", False) + + +@contextmanager +def cost_model_guard(value: bool): + with EnvironmentVariableGuard(ENV_COST_MODEL, value): + yield + + +@contextmanager +def strict_mode_guard(value: bool): + with EnvironmentVariableGuard(ENV_STRICT_MODE, value): + yield + + +@contextmanager +def min_graph_size_guard(value: int): + with EnvironmentVariableGuard(ENV_MIN_GRAPH_SIZE, value): + yield diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py index ad4ff3faaa4dc..b4980688f0834 100644 --- a/python/paddle/jit/sot/utils/utils.py +++ b/python/paddle/jit/sot/utils/utils.py @@ -16,7 +16,6 @@ import builtins import inspect -import os import time import types import weakref @@ -32,6 +31,12 @@ from paddle.framework import Program from paddle.utils import flatten, map_structure +from .envs import ( + ENV_CLEAN_CODE, + ENV_COST_MODEL, + ENV_SOT_LOG_LEVEL, + ENV_STRICT_MODE, +) from .paddle_api_config import ( break_graph_set, paddle_api_list, @@ -41,14 +46,6 @@ T = TypeVar("T") -def cost_model(): - return os.environ.get("COST_MODEL", "True") == "True" - - -def min_graph_size(): - return int(os.environ.get("MIN_GRAPH_SIZE", 10)) - - class Singleton(Generic[T]): def __init__(self, cls: type[T]): self._cls = cls @@ -119,17 +116,23 @@ def next(self): def log(level, *args): - cur_level = int(os.environ.get("SOT_LOG_LEVEL", "0")) + cur_level = ENV_SOT_LOG_LEVEL.get() if level <= cur_level: print(*args, end="") def log_do(level, fn): - cur_level = int(os.environ.get("SOT_LOG_LEVEL", "0")) + cur_level = ENV_SOT_LOG_LEVEL.get() if level <= cur_level: fn() +def log_format(level, str, *args): + cur_level = ENV_SOT_LOG_LEVEL.get() + if level <= cur_level: + print(str.format(*args), end="") + + def no_eval_frame(func): def no_eval_frame_func(*args, **kwargs): old_cb = paddle.framework.core.set_eval_frame(None) @@ -287,15 +290,11 @@ def meta_str(shape, dtype, stop_gradient): def is_strict_mode(): - return os.environ.get("STRICT_MODE", "0") == "1" - - -def show_trackers() -> str | None: - return os.environ.get("SHOW_TRACKERS", None) + return ENV_STRICT_MODE.get() def is_clean_code() -> bool: - return os.environ.get('CLEAN_CODE', "False") == "True" + return ENV_CLEAN_CODE.get() def list_find_index_by_id(li: list[Any], item: Any) -> int: @@ -623,7 +622,9 @@ class StepInfo: def __init__(self): self.step_count = -1 self.state = ( - StepState.COLLECT_INFO if cost_model() else StepState.RUN_SOT + StepState.COLLECT_INFO + if ENV_COST_MODEL.get() + else StepState.RUN_SOT ) self.dyn_time_costs = [] self.avg_dyn_time = 0 diff --git a/python/paddle/jit/translated_layer.py b/python/paddle/jit/translated_layer.py index 766e72e0553e8..a5070da21d734 100644 --- a/python/paddle/jit/translated_layer.py +++ b/python/paddle/jit/translated_layer.py @@ -336,7 +336,6 @@ def __init__(self, program_desc): # input, output, persistable, double_grads var info self._input_descs = [] self._output_descs = [] - self._double_grad_descs = [] self._persistable_names = [] self._grad_var_names = {} @@ -409,10 +408,6 @@ def output_descs(self): def persistable_names(self): return self._persistable_names - @property - def double_grad_descs(self): - return self._double_grad_descs - @property def scope(self): return self._inner_scope @@ -465,12 +460,6 @@ def _preprocess(self, program_desc): for op_idx in reversed(ops_to_remove): root_block._remove_op(op_idx, op_idx + 1) - for i in range(program_desc.num_blocks()): - block_desc = program_desc.block(i) - for var_desc in block_desc.all_vars(): - if "@GRAD" in var_desc.name(): - self._double_grad_descs.append(var_desc) - # 2. Input processing, reverse feed vars self._input_descs.reverse() @@ -512,6 +501,11 @@ def _preprocess(self, program_desc): @switch_to_static_graph def _append_scale_to_output(self, program): + # 0. scale don't support bool output, we skip append scale for it + for out_desc in self._output_descs: + if out_desc.dtype() == core.VarDesc.VarType.BOOL: + return + # 1. append scale & save var scale_output_vars = [] with framework.program_guard(program): @@ -949,17 +943,6 @@ def _run_dygraph(instance, input, program_holder): # hold forward variables tmp_scope_vec = [program_holder.scope] - double_grad_vars = [] - for var_desc in program_holder.double_grad_descs: - var = core.eager.Tensor( - dtype=var_desc.dtype(), - dims=var_desc.shape(), - name=var_desc.name(), - type=var_desc.type(), - persistable=False, - ) - double_grad_vars.append(var) - # 2. run program by op trace_program = ( program_holder.infer_program @@ -1021,7 +1004,6 @@ def _run_dygraph(instance, input, program_holder): _valid_vars(persistable_vars), _valid_vars(output_vars), tmp_scope_vec, - _valid_vars(double_grad_vars), None, *attrs, ) diff --git a/python/paddle/metric/metrics.py b/python/paddle/metric/metrics.py index 2760b448a7027..e87ab0068ff2c 100644 --- a/python/paddle/metric/metrics.py +++ b/python/paddle/metric/metrics.py @@ -17,10 +17,10 @@ import numpy as np import paddle -from paddle import _legacy_C_ops +from paddle import _C_ops, _legacy_C_ops from ..base.data_feeder import check_variable_and_dtype -from ..base.framework import _create_tensor +from ..base.framework import _create_tensor, in_pir_mode from ..base.layer_helper import LayerHelper from ..framework import in_dynamic_mode @@ -92,6 +92,7 @@ class Metric(metaclass=abc.ABCMeta): ... correct = pred == label ... return paddle.cast(correct, dtype='float32') ... + With the :code:`compute`, we split some calculations to OPs (which may run on GPU devices, will be faster), and only fetch 1 tensor with shape as [N, 5] instead of 2 tensors with shapes as [N, 10] and [N, 1]. @@ -807,6 +808,10 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None): ) return _acc + elif in_pir_mode(): + topk_out, topk_indices = paddle.topk(input, k=k) + _acc, _, _ = _C_ops.accuracy(topk_out, topk_indices, label) + return _acc helper = LayerHelper("accuracy", **locals()) check_variable_and_dtype( diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 1ef27639abd13..54d84d5b74931 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -12,200 +12,153 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO: import all neural network related api under this directory, -# including layers, linear, conv, rnn etc. -from .layer.container import LayerList # noqa: F401 -from .layer.container import ParameterList # noqa: F401 -from .layer.container import Sequential # noqa: F401 - -from .clip import ClipGradByGlobalNorm # noqa: F401 -from .clip import ClipGradByNorm # noqa: F401 -from .clip import ClipGradByValue # noqa: F401 -from .decode import BeamSearchDecoder # noqa: F401 -from .decode import dynamic_decode # noqa: F401 -from .layer.activation import CELU # noqa: F401 -from .layer.activation import ELU # noqa: F401 -from .layer.activation import GELU # noqa: F401 -from .layer.activation import Tanh # noqa: F401 -from .layer.activation import Hardshrink # noqa: F401 -from .layer.activation import Hardswish # noqa: F401 -from .layer.activation import Hardtanh # noqa: F401 -from .layer.activation import PReLU # noqa: F401 -from .layer.activation import ReLU # noqa: F401 -from .layer.activation import ReLU6 # noqa: F401 -from .layer.activation import SELU # noqa: F401 -from .layer.activation import Silu # noqa: F401 -from .layer.activation import LeakyReLU # noqa: F401 -from .layer.activation import Sigmoid # noqa: F401 -from .layer.activation import Hardsigmoid # noqa: F401 -from .layer.activation import LogSigmoid # noqa: F401 -from .layer.activation import Softmax # noqa: F401 -from .layer.activation import Softmax2D # noqa: F401 -from .layer.activation import Softplus # noqa: F401 -from .layer.activation import Softshrink # noqa: F401 -from .layer.activation import Softsign # noqa: F401 -from .layer.activation import Swish # noqa: F401 -from .layer.activation import Mish # noqa: F401 -from .layer.activation import Tanhshrink # noqa: F401 -from .layer.activation import ThresholdedReLU # noqa: F401 -from .layer.activation import LogSoftmax # noqa: F401 -from .layer.activation import Maxout # noqa: F401 -from .layer.activation import RReLU # noqa: F401 -from .layer.common import Pad1D # noqa: F401 -from .layer.common import Pad2D # noqa: F401 -from .layer.common import ZeroPad2D # noqa: F401 -from .layer.common import Pad3D # noqa: F401 -from .layer.common import CosineSimilarity # noqa: F401 -from .layer.common import Embedding # noqa: F401 -from .layer.common import Linear # noqa: F401 -from .layer.common import Identity # noqa: F401 -from .layer.common import Flatten # noqa: F401 -from .layer.common import Upsample # noqa: F401 -from .layer.common import UpsamplingNearest2D # noqa: F401 -from .layer.common import UpsamplingBilinear2D # noqa: F401 -from .layer.common import Bilinear # noqa: F401 -from .layer.common import Dropout # noqa: F401 -from .layer.common import Dropout2D # noqa: F401 -from .layer.common import Dropout3D # noqa: F401 -from .layer.common import AlphaDropout # noqa: F401 -from .layer.common import Unfold # noqa: F401 -from .layer.common import Fold # noqa: F401 -from .layer.common import Unflatten # noqa: F401 -from .layer.pooling import AvgPool1D # noqa: F401 -from .layer.pooling import AvgPool2D # noqa: F401 -from .layer.pooling import AvgPool3D # noqa: F401 -from .layer.pooling import MaxPool1D # noqa: F401 -from .layer.pooling import MaxPool2D # noqa: F401 -from .layer.pooling import MaxPool3D # noqa: F401 -from .layer.pooling import MaxUnPool1D # noqa: F401 -from .layer.pooling import MaxUnPool2D # noqa: F401 -from .layer.pooling import MaxUnPool3D # noqa: F401 -from .layer.pooling import AdaptiveAvgPool1D # noqa: F401 -from .layer.pooling import AdaptiveAvgPool2D # noqa: F401 -from .layer.pooling import AdaptiveAvgPool3D # noqa: F401 -from .layer.pooling import AdaptiveMaxPool1D # noqa: F401 -from .layer.pooling import AdaptiveMaxPool2D # noqa: F401 -from .layer.pooling import AdaptiveMaxPool3D # noqa: F401 - -from .layer.conv import Conv1D # noqa: F401 -from .layer.conv import Conv2D # noqa: F401 -from .layer.conv import Conv3D # noqa: F401 -from .layer.conv import Conv1DTranspose # noqa: F401 -from .layer.conv import Conv2DTranspose # noqa: F401 -from .layer.conv import Conv3DTranspose # noqa: F401 - -from .layer.loss import BCEWithLogitsLoss # noqa: F401 -from .layer.loss import CrossEntropyLoss # noqa: F401 -from .layer.loss import HSigmoidLoss # noqa: F401 -from .layer.loss import MSELoss # noqa: F401 -from .layer.loss import L1Loss # noqa: F401 -from .layer.loss import NLLLoss # noqa: F401 -from .layer.loss import PoissonNLLLoss # noqa: F401 -from .layer.loss import BCELoss # noqa: F401 -from .layer.loss import KLDivLoss # noqa: F401 -from .layer.loss import MarginRankingLoss # noqa: F401 -from .layer.loss import MultiLabelSoftMarginLoss -from .layer.loss import CTCLoss # noqa: F401 -from .layer.loss import RNNTLoss # noqa: F401 -from .layer.loss import SmoothL1Loss # noqa: F401 -from .layer.loss import HingeEmbeddingLoss # noqa: F401 -from .layer.loss import CosineEmbeddingLoss # noqa: F401 -from .layer.loss import MultiMarginLoss -from .layer.loss import TripletMarginWithDistanceLoss -from .layer.loss import TripletMarginLoss -from .layer.loss import SoftMarginLoss -from .layer.loss import GaussianNLLLoss - -from .layer.norm import BatchNorm # noqa: F401 -from .layer.norm import SyncBatchNorm # noqa: F401 -from .layer.norm import GroupNorm # noqa: F401 -from .layer.norm import LayerNorm # noqa: F401 -from .layer.norm import SpectralNorm # noqa: F401 -from .layer.norm import InstanceNorm1D # noqa: F401 -from .layer.norm import InstanceNorm2D # noqa: F401 -from .layer.norm import InstanceNorm3D # noqa: F401 -from .layer.norm import BatchNorm1D # noqa: F401 -from .layer.norm import BatchNorm2D # noqa: F401 -from .layer.norm import BatchNorm3D # noqa: F401 -from .layer.norm import LocalResponseNorm # noqa: F401 - -from .layer.rnn import RNNCellBase # noqa: F401 -from .layer.rnn import SimpleRNNCell # noqa: F401 -from .layer.rnn import LSTMCell # noqa: F401 -from .layer.rnn import GRUCell # noqa: F401 -from .layer.rnn import RNN # noqa: F401 -from .layer.rnn import BiRNN # noqa: F401 -from .layer.rnn import SimpleRNN # noqa: F401 -from .layer.rnn import LSTM # noqa: F401 -from .layer.rnn import GRU # noqa: F401 - -from .layer.transformer import MultiHeadAttention # noqa: F401 -from .layer.transformer import TransformerEncoderLayer # noqa: F401 -from .layer.transformer import TransformerEncoder # noqa: F401 -from .layer.transformer import TransformerDecoderLayer # noqa: F401 -from .layer.transformer import TransformerDecoder # noqa: F401 -from .layer.transformer import Transformer # noqa: F401 -from .layer.distance import PairwiseDistance # noqa: F401 - -from .layer.vision import PixelShuffle # noqa: F401 -from .layer.vision import PixelUnshuffle # noqa: F401 -from .layer.vision import ChannelShuffle # noqa: F401 -from .layer.container import LayerDict # noqa: F401 - -from .layer.layers import Layer # noqa: F401 - -from .utils.spectral_norm_hook import spectral_norm +from . import functional, initializer, quant, utils # noqa: F401 +from .clip import ClipGradByGlobalNorm, ClipGradByNorm, ClipGradByValue +from .decode import BeamSearchDecoder, dynamic_decode # TODO: remove loss, keep it for too many used in unittests from .layer import loss # noqa: F401 - -from . import utils # noqa: F401 -from . import functional # noqa: F401 -from . import initializer # noqa: F401 -from . import quant # noqa: F401 - -# TODO: remove 'diag_embed', 'remove_weight_norm', 'weight_norm' months later. -from paddle.utils import deprecated - - -@deprecated( - since="2.0.0", - update_to="paddle.nn.functional.diag_embed", - level=1, - reason="diag_embed in paddle.nn will be removed in future", +from .layer.activation import ( + CELU, + ELU, + GELU, + SELU, + Hardshrink, + Hardsigmoid, + Hardswish, + Hardtanh, + LeakyReLU, + LogSigmoid, + LogSoftmax, + Maxout, + Mish, + PReLU, + ReLU, + ReLU6, + RReLU, + Sigmoid, + Silu, + Softmax, + Softmax2D, + Softplus, + Softshrink, + Softsign, + Swish, + Tanh, + Tanhshrink, + ThresholdedReLU, ) -def diag_embed(*args): - ''' - alias name of paddle.nn.functional.diag_embed - ''' - return functional.diag_embed(*args) - - -@deprecated( - since="2.0.0", - update_to="paddle.nn.utils.remove_weight_norm", - level=1, - reason="remove_weight_norm in paddle.nn will be removed in future", +from .layer.common import ( + AlphaDropout, + Bilinear, + CosineSimilarity, + Dropout, + Dropout2D, + Dropout3D, + Embedding, + Flatten, + Fold, + Identity, + Linear, + Pad1D, + Pad2D, + Pad3D, + Unflatten, + Unfold, + Upsample, + UpsamplingBilinear2D, + UpsamplingNearest2D, + ZeroPad2D, ) -def remove_weight_norm(*args): - ''' - alias name of paddle.nn.utils.remove_weight_norm - ''' - return utils.remove_weight_norm(*args) - -@deprecated( - since="2.0.0", - update_to="paddle.nn.utils.weight_norm", - level=1, - reason="weight_norm in paddle.nn will be removed in future", +# TODO: import all neural network related api under this directory, +# including layers, linear, conv, rnn etc. +from .layer.container import LayerDict, LayerList, ParameterList, Sequential +from .layer.conv import ( + Conv1D, + Conv1DTranspose, + Conv2D, + Conv2DTranspose, + Conv3D, + Conv3DTranspose, ) -def weight_norm(*args): - ''' - alias name of paddle.nn.utils.weight_norm - ''' - return utils.weight_norm(*args) - +from .layer.distance import PairwiseDistance +from .layer.layers import Layer +from .layer.loss import ( + BCELoss, + BCEWithLogitsLoss, + CosineEmbeddingLoss, + CrossEntropyLoss, + CTCLoss, + GaussianNLLLoss, + HingeEmbeddingLoss, + HSigmoidLoss, + KLDivLoss, + L1Loss, + MarginRankingLoss, + MSELoss, + MultiLabelSoftMarginLoss, + MultiMarginLoss, + NLLLoss, + PoissonNLLLoss, + RNNTLoss, + SmoothL1Loss, + SoftMarginLoss, + TripletMarginLoss, + TripletMarginWithDistanceLoss, +) +from .layer.norm import ( + BatchNorm, + BatchNorm1D, + BatchNorm2D, + BatchNorm3D, + GroupNorm, + InstanceNorm1D, + InstanceNorm2D, + InstanceNorm3D, + LayerNorm, + LocalResponseNorm, + SpectralNorm, + SyncBatchNorm, +) +from .layer.pooling import ( + AdaptiveAvgPool1D, + AdaptiveAvgPool2D, + AdaptiveAvgPool3D, + AdaptiveMaxPool1D, + AdaptiveMaxPool2D, + AdaptiveMaxPool3D, + AvgPool1D, + AvgPool2D, + AvgPool3D, + MaxPool1D, + MaxPool2D, + MaxPool3D, + MaxUnPool1D, + MaxUnPool2D, + MaxUnPool3D, +) +from .layer.rnn import ( + GRU, + LSTM, + RNN, + BiRNN, + GRUCell, + LSTMCell, + RNNCellBase, + SimpleRNN, + SimpleRNNCell, +) +from .layer.transformer import ( + MultiHeadAttention, + Transformer, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, +) +from .layer.vision import ChannelShuffle, PixelShuffle, PixelUnshuffle +from .utils.spectral_norm_hook import spectral_norm # noqa: F401 __all__ = [ 'BatchNorm', diff --git a/python/paddle/nn/clip.py b/python/paddle/nn/clip.py index 13742ae6d9be8..994800ccfe500 100644 --- a/python/paddle/nn/clip.py +++ b/python/paddle/nn/clip.py @@ -643,6 +643,12 @@ def __init__( self.group_name = group_name assert isinstance(auto_skip_clip, bool) self.auto_skip_clip = auto_skip_clip + # TODO(zhiqiu): Now, in dygraph mode async_add_n is always used. + # However, in static mode, it is only used in auto_parallel mode + # by setting self._async_add_n to True. The reason is that there + # are so many hard code depends on `add_n` in the legacy static + # manual hybrid-parallel. + self._async_add_n = None def __str__(self): return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm) @@ -749,6 +755,13 @@ def _static_clip(self, params_grads): sum_square_list_fp16 = [] sum_square_list_bf16 = [] sum_square_list_fp32 = [] + + def _add_n(var_list): + if self._async_add_n: + return paddle.stack(var_list).sum() + else: + return paddle.add_n(var_list) + with framework.name_scope('gradient_clip'): for p, g in params_grads: if g is None: @@ -794,7 +807,7 @@ def _static_clip(self, params_grads): global_norm_var = [] if len(sum_square_list_fp16) > 0: - global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16) + global_norm_var_fp16 = _add_n(sum_square_list_fp16) if ( sum_square_list_fp32 or sum_square_list @@ -806,7 +819,7 @@ def _static_clip(self, params_grads): else: global_norm_var.append(global_norm_var_fp16) if len(sum_square_list_bf16) > 0: - global_norm_var_bf16 = paddle.add_n(sum_square_list_bf16) + global_norm_var_bf16 = _add_n(sum_square_list_bf16) if ( sum_square_list_fp32 or sum_square_list @@ -818,7 +831,7 @@ def _static_clip(self, params_grads): else: global_norm_var.append(global_norm_var_bf16) if len(sum_square_list_fp32) > 0: - global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32) + global_norm_var_fp32 = _add_n(sum_square_list_fp32) if sum_dtype == 'float32': global_norm_var.append(global_norm_var_fp32) else: @@ -827,11 +840,11 @@ def _static_clip(self, params_grads): ) if len(sum_square_list) > 0: # fp64 - global_norm_var_other_dtype = paddle.add_n(sum_square_list) + global_norm_var_other_dtype = _add_n(sum_square_list) global_norm_var.append(global_norm_var_other_dtype) global_norm_var = ( - paddle.add_n(global_norm_var) + _add_n(global_norm_var) if len(global_norm_var) > 1 else global_norm_var[0] ) @@ -919,9 +932,12 @@ def _process_context(self, context, param, grad): self.context = context def _create_operators(self, param, grad): + def async_add_n(var_list): + return paddle.stack(var_list).sum() + group_scale_name = self.group_name + "_scale" if group_scale_name not in self.context: - group_norm_var = paddle.add_n(self.context[self.group_name]) + group_norm_var = async_add_n(self.context[self.group_name]) group_norm_var = paddle.sqrt(x=group_norm_var) clip_var = self.context[self.group_name + "_clip"] group_scale_var = paddle.divide( diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 608587becd952..efe3c2adc910e 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -15,130 +15,138 @@ # TODO: import all neural network related api under this directory, # including layers, linear, conv, rnn etc. -from .activation import celu # noqa: F401 -from .activation import elu # noqa: F401 -from .activation import elu_ # noqa: F401 -from .activation import gelu # noqa: F401 -from .activation import hardshrink # noqa: F401 -from .activation import hardtanh # noqa: F401 -from .activation import hardtanh_ # noqa: F401 -from .activation import hardsigmoid # noqa: F401 -from .activation import hardswish # noqa: F401 -from .activation import leaky_relu # noqa: F401 -from .activation import leaky_relu_ # noqa: F401 -from .activation import log_sigmoid # noqa: F401 -from .activation import maxout # noqa: F401 -from .activation import prelu # noqa: F401 -from .activation import relu # noqa: F401 -from .activation import relu_ # noqa: F401 -from .activation import relu6 # noqa: F401 -from .activation import selu # noqa: F401 -from .activation import sigmoid # noqa: F401 -from .activation import silu # noqa: F401 -from .activation import softmax # noqa: F401 -from .activation import softmax_ # noqa: F401 -from .activation import softplus # noqa: F401 -from .activation import softshrink # noqa: F401 -from .activation import softsign # noqa: F401 -from .activation import swish # noqa: F401 -from .activation import mish # noqa: F401 -from .activation import tanh # noqa: F401 -from .activation import tanh_ # noqa: F401 -from .activation import tanhshrink # noqa: F401 -from .activation import thresholded_relu # noqa: F401 -from .activation import thresholded_relu_ # noqa: F401 -from .activation import log_softmax # noqa: F401 -from .activation import glu # noqa: F401 -from .activation import gumbel_softmax # noqa: F401 -from .activation import rrelu # noqa: F401 -from .common import dropout # noqa: F401 -from .common import dropout2d # noqa: F401 -from .common import dropout3d # noqa: F401 -from .common import alpha_dropout # noqa: F401 -from .common import label_smooth # noqa: F401 -from .common import pad # noqa: F401 -from .common import zeropad2d # noqa: F401 -from .common import cosine_similarity # noqa: F401 -from .common import unfold # noqa: F401 -from .common import fold -from .common import interpolate # noqa: F401 -from .common import upsample # noqa: F401 -from .common import bilinear # noqa: F401 -from .common import class_center_sample # noqa: F401 -from .conv import conv1d # noqa: F401 -from .conv import conv1d_transpose # noqa: F401 -from .common import linear # noqa: F401 -from .conv import conv2d # noqa: F401 -from .conv import conv2d_transpose # noqa: F401 -from .conv import conv3d # noqa: F401 -from .conv import conv3d_transpose # noqa: F401 -from .distance import pairwise_distance # noqa: F401 -from .extension import diag_embed # noqa: F401 -from .extension import sequence_mask -from .loss import binary_cross_entropy # noqa: F401 -from .loss import binary_cross_entropy_with_logits # noqa: F401 -from .loss import cross_entropy # noqa: F401 -from .loss import dice_loss # noqa: F401 -from .loss import hsigmoid_loss # noqa: F401 -from .loss import kl_div # noqa: F401 -from .loss import l1_loss # noqa: F401 -from .loss import log_loss # noqa: F401 -from .loss import margin_ranking_loss # noqa: F401 -from .loss import mse_loss # noqa: F401 -from .loss import nll_loss # noqa: F401 -from .loss import poisson_nll_loss # noqa: F401 -from .loss import npair_loss # noqa: F401 -from .loss import sigmoid_focal_loss # noqa: F401 -from .loss import smooth_l1_loss # noqa: F401 -from .loss import softmax_with_cross_entropy # noqa: F401 -from .loss import margin_cross_entropy # noqa: F401 -from .loss import square_error_cost # noqa: F401 -from .loss import ctc_loss # noqa: F401 -from .loss import rnnt_loss # noqa: F401 -from .loss import hinge_embedding_loss # noqa: F401 -from .loss import cosine_embedding_loss # noqa: F401 -from .loss import multi_margin_loss -from .loss import multi_label_soft_margin_loss -from .loss import triplet_margin_with_distance_loss -from .loss import triplet_margin_loss -from .loss import soft_margin_loss -from .loss import gaussian_nll_loss - -from .norm import batch_norm # noqa: F401 -from .norm import instance_norm # noqa: F401 -from .norm import layer_norm # noqa: F401 -from .norm import local_response_norm # noqa: F401 -from .norm import normalize # noqa: F401 -from .pooling import avg_pool1d # noqa: F401 -from .pooling import avg_pool2d # noqa: F401 -from .pooling import avg_pool3d # noqa: F401 -from .pooling import max_pool1d # noqa: F401 -from .pooling import max_pool2d # noqa: F401 -from .pooling import max_pool3d # noqa: F401 - -from .pooling import adaptive_max_pool1d # noqa: F401 -from .pooling import adaptive_max_pool2d # noqa: F401 -from .pooling import adaptive_max_pool3d # noqa: F401 -from .pooling import adaptive_avg_pool1d # noqa: F401 -from .pooling import adaptive_avg_pool2d # noqa: F401 -from .pooling import adaptive_avg_pool3d # noqa: F401 -from .pooling import max_unpool1d # noqa: F401 -from .pooling import max_unpool2d # noqa: F401 -from .pooling import max_unpool3d # noqa: F401 - -from .vision import affine_grid # noqa: F401 -from .vision import grid_sample # noqa: F401 -from .vision import pixel_shuffle # noqa: F401 -from .vision import pixel_unshuffle # noqa: F401 -from .vision import channel_shuffle # noqa: F401 -from .input import one_hot # noqa: F401 -from .input import embedding # noqa: F401 -from .extension import gather_tree # noqa: F401 -from .extension import temporal_shift # noqa: F401 - +from .activation import ( + celu, + elu, + elu_, + gelu, + glu, + gumbel_softmax, + hardshrink, + hardsigmoid, + hardswish, + hardtanh, + hardtanh_, + leaky_relu, + leaky_relu_, + log_sigmoid, + log_softmax, + maxout, + mish, + prelu, + relu, + relu6, + relu_, + rrelu, + selu, + sigmoid, + silu, + softmax, + softmax_, + softplus, + softshrink, + softsign, + swish, + tanh, + tanh_, + tanhshrink, + thresholded_relu, + thresholded_relu_, +) +from .common import ( + alpha_dropout, + bilinear, + class_center_sample, + cosine_similarity, + dropout, + dropout2d, + dropout3d, + fold, + interpolate, + label_smooth, + linear, + pad, + unfold, + upsample, + zeropad2d, +) +from .conv import ( + conv1d, + conv1d_transpose, + conv2d, + conv2d_transpose, + conv3d, + conv3d_transpose, +) +from .distance import pairwise_distance +from .extension import diag_embed, gather_tree, sequence_mask, temporal_shift +from .flash_attention import ( # noqa: F401 + scaled_dot_product_attention, + sdp_kernel, +) +from .input import embedding, one_hot +from .loss import ( + binary_cross_entropy, + binary_cross_entropy_with_logits, + cosine_embedding_loss, + cross_entropy, + ctc_loss, + dice_loss, + gaussian_nll_loss, + hinge_embedding_loss, + hsigmoid_loss, + kl_div, + l1_loss, + log_loss, + margin_cross_entropy, + margin_ranking_loss, + mse_loss, + multi_label_soft_margin_loss, + multi_margin_loss, + nll_loss, + npair_loss, + poisson_nll_loss, + rnnt_loss, + sigmoid_focal_loss, + smooth_l1_loss, + soft_margin_loss, + softmax_with_cross_entropy, + square_error_cost, + triplet_margin_loss, + triplet_margin_with_distance_loss, +) +from .norm import ( + batch_norm, + instance_norm, + layer_norm, + local_response_norm, + normalize, +) +from .pooling import ( + adaptive_avg_pool1d, + adaptive_avg_pool2d, + adaptive_avg_pool3d, + adaptive_max_pool1d, + adaptive_max_pool2d, + adaptive_max_pool3d, + avg_pool1d, + avg_pool2d, + avg_pool3d, + max_pool1d, + max_pool2d, + max_pool3d, + max_unpool1d, + max_unpool2d, + max_unpool3d, +) from .sparse_attention import sparse_attention -from .flash_attention import scaled_dot_product_attention -from .flash_attention import sdp_kernel +from .vision import ( + affine_grid, + channel_shuffle, + grid_sample, + pixel_shuffle, + pixel_unshuffle, +) __all__ = [ 'celu', @@ -183,7 +191,6 @@ 'log_softmax', 'glu', 'gumbel_softmax', - 'diag_embed', 'sequence_mask', 'dropout', 'dropout2d', diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index c74748793a4e9..7360c14a5214c 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -21,9 +21,8 @@ from ...base.framework import convert_np_dtype_to_dtype_ from ...base.layer_helper import LayerHelper from ...tensor.manipulation import chunk -from ...tensor.math import tanh # noqa: F401 -from ...tensor.math import tanh_ # noqa: F401 -from ...tensor.ops import sigmoid # noqa: F401 +from ...tensor.math import tanh, tanh_ # noqa: F401 +from ...tensor.ops import sigmoid __all__ = [] @@ -61,7 +60,7 @@ def celu(x, alpha=1.0, name=None): """ if alpha == 0: raise ZeroDivisionError("alpha cannot be 0 for celu") - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.celu(x, alpha) else: check_variable_and_dtype( @@ -243,7 +242,7 @@ def hardshrink(x, threshold=0.5, name=None): """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.hardshrink(x, threshold) else: check_variable_and_dtype( @@ -297,7 +296,7 @@ def hardtanh(x, min=-1.0, max=1.0, name=None): [-1. , 0.30000001, 1. ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.hardtanh(x, min, max) else: check_variable_and_dtype( @@ -364,7 +363,7 @@ def hardsigmoid(x, slope=0.1666667, offset=0.5, name=None): [0. , 1. , 0.66666669]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.hardsigmoid(x, slope, offset) else: check_variable_and_dtype( @@ -418,7 +417,7 @@ def hardswish(x, name=None): Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, [-0. , 5. , 0.66666669]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.hardswish(x) else: check_variable_and_dtype( @@ -484,7 +483,7 @@ def leaky_relu(x, negative_slope=0.01, name=None): [-0.02000000, 0. , 1. ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.leaky_relu(x, negative_slope) else: check_variable_and_dtype( @@ -594,7 +593,7 @@ def prelu(x, weight, data_format="NCHW", name=None): ), "The weight size should be equal to x input channel in prelu() when weight shape is not [1]." mode = 'channel' - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.prelu(x, weight, data_format, mode) else: check_variable_and_dtype( @@ -696,9 +695,7 @@ def rrelu(x, lower=1.0 / 8.0, upper=1.0 / 3.0, training=True, name=None): if lower < 0 or lower > 1: raise ValueError( - "The lower value must be no less than zero or greater than one. Received: {}.".format( - lower - ) + f"The lower value must be no less than zero or greater than one. Received: {lower}." ) if upper < lower: @@ -813,7 +810,7 @@ def log_sigmoid(x, name=None): [-0.31326166, -0.12692805, -0.04858733, -0.01814996]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.logsigmoid(x) else: check_variable_and_dtype( @@ -942,7 +939,7 @@ def relu6(x, name=None): [0. , 0.30000001, 6. ]) """ threshold = 6.0 - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.relu6(x) check_variable_and_dtype( @@ -1298,7 +1295,7 @@ def softplus(x, beta=1, threshold=20, name=None): [0.51301527, 0.59813893, 0.74439669, 0.85435522]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.softplus(x, beta, threshold) else: check_variable_and_dtype( @@ -1365,7 +1362,7 @@ def softshrink(x, threshold=0.5, name=None): f"The threshold must be no less than zero. Received: {threshold}." ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.softshrink(x, threshold) else: check_variable_and_dtype( @@ -1391,7 +1388,7 @@ def softsign(x, name=None): softsign(x) = \frac{x}{1 + |x|} Parameters: - x (Tensor): The input Tensor with data type float32, float64. + x (Tensor): The input Tensor with data type float32, float64, complex64 or complex128. name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: @@ -1409,7 +1406,7 @@ def softsign(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [-0.28571430, -0.16666666, 0.09090909, 0.23076925]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.softsign(x) check_variable_and_dtype( @@ -1448,7 +1445,7 @@ def swish(x, name=None): Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, [-0.23840584, 0. , 0.73105860]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.swish(x) else: check_variable_and_dtype( @@ -1683,7 +1680,7 @@ def log_softmax(x, axis=-1, dtype=None, name=None): if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)): dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if dtype is not None: x = _C_ops.cast(x, dtype) return _C_ops.log_softmax(x, axis) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 74468e719eaca..ef5955d8ac08e 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -410,7 +410,7 @@ def interpolate( 'The x and size should satisfy rank(x) - 2 == len(size).' ) - if isinstance(size, Variable): + if isinstance(size, (Variable, paddle.pir.OpResult)): size = size.cast("int32") # static mode only support int32 if size.ndim != 1: raise ValueError( @@ -432,7 +432,7 @@ def interpolate( ) if resample == 'AREA': - if isinstance(size, (list, tuple, Variable)): + if isinstance(size, (list, tuple, Variable, paddle.pir.OpResult)): if len(size) == 0: raise ValueError("output size can not be empty") if size is None: @@ -491,7 +491,10 @@ def _is_list_or_turple_(data): if out_shape is not None and scale is not None: raise ValueError("Only one of size or scale_factor should be defined.") if out_shape is not None: - if isinstance(out_shape, Variable) and not in_dynamic_mode(): + if ( + isinstance(out_shape, (Variable, paddle.pir.OpResult)) + and not in_dynamic_mode() + ): out_shape.stop_gradient = True inputs['OutSize'] = out_shape else: @@ -509,7 +512,7 @@ def _is_list_or_turple_(data): # Validate the shape contain_var = False for dim_idx, dim_size in enumerate(out_shape): - if isinstance(dim_size, Variable): + if isinstance(dim_size, (Variable, paddle.pir.OpResult)): contain_var = True continue assert ( @@ -520,18 +523,25 @@ def _is_list_or_turple_(data): new_size_tensor = [] size_list = [] for dim in out_shape: - if isinstance(dim, Variable): + if isinstance(dim, (Variable, paddle.pir.OpResult)): dim.stop_gradient = True new_size_tensor.append(dim) size_list.append(-1) else: assert isinstance(dim, int) - temp_out = helper.create_variable_for_type_inference( - 'int32' - ) - paddle.tensor.fill_constant( - [1], 'int32', dim, force_cpu=True, out=temp_out - ) + if in_pir_mode(): + temp_out = paddle.tensor.fill_constant( + [1], 'int32', dim, force_cpu=True + ) + else: + temp_out = ( + helper.create_variable_for_type_inference( + 'int32' + ) + ) + paddle.tensor.fill_constant( + [1], 'int32', dim, force_cpu=True, out=temp_out + ) new_size_tensor.append(temp_out) size_list.append(dim) inputs['SizeTensor'] = new_size_tensor @@ -579,7 +589,7 @@ def _is_list_or_turple_(data): scale = float(scale) else: scale = list(scale.numpy()) - if isinstance(scale, Variable): + if isinstance(scale, (Variable, paddle.pir.OpResult)): scale.stop_gradient = True inputs["Scale"] = scale elif isinstance(scale, (float, int, numpy.ndarray)): @@ -604,7 +614,7 @@ def _is_list_or_turple_(data): "Attr(scale)'s type should be float, int, list, tuple, or Tensor." ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): attr_list = [] for k, v in attrs.items(): attr_list.append(k) @@ -1155,9 +1165,7 @@ def get_attrs(prog, dropout_prob, is_test, seed): dropout_prob, Variable ) and not dropout_prob.shape != [1]: raise TypeError( - "Required p.shape == [1] if type(p) is Variable, but received p.shape = {}".format( - p.shape - ) + f"Required p.shape == [1] if type(p) is Variable, but received p.shape = {p.shape}" ) attrs = { 'dropout_prob': dropout_prob, @@ -1196,7 +1204,7 @@ def get_attrs(prog, dropout_prob, is_test, seed): # get mask shape input_shape = x.shape - if not in_dynamic_or_pir_mode(): + if not in_dynamic_mode(): input_shape_tensor = paddle.shape(x) drop_axes = [axis] if isinstance(axis, int) else list(axis) if min(drop_axes) < 0 or max(drop_axes) > len(input_shape) - 1: @@ -1212,7 +1220,7 @@ def get_attrs(prog, dropout_prob, is_test, seed): ) ) mask_shape = [1] * len(input_shape) - if not in_dynamic_or_pir_mode(): + if not in_dynamic_mode(): for i in drop_axes: mask_shape[i] = input_shape_tensor[i] else: @@ -2050,7 +2058,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): if epsilon > 1.0 or epsilon < 0.0: raise ValueError("The value of epsilon must be between 0 and 1.") - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.label_smooth(label, prior_dist, float(epsilon)) check_variable_and_dtype( @@ -2228,7 +2236,7 @@ class centers and the shape of sampled_class_center will be [num_positive_class_ if (seed is None or seed == 0) and default_main_program().random_seed != 0: seed = default_main_program().random_seed - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.class_center_sample( label, num_classes, diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 6caf0370366f4..df83721599354 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle import _C_ops, _legacy_C_ops, get_flags, in_dynamic_mode -from paddle.base.framework import _global_flags +from paddle import _C_ops, _legacy_C_ops, get_flags, in_dynamic_mode, pir +from paddle.base.framework import _global_flags, in_dynamic_or_pir_mode from paddle.device import ( get_all_custom_device_type, is_compiled_with_cuda, @@ -102,9 +102,7 @@ def _update_padding_nd(padding, channel_last, num_dims): padding = convert_to_list(padding, num_dims, 'padding') if not all(p >= 0 for p in padding): raise ValueError( - "Invalid padding, all value should be larger than or equal to 0, but received: {}".format( - padding - ) + f"Invalid padding, all value should be larger than or equal to 0, but received: {padding}" ) return padding, padding_algorithm @@ -126,7 +124,7 @@ def _conv_nd( name=None, ): # Due to the poor performance of NHWC, we transpose the input to NCHW. - if in_dynamic_mode() and op_type == "conv2d": + if in_dynamic_or_pir_mode() and op_type == "conv2d": pre_bias = _C_ops.conv2d( x, weight, @@ -155,7 +153,7 @@ def _conv_nd( else: return pre_bias - if in_dynamic_mode() and op_type == "depthwise_conv2d": + if in_dynamic_or_pir_mode() and op_type == "depthwise_conv2d": pre_bias = _C_ops.depthwise_conv2d( x, weight, @@ -174,7 +172,7 @@ def _conv_nd( else: return pre_bias - if in_dynamic_mode() and op_type == "conv3d": + if in_dynamic_or_pir_mode() and op_type == "conv3d": pre_bias = _C_ops.conv3d( x, weight, @@ -437,9 +435,7 @@ def conv1d( padding = [0] + padding else: raise ValueError( - "The size of padding's dimension should be 1 or 2. But got padding={}".format( - padding - ) + f"The size of padding's dimension should be 1 or 2. But got padding={padding}" ) stride = [1] + convert_to_list(stride, 1, 'stride') dilation = [1] + convert_to_list(dilation, 1, 'dilation') @@ -463,7 +459,7 @@ def conv1d( squeeze_aixs = -3 if channel_last else -2 x = unsqueeze(x, axis=[squeeze_aixs]) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if l_type == 'conv2d': out = _C_ops.conv2d( x, @@ -918,9 +914,7 @@ def conv1d_transpose( ) if groups <= 0: raise ValueError( - "The groups of conv1d_transpose should be greater than 0. Received groups: {}".format( - groups - ) + f"The groups of conv1d_transpose should be greater than 0. Received groups: {groups}" ) if num_channels % groups != 0: raise ValueError( @@ -1202,9 +1196,7 @@ def conv2d_transpose( ) if groups <= 0: raise ValueError( - "The groups of conv2d_transpose should be greater than 0. Received groups: {}".format( - groups - ) + f"The groups of conv2d_transpose should be greater than 0. Received groups: {groups}" ) if num_channels % groups != 0: raise ValueError( @@ -1241,7 +1233,7 @@ def conv2d_transpose( output_size = convert_to_list(output_size, 2, 'output_size') elif isinstance(output_size, int): output_size = convert_to_list(output_size, 2, 'output_size') - elif isinstance(output_size, Variable): + elif isinstance(output_size, (Variable, pir.OpResult)): check_dtype( output_size.dtype, 'output_size', @@ -1273,7 +1265,7 @@ def conv2d_transpose( op_type = 'depthwise_conv2d_transpose' use_cudnn = False - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): op = ( _C_ops.conv2d_transpose if op_type == 'conv2d_transpose' @@ -1699,9 +1691,7 @@ def conv3d_transpose( ) if groups <= 0: raise ValueError( - "The groups of conv3d_transpose should be greater than 0. Received groups: {}".format( - groups - ) + f"The groups of conv3d_transpose should be greater than 0. Received groups: {groups}" ) if num_channels % groups != 0: raise ValueError( diff --git a/python/paddle/nn/functional/distance.py b/python/paddle/nn/functional/distance.py index dc69092daed08..113df166a027a 100644 --- a/python/paddle/nn/functional/distance.py +++ b/python/paddle/nn/functional/distance.py @@ -14,7 +14,7 @@ import paddle from paddle import _C_ops -from paddle.framework import in_dynamic_mode +from paddle.framework import in_dynamic_or_pir_mode from ...base.data_feeder import check_type, check_variable_and_dtype from ...base.layer_helper import LayerHelper @@ -67,13 +67,11 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None): Tensor(shape=[2], dtype=float64, place=Place(cpu), stop_gradient=True, [4.99999860, 4.99999860]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): sub = _C_ops.subtract(x, y) # p_norm op has not used epsilon, so change it to the following. if epsilon != 0.0: - epsilon = paddle.base.dygraph.base.to_variable( - [epsilon], dtype=sub.dtype - ) + epsilon = paddle.to_tensor([epsilon], dtype=sub.dtype) sub = _C_ops.add(sub, epsilon) return _C_ops.p_norm(sub, p, -1, 0.0, keepdim, False) diff --git a/python/paddle/nn/functional/extension.py b/python/paddle/nn/functional/extension.py index 757c9059efdd6..f52a133433120 100644 --- a/python/paddle/nn/functional/extension.py +++ b/python/paddle/nn/functional/extension.py @@ -14,141 +14,26 @@ # TODO: define the extention functions -import numpy as np -from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode +from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode, tensor +from paddle.utils import deprecated -from ...base.data_feeder import ( - check_dtype, - check_type, - check_variable_and_dtype, -) +from ...base.data_feeder import check_type, check_variable_and_dtype from ...base.layer_helper import LayerHelper from ...common_ops_import import Variable from ...framework import convert_np_dtype_to_dtype_, core -from ...tensor.creation import assign __all__ = [] +@deprecated( + since="2.5.2", + update_to="paddle.diag_embed", + level=1, + reason="diag_embed in paddle.nn.functional will be removed in future", +) def diag_embed(input, offset=0, dim1=-2, dim2=-1): - """ - Creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) - are filled by ``input``. By default, a 2D plane formed by the last two dimensions - of the returned tensor will be selected. - - The argument ``offset`` determines which diagonal is generated: - - - If offset = 0, it is the main diagonal. - - If offset > 0, it is above the main diagonal. - - If offset < 0, it is below the main diagonal. - - Args: - input(Tensor|numpy.ndarray): The input tensor. Must be at least 1-dimensional. The input data type should be float32, float64, int32, int64. - offset(int, optional): Which diagonal to consider. Default: 0 (main diagonal). - dim1(int, optional): The first dimension with respect to which to take diagonal. Default: -2. - dim2(int, optional): The second dimension with respect to which to take diagonal. Default: -1. - - Returns: - Tensor, the output data type is the same as input data type. - - Examples: - .. code-block:: python - - >>> import paddle - >>> import paddle.nn.functional as F - - >>> diag_embed_input = paddle.arange(6) - - >>> diag_embed_output1 = F.diag_embed(diag_embed_input) - >>> print(diag_embed_output1) - Tensor(shape=[6, 6], dtype=int64, place=Place(cpu), stop_gradient=True, - [[0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0], - [0, 0, 2, 0, 0, 0], - [0, 0, 0, 3, 0, 0], - [0, 0, 0, 0, 4, 0], - [0, 0, 0, 0, 0, 5]]) - - >>> diag_embed_output2 = F.diag_embed(diag_embed_input, offset=-1, dim1=0,dim2=1 ) - >>> print(diag_embed_output2) - Tensor(shape=[7, 7], dtype=int64, place=Place(cpu), stop_gradient=True, - [[0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0], - [0, 0, 2, 0, 0, 0, 0], - [0, 0, 0, 3, 0, 0, 0], - [0, 0, 0, 0, 4, 0, 0], - [0, 0, 0, 0, 0, 5, 0]]) - - >>> diag_embed_input_2dim = paddle.reshape(diag_embed_input,[2,3]) - >>> print(diag_embed_input_2dim) - Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, - [[0, 1, 2], - [3, 4, 5]]) - >>> diag_embed_output3 = F.diag_embed(diag_embed_input_2dim,offset= 0, dim1=0, dim2=2 ) - >>> print(diag_embed_output3) - Tensor(shape=[3, 2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, - [[[0, 0, 0], - [3, 0, 0]], - [[0, 1, 0], - [0, 4, 0]], - [[0, 0, 2], - [0, 0, 5]]]) - """ - if not isinstance(input, Variable): - input = assign(input) - - if in_dynamic_mode(): - return _C_ops.diag_embed(input, offset, dim1, dim2) - - inputs = {'Input': [input]} - attrs = {'offset': offset, 'dim1': dim1, 'dim2': dim2} - - def __check_input(input, offset, dim1, dim2): - check_dtype( - input.dtype, - 'Input', - ['int32', 'int64', 'float16', 'float32', 'float64'], - 'diag_embed', - ) - - input_shape = list(input.shape) - assert len(input_shape) >= 1, ( - "Input must be at least 1-dimensional, " - "But received Input's dimensional: %s.\n" % len(input_shape) - ) - - assert np.abs(dim1) <= len(input_shape), ( - "Dim1 is out of range (expected to be in range of [%d, %d], but got %d).\n" - % (-(len(input_shape) + 1), len(input_shape), dim1) - ) - - assert np.abs(dim2) <= len(input_shape), ( - "Dim2 is out of range (expected to be in range of [%d, %d], but got %d).\n" - % (-(len(input_shape) + 1), len(input_shape), dim2) - ) - - dim1_ = dim1 if dim1 >= 0 else len(input_shape) + dim1 + 1 - dim2_ = dim2 if dim2 >= 0 else len(input_shape) + dim2 + 1 - assert dim1_ != dim2_, ( - "dim1 and dim2 cannot be the same dimension." - "But received dim1 = %d, dim2 = %d\n" % (dim1, dim2) - ) - - __check_input(input, offset, dim1, dim2) - helper = LayerHelper("diag_embed", **locals()) - - out = helper.create_variable_for_type_inference(dtype=input.dtype) - - helper.append_op( - type='diag_embed', - inputs={'Input': [input]}, - attrs={'offset': offset, 'dim1': dim1, 'dim2': dim2}, - outputs={'Out': [out]}, - ) - out.stop_gradient = True - return out + return tensor.diag_embed(input, offset, dim1, dim2) def sequence_mask(x, maxlen=None, dtype='int64', name=None): diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 11b85df5d1377..98da4e717feb3 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -45,6 +45,15 @@ def sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=True): g_enable_mem_efficient = original_enable_mem_efficient +# special for XPU device +def get_triangle_upper_mask(x): + mask = paddle.full_like(x, -1e4) + mask.stop_gradient = True + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + + def _math_attention( query, key, @@ -65,11 +74,19 @@ def _math_attention( product = paddle.matmul( x=query * (head_dim**-0.5), y=key, transpose_y=True ) - weights = ( - paddle.incubate.softmax_mask_fuse_upper_triangle(product) - if causal - else F.softmax(product) - ) + + if not causal: + weights = F.softmax(product) + else: + # special for XPU device + place = paddle.get_device() + if "xpu" in place: + # softmax_mask_fuse_upper_triangle is not supported on XPU, use plain implementation + mask = get_triangle_upper_mask(product) + product = product + mask + weights = F.softmax(product) + else: + weights = paddle.incubate.softmax_mask_fuse_upper_triangle(product) if dropout_rate > 0.0: weights = F.dropout( weights, dropout_rate, training=training, mode="upscale_in_train" @@ -183,10 +200,22 @@ def flash_attention( >>> import paddle - >>> paddle.seed(1) + >>> paddle.seed(2023) >>> q = paddle.rand((1, 128, 2, 16)) >>> output = paddle.nn.functional.flash_attention.flash_attention(q, q, q, 0.9, False, False) + >>> print(output) + (Tensor(shape=[1, 128, 2, 16], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[[0.34992966, 0.34456208, 0.45826620, ..., 0.39883569, + 0.42132431, 0.39157745], + [0.76687670, 0.65837246, 0.69117945, ..., 0.82817286, + 0.76690865, 0.71485823]], + ..., + [[0.71662450, 0.57275224, 0.57053083, ..., 0.48108247, + 0.53336465, 0.54540104], + [0.59137970, 0.51350880, 0.50449550, ..., 0.38860250, + 0.40526697, 0.60541755]]]]), None) + """ head_dim = query.shape[3] sdp_func_name = _select_sdp(head_dim) @@ -340,11 +369,12 @@ def flash_attn_unpadded( .. code-block:: python >>> import paddle - >>> paddle.seed(1) - >>> q = paddle.rand((1, 128, 2, 16)) + >>> paddle.seed(2023) + >>> q = paddle.rand((2, 128, 8, 16), dtype='float16') + >>> cu = paddle.arange(0, 384, 128, dtype='int32') + >>> qq = paddle.reshape(q, [256, 8, 16]) + >>> output = paddle.nn.functional.flash_attention.flash_attn_unpadded(qq, qq, qq, cu, cu, 128, 128, 0.25, 0.0, False, False) - >>> output = paddle.nn.functional.flash_attention.flash_attn_unpadded(q, q, q, 0.9, False, False) - >>> print(output) """ if in_dynamic_mode(): ( @@ -461,7 +491,7 @@ def scaled_dot_product_attention( Examples: .. code-block:: python - >>> # doctest: +SKIP() + >>> # doctest: +SKIP('bfloat need V100 compile') >>> import paddle >>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16) >>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index e38797a1115ae..c13ebf986e475 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -17,7 +17,7 @@ from ...base.data_feeder import check_variable_and_dtype from ...base.layer_helper import LayerHelper from ...common_ops_import import Variable -from ...framework import in_dynamic_mode, in_dynamic_or_pir_mode +from ...framework import in_dynamic_or_pir_mode __all__ = [] @@ -89,7 +89,7 @@ def one_hot(x, num_classes, name=None): """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.one_hot(x, num_classes) else: check_variable_and_dtype(x, 'input', ['int32', 'int64'], 'one_hot_v2') diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index e74e67d83f88e..5faba8b2f3131 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -22,7 +22,7 @@ from paddle.utils import deprecated from ...base.data_feeder import check_variable_and_dtype -from ...base.framework import _current_expected_place +from ...base.framework import _current_expected_place, in_pir_mode from ...base.layer_helper import LayerHelper from ...common_ops_import import Variable from ...tensor.manipulation import reshape @@ -713,7 +713,7 @@ def binary_cross_entropy_with_logits( logit, label, weight=None, reduction='mean', pos_weight=None, name=None ): r""" - Combine the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer. + Combine the sigmoid layer and the :ref:`api_paddle_nn_BCELoss` layer. This measures the element-wise probability error in classification tasks in which each class is independent. @@ -1337,13 +1337,13 @@ def l1_loss(input, label, reduction='mean', name=None): check_variable_and_dtype( input, 'input', - ['float32', 'float64', 'int32', 'int64'], + ['float32', 'float64', 'int32', 'int64', 'float16'], 'l1_loss', ) check_variable_and_dtype( label, 'label', - ['float32', 'float64', 'int32', 'int64'], + ['float32', 'float64', 'int32', 'int64', 'float16'], 'l1_loss', ) @@ -2935,24 +2935,31 @@ def cross_entropy( ['uint8', 'int8', 'int16', 'int32', 'int64', 'float32', 'float64'], 'softmax_cross_entropy', ) - attrs = { - 'soft_label': soft_label, - 'ignore_index': ignore_index, - 'numeric_stable_mode': True, - 'axis': axis, - 'use_softmax': use_softmax, - } - helper = LayerHelper('softmax_with_cross_entropy', **locals()) - softmax = helper.create_variable_for_type_inference(dtype=input.dtype) - out = helper.create_variable_for_type_inference(dtype=input.dtype) + if in_pir_mode(): + softmax, out = _C_ops.cross_entropy_with_softmax( + input, label, soft_label, use_softmax, True, ignore_index, axis + ) + else: + attrs = { + 'soft_label': soft_label, + 'ignore_index': ignore_index, + 'numeric_stable_mode': True, + 'axis': axis, + 'use_softmax': use_softmax, + } + helper = LayerHelper('softmax_with_cross_entropy', **locals()) + softmax = helper.create_variable_for_type_inference( + dtype=input.dtype + ) + out = helper.create_variable_for_type_inference(dtype=input.dtype) - outputs = {'Softmax': softmax, 'Loss': out} - helper.append_op( - type='softmax_with_cross_entropy', - inputs={'Logits': input, 'Label': label}, - outputs=outputs, - attrs=attrs, - ) + outputs = {'Softmax': softmax, 'Loss': out} + helper.append_op( + type='softmax_with_cross_entropy', + inputs={'Logits': input, 'Label': label}, + outputs=outputs, + attrs=attrs, + ) if weight is not None: check_variable_and_dtype( @@ -3036,19 +3043,21 @@ def cross_entropy( if weight is None: mask = paddle.cast(mask, dtype=out_sum.dtype) count = paddle.sum(mask, name=name) - ret = out_sum / (count + (count == 0.0)) + ret = out_sum / (count + paddle.equal(count, 0.0)) else: mask = paddle.cast(mask, weight_gather_reshape.dtype) weight_ignored = paddle.multiply( mask, weight_gather_reshape ) weight_sum = paddle.sum(weight_ignored, name=name) - ret = out_sum / (weight_sum + (weight_sum == 0.0)) + ret = out_sum / (weight_sum + paddle.equal(weight_sum, 0.0)) return ret elif weight is not None: out_sum = paddle.sum(out, name=name) total_weight = paddle.sum(weight_gather_reshape) - return out_sum / (total_weight + (total_weight == 0.0)) + return out_sum / ( + total_weight + paddle.equal(total_weight, 0.0) + ) else: return paddle.mean(out, name=name) @@ -3301,7 +3310,7 @@ def multi_label_soft_margin_loss( if reduction not in ['sum', 'mean', 'none']: raise ValueError( "'reduction' in 'multi_label_soft_margin_loss' should be 'sum', 'mean' or 'none', " - "but received {}.".format(reduction) + f"but received {reduction}." ) if not (input.shape == label.shape): diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 704eb880c516c..0460a823b2017 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -91,9 +91,7 @@ def normalize(x, p=2, axis=1, epsilon=1e-12, name=None): ) if len(x.shape) == 1 and axis != 0 and axis != -1: raise ValueError( - "Axis must be 0 or -1 when x is a 1-D tensor, but received axis = {}".format( - axis - ) + f"Axis must be 0 or -1 when x is a 1-D tensor, but received axis = {axis}" ) attrs = { @@ -116,8 +114,8 @@ def batch_norm( x, running_mean, running_var, - weight, - bias, + weight=None, + bias=None, training=False, momentum=0.9, epsilon=1e-05, @@ -134,8 +132,8 @@ def batch_norm( x(Tesnor): input value. It's data type should be float32, float64. running_mean(Tensor): running mean. running_var(Tensor): running variance. - weight(Tensor): The weight tensor of batch_norm, can not be None. - bias(Tensor): The bias tensor of batch_norm can not be None. + weight(Tensor, optional): The weight tensor of batch_norm. Default: None. + bias(Tensor, optional): The bias tensor of batch_norm. Default: None. epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. training(bool, optional): True means train mode which compute by batch data and track global mean and var during train period. False means inference mode which compute by global mean and var which calculated by train period. Default False. momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. @@ -194,7 +192,7 @@ def batch_norm( else: trainable_statistics = not use_global_stats - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm( x, running_mean, @@ -229,12 +227,14 @@ def batch_norm( inputs = { "X": [x], - "Scale": [weight], - "Bias": [bias], "Mean": [running_mean], "Variance": [running_var], } + if weight: + inputs['Scale'] = [weight] + if bias: + inputs['Bias'] = [bias] helper = LayerHelper('batch_norm', **locals()) from paddle.base.data_feeder import convert_dtype diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 6f111c61cb507..07bdc48243314 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -15,7 +15,11 @@ import numpy as np from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode -from paddle.base.framework import Variable, in_dygraph_mode +from paddle.base.framework import ( + Variable, + in_dygraph_mode, + in_dynamic_or_pir_mode, +) from ...base.data_feeder import check_type, check_variable_and_dtype from ...base.layer_helper import LayerHelper @@ -167,9 +171,7 @@ def _expand_low_nd_padding(padding): padding = [0] + padding else: raise ValueError( - "The size of padding's dimmention should be 1 or 2. But got padding={}".format( - padding - ) + f"The size of padding's dimmention should be 1 or 2. But got padding={padding}" ) return padding @@ -185,7 +187,7 @@ def avg_pool1d( ): """ This API implements average pooling 1d operation, - See more details in :ref:`api_nn_pooling_AvgPool1d` . + See more details in :ref:`api_paddle_nn_AvgPool1d` . Args: x (Tensor): The input tensor of pooling operator which is a 3-D tensor with @@ -308,7 +310,7 @@ def avg_pool2d( ): """ This API implements average pooling 2d operation. - See more details in :ref:`api_nn_pooling_AvgPool2d` . + See more details in :ref:`api_paddle_nn_AvgPool2d` . Args: x (Tensor): The input tensor of pooling operator which is a 4-D tensor with @@ -372,7 +374,7 @@ def avg_pool2d( padding, 2, channel_last, ceil_mode=ceil_mode ) - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): output = _C_ops.pool2d( x, kernel_size, @@ -441,7 +443,7 @@ def avg_pool3d( ): """ This API implements average pooling 3d operation. - See more details in :ref:`api_nn_pooling_AvgPool3d` . + See more details in :ref:`api_paddle_nn_AvgPool3d` . Args: x (Tensor): The input tensor of pooling operator, which is a 5-D tensor with @@ -568,7 +570,7 @@ def max_pool1d( ): """ This API implements max pooling 1d opereation. - See more details in :ref:`api_nn_pooling_MaxPool1d` . + See more details in :ref:`api_paddle_nn_MaxPool1d` . Args: x (Tensor): The input tensor of pooling operator which is a 3-D tensor with @@ -1180,7 +1182,7 @@ def max_pool2d( ): """ This API implements max pooling 2d operation. - See more details in :ref:`api_nn_pooling_MaxPool2d` . + See more details in :ref:`api_paddle_nn_MaxPool2d` . Args: x (Tensor): The input tensor of pooling operator which is a 4-D tensor with @@ -1254,7 +1256,7 @@ def max_pool2d( "When setting return_mask to true, data_format must be set to NCHW in API:max_pool2d" ) - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): if return_mask: output = _C_ops.max_pool2d_with_index( x, kernel_size, stride, padding, False, False @@ -1480,7 +1482,7 @@ def adaptive_avg_pool1d(x, output_size, name=None): Adaptive average pooling 1d operation on :attr:`x` according to :attr:`output_size`. Notes: - See more details in :ref:`api_nn_pooling_AdaptiveAvgPool1d` . + See more details in :ref:`api_paddle_nn_AdaptiveAvgPool1d` . Args: x (Tensor): The input Tensor of pooling, which is a 3-D tensor with shape :math:`[N, C, L]`, where :math:`N` is batch size, :math:`C` is the number of channels and :math:`L` is the length of the feature. The data type is float32 or float64. @@ -1518,8 +1520,9 @@ def adaptive_avg_pool1d(x, output_size, name=None): pool_size = [1] + convert_to_list(output_size, 1, 'pool_size') x = unsqueeze(x, [2]) - if in_dygraph_mode(): - x = x._use_gpudnn(False) + if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): + x = x._use_gpudnn(False) pool_out = _C_ops.pool2d( x, pool_size, @@ -1647,8 +1650,9 @@ def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None): elif _contain_var(output_size): output_size = _convert_to_tensor_list(output_size) - if in_dygraph_mode(): - x = x._use_gpudnn(False) + if in_dynamic_or_pir_mode(): + if in_dygraph_mode(): + x = x._use_gpudnn(False) return _C_ops.pool2d( x, output_size, @@ -1662,7 +1666,6 @@ def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None): True, "EXPLICIT", ) - else: l_type = 'pool2d' check_variable_and_dtype( @@ -1775,8 +1778,9 @@ def adaptive_avg_pool3d(x, output_size, data_format='NCDHW', name=None): if output_size[2] is None: output_size[2] = in_w - if in_dygraph_mode(): - x = x._use_gpudnn(False) + if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): + x = x._use_gpudnn(False) return _C_ops.pool3d( x, output_size, @@ -1821,7 +1825,7 @@ def adaptive_avg_pool3d(x, output_size, data_format='NCDHW', name=None): def adaptive_max_pool1d(x, output_size, return_mask=False, name=None): """ This API implements adaptive max pooling 1d operation. - See more details in :ref:`api_nn_pooling_AdaptiveMaxPool1d` . + See more details in :ref:`api_paddle_nn_AdaptiveMaxPool1d` . Args: x (Tensor): The input tensor of pooling operator, which is a 3-D tensor @@ -1917,7 +1921,7 @@ def adaptive_max_pool1d(x, output_size, return_mask=False, name=None): def adaptive_max_pool2d(x, output_size, return_mask=False, name=None): """ This operation applies a 2D adaptive max pooling on input tensor. - See more details in :ref:`api_nn_pooling_AdaptiveMaxPool2d` . + See more details in :ref:`api_paddle_nn_AdaptiveMaxPool2d` . Args: x (Tensor): The input tensor of adaptive max pool2d operator, which is a 4-D tensor. The data type can be float16, float32, float64, int32 or int64. @@ -1996,14 +2000,13 @@ def adaptive_max_pool2d(x, output_size, return_mask=False, name=None): "adaptive": True, }, ) - # return (pool_out, mask) if return_mask else pool_out - return pool_out + return (pool_out, mask) if return_mask else pool_out def adaptive_max_pool3d(x, output_size, return_mask=False, name=None): """ This operation applies a 3D adaptive max pooling on input tensor. - See more details in :ref:`api_nn_pooling_AdaptiveMaxPool3d` . + See more details in :ref:`api_paddle_nn_AdaptiveMaxPool3d` . Args: x (Tensor): The input tensor of adaptive max pool3d operator, which is a 5-D tensor. The data type can be float32, float64. diff --git a/python/paddle/nn/initializer/__init__.py b/python/paddle/nn/initializer/__init__.py index c1e0866ad8f06..09691ebe8ffa3 100644 --- a/python/paddle/nn/initializer/__init__.py +++ b/python/paddle/nn/initializer/__init__.py @@ -13,37 +13,26 @@ # limitations under the License. # TODO: define the initializers to create a Parameter in neural network -from ...base.initializer import set_global_initializer # noqa: F401 - -from .Bilinear import Bilinear # noqa: F401 - -from .constant import Constant # noqa: F401 - -from .kaiming import KaimingNormal # noqa: F401 -from .kaiming import KaimingUniform # noqa: F401 - -from .xavier import XavierNormal # noqa: F401 -from .xavier import XavierUniform # noqa: F401 - -from .assign import Assign # noqa: F401 - -from .normal import Normal # noqa: F401 -from .normal import TruncatedNormal # noqa: F401 - -from .uniform import Uniform # noqa: F401 - -from .orthogonal import Orthogonal # noqa: F401 - -from .dirac import Dirac # noqa: F401 - -from .initializer import Initializer, calculate_gain # noqa: F401 -from .uniform import UniformInitializer # noqa: F401 +from ...base.initializer import set_global_initializer +from .assign import NumpyArrayInitializer # noqa: F401 +from .assign import Assign +from .Bilinear import Bilinear from .constant import ConstantInitializer # noqa: F401 -from .normal import NormalInitializer # noqa: F401 -from .normal import TruncatedNormalInitializer # noqa: F401 -from .xavier import XavierInitializer # noqa: F401 +from .constant import Constant +from .dirac import Dirac +from .initializer import Initializer, calculate_gain # noqa: F401 from .kaiming import MSRAInitializer # noqa: F401 -from .assign import NumpyArrayInitializer # noqa: F401 +from .kaiming import KaimingNormal, KaimingUniform +from .normal import ( # noqa: F401 + Normal, + NormalInitializer, + TruncatedNormal, + TruncatedNormalInitializer, +) +from .orthogonal import Orthogonal +from .uniform import Uniform, UniformInitializer # noqa: F401 +from .xavier import XavierInitializer # noqa: F401 +from .xavier import XavierNormal, XavierUniform __all__ = [ 'Bilinear', diff --git a/python/paddle/nn/initializer/normal.py b/python/paddle/nn/initializer/normal.py index c1bcb89f676f7..3a05bbed121f3 100644 --- a/python/paddle/nn/initializer/normal.py +++ b/python/paddle/nn/initializer/normal.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle import _C_ops +from paddle import _C_ops, pir from ...base import core, framework, unique_name from ...base.data_feeder import check_variable_and_dtype -from ...base.framework import _current_expected_place, in_dygraph_mode +from ...base.framework import ( + _current_expected_place, + in_dygraph_mode, + in_pir_mode, +) from .initializer import Initializer __all__ = [] @@ -54,7 +58,7 @@ def forward(self, var, block=None): """ block = self._check_block(block) - assert isinstance(block, framework.Block) + assert isinstance(block, (framework.Block, pir.Block)) check_variable_and_dtype( var, @@ -78,7 +82,17 @@ def forward(self, var, block=None): ) out_var._share_underline_tensor_to(var) return None - + elif in_pir_mode(): + place = _current_expected_place() + out_var = _C_ops.gaussian( + var.shape, + self._mean, + self._std_dev, + self._seed, + var.dtype, + place, + ) + return out_var else: op = block.append_op( type="gaussian_random", diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index f83b8454456ff..9271c5ecc10e1 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -14,95 +14,99 @@ # TODO: define activation functions of neural network -from . import rnn # noqa: F401 -from . import transformer # noqa: F401 -from . import container # noqa: F401 - -from .activation import CELU # noqa: F401 -from .activation import PReLU # noqa: F401 -from .activation import ReLU # noqa: F401 -from .activation import ReLU6 # noqa: F401 -from .activation import LeakyReLU # noqa: F401 -from .activation import Sigmoid # noqa: F401 -from .activation import Softmax # noqa: F401 -from .activation import LogSoftmax # noqa: F401 -from .activation import RReLU # noqa: F401 -from .activation import Softmax2D # noqa: F401 -from .common import Bilinear # noqa: F401 -from .common import Pad1D # noqa: F401 -from .common import Pad2D # noqa: F401 -from .common import ZeroPad2D # noqa: F401 -from .common import Pad3D # noqa: F401 -from .common import CosineSimilarity # noqa: F401 -from .common import Embedding # noqa: F401 -from .common import Linear # noqa: F401 -from .common import Identity # noqa: F401 -from .common import Flatten # noqa: F401 -from .common import Upsample # noqa: F401 -from .common import Dropout # noqa: F401 -from .common import Dropout2D # noqa: F401 -from .common import Dropout3D # noqa: F401 -from .common import AlphaDropout # noqa: F401 -from .common import UpsamplingBilinear2D # noqa: F401 -from .common import UpsamplingNearest2D # noqa: F401 -from .common import Fold # noqa: F401 -from .common import Unflatten # noqa: F401 - -from .pooling import AvgPool1D # noqa: F401 -from .pooling import AvgPool2D # noqa: F401 -from .pooling import AvgPool3D # noqa: F401 -from .pooling import MaxPool1D # noqa: F401 -from .pooling import MaxPool2D # noqa: F401 -from .pooling import MaxPool3D # noqa: F401 -from .pooling import AdaptiveAvgPool1D # noqa: F401 -from .pooling import AdaptiveAvgPool2D # noqa: F401 -from .pooling import AdaptiveAvgPool3D # noqa: F401 -from .pooling import AdaptiveMaxPool1D # noqa: F401 -from .pooling import AdaptiveMaxPool2D # noqa: F401 -from .pooling import AdaptiveMaxPool3D # noqa: F401 -from .pooling import MaxUnPool1D # noqa: F401 -from .pooling import MaxUnPool2D # noqa: F401 -from .pooling import MaxUnPool3D # noqa: F401 -from .conv import Conv1D # noqa: F401 -from .conv import Conv2D # noqa: F401 -from .conv import Conv3D # noqa: F401 -from .conv import Conv1DTranspose # noqa: F401 -from .conv import Conv2DTranspose # noqa: F401 -from .conv import Conv3DTranspose # noqa: F401 -from .loss import BCEWithLogitsLoss # noqa: F401 -from .loss import CrossEntropyLoss # noqa: F401 -from .loss import MSELoss # noqa: F401 -from .loss import L1Loss # noqa: F401 -from .loss import NLLLoss # noqa: F401 -from .loss import PoissonNLLLoss # noqa: F401 -from .loss import BCELoss # noqa: F401 -from .loss import KLDivLoss # noqa: F401 -from .loss import MarginRankingLoss # noqa: F401 -from .loss import MultiLabelSoftMarginLoss -from .loss import CTCLoss # noqa: F401 -from .loss import RNNTLoss # noqa: F401 -from .loss import SmoothL1Loss # noqa: F401 -from .loss import HingeEmbeddingLoss # noqa: F401 -from .loss import TripletMarginWithDistanceLoss -from .loss import TripletMarginLoss -from .loss import SoftMarginLoss -from .loss import MultiMarginLoss -from .loss import GaussianNLLLoss - -from .norm import BatchNorm1D # noqa: F401 -from .norm import BatchNorm2D # noqa: F401 -from .norm import BatchNorm3D # noqa: F401 -from .norm import SyncBatchNorm # noqa: F401 -from .norm import GroupNorm # noqa: F401 -from .norm import LayerNorm # noqa: F401 -from .norm import SpectralNorm # noqa: F401 -from .norm import LocalResponseNorm # noqa: F401 - -from .vision import PixelShuffle # noqa: F401 -from .vision import PixelUnshuffle # noqa: F401 -from .vision import ChannelShuffle # noqa: F401 -from .distance import PairwiseDistance # noqa: F401 +from . import container, rnn, transformer # noqa: F401 +from .activation import ( # noqa: F401 + CELU, + LeakyReLU, + LogSoftmax, + PReLU, + ReLU, + ReLU6, + RReLU, + Sigmoid, + Softmax, + Softmax2D, +) +from .common import ( # noqa: F401 + AlphaDropout, + Bilinear, + CosineSimilarity, + Dropout, + Dropout2D, + Dropout3D, + Embedding, + Flatten, + Fold, + Identity, + Linear, + Pad1D, + Pad2D, + Pad3D, + Unflatten, + Upsample, + UpsamplingBilinear2D, + UpsamplingNearest2D, + ZeroPad2D, +) from .container import LayerDict # noqa: F401 -from .layers import Layer +from .conv import ( # noqa: F401 + Conv1D, + Conv1DTranspose, + Conv2D, + Conv2DTranspose, + Conv3D, + Conv3DTranspose, +) +from .distance import PairwiseDistance # noqa: F401 +from .layers import Layer # noqa: F401 +from .loss import ( # noqa: F401 + BCELoss, + BCEWithLogitsLoss, + CrossEntropyLoss, + CTCLoss, + GaussianNLLLoss, + HingeEmbeddingLoss, + KLDivLoss, + L1Loss, + MarginRankingLoss, + MSELoss, + MultiLabelSoftMarginLoss, + MultiMarginLoss, + NLLLoss, + PoissonNLLLoss, + RNNTLoss, + SmoothL1Loss, + SoftMarginLoss, + TripletMarginLoss, + TripletMarginWithDistanceLoss, +) +from .norm import ( # noqa: F401 + BatchNorm1D, + BatchNorm2D, + BatchNorm3D, + GroupNorm, + LayerNorm, + LocalResponseNorm, + SpectralNorm, + SyncBatchNorm, +) +from .pooling import ( # noqa: F401 + AdaptiveAvgPool1D, + AdaptiveAvgPool2D, + AdaptiveAvgPool3D, + AdaptiveMaxPool1D, + AdaptiveMaxPool2D, + AdaptiveMaxPool3D, + AvgPool1D, + AvgPool2D, + AvgPool3D, + MaxPool1D, + MaxPool2D, + MaxPool3D, + MaxUnPool1D, + MaxUnPool2D, + MaxUnPool3D, +) +from .vision import ChannelShuffle, PixelShuffle, PixelUnshuffle # noqa: F401 __all__ = [] diff --git a/python/paddle/nn/layer/layers.py b/python/paddle/nn/layer/layers.py index 791b5549ee7a2..126535870c7b4 100644 --- a/python/paddle/nn/layer/layers.py +++ b/python/paddle/nn/layer/layers.py @@ -1652,10 +1652,8 @@ def _remove_if_exist(*dicts): # conservative code. if in_to_static_mode() and _buffers[name] is None: raise RuntimeError( - 'In Dy2stat, self.{0} is a buffer and self.{0} is ' - 'not allowed to be set to Variable when self.{0} is None.'.format( - name - ) + f'In Dy2stat, self.{name} is a buffer and self.{name} is ' + f'not allowed to be set to Variable when self.{name} is None.' ) elif ( _buffers[name] is None diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index f8382ab13fe0e..6335db905e1fd 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1484,7 +1484,7 @@ def __init__(self, weight=None, reduction="mean", name=None): if reduction not in ['sum', 'mean', 'none']: raise ValueError( "'reduction' in 'MultiLabelSoftMarginloss' should be 'sum', 'mean' or 'none', " - "but received {}.".format(reduction) + f"but received {reduction}." ) self.weight = weight self.reduction = reduction diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 9944a4b481126..2f40653d193bc 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -37,7 +37,13 @@ from ...base import dygraph_utils from ...base.data_feeder import check_variable_and_dtype -from ...framework import ParamAttr, _global_flags, get_default_dtype, no_grad +from ...framework import ( + ParamAttr, + _global_flags, + get_default_dtype, + in_dynamic_or_pir_mode, + no_grad, +) from .. import functional as F from ..functional import batch_norm, instance_norm, layer_norm from ..initializer import Constant, Normal @@ -721,46 +727,25 @@ def __init__( param_shape = [num_features] # create parameter - if weight_attr is False: - self.weight = self.create_parameter( - attr=None, - shape=param_shape, - dtype=self._dtype, - default_initializer=Constant(1.0), - ) - self.weight.stop_gradient = True - else: + if weight_attr is not False: self.weight = self.create_parameter( attr=self._weight_attr, shape=param_shape, dtype=self._dtype, default_initializer=Constant(1.0), ) - self.weight.stop_gradient = ( - self._weight_attr is not None - and self._weight_attr.learning_rate == 0.0 - ) - if bias_attr is False: - self.bias = self.create_parameter( - attr=None, - shape=param_shape, - dtype=self._dtype, - default_initializer=Constant(0.0), - is_bias=True, - ) - self.bias.stop_gradient = True else: + self.weight = None + if bias_attr is not False: self.bias = self.create_parameter( attr=self._bias_attr, shape=param_shape, dtype=self._dtype, is_bias=True, ) - self.bias.stop_gradient = ( - self._bias_attr is not None - and self._bias_attr.learning_rate == 0.0 - ) + else: + self.bias = None moving_mean_name = None moving_variance_name = None @@ -986,10 +971,6 @@ def __init__( self._act = act self._use_mkldnn = _global_flags()["FLAGS_use_mkldnn"] - assert ( - bias_attr is not False - ), "bias_attr should not be False in batch_norm." - if dtype == "float16": self._dtype = "float32" else: @@ -998,25 +979,24 @@ def __init__( param_shape = [num_channels] # create parameter - self.weight = self.create_parameter( - attr=self._param_attr, - shape=param_shape, - dtype=self._dtype, - default_initializer=Constant(1.0), - ) - self.weight.stop_gradient = ( - use_global_stats and self._param_attr.learning_rate == 0.0 - ) - - self.bias = self.create_parameter( - attr=self._bias_attr, - shape=param_shape, - dtype=self._dtype, - is_bias=True, - ) - self.bias.stop_gradient = ( - use_global_stats and self._param_attr.learning_rate == 0.0 - ) + if param_attr is not False: + self.weight = self.create_parameter( + attr=self._param_attr, + shape=param_shape, + dtype=self._dtype, + default_initializer=Constant(1.0), + ) + else: + self.weight = None + if bias_attr is not False: + self.bias = self.create_parameter( + attr=self._bias_attr, + shape=param_shape, + dtype=self._dtype, + is_bias=True, + ) + else: + self.bias = None self._mean = self.create_parameter( attr=ParamAttr( @@ -1076,7 +1056,7 @@ def __init__( self._trainable_statistics = trainable_statistics def forward(self, input): - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm( input, self._mean, @@ -1092,9 +1072,13 @@ def forward(self, input): ) if self._act is None: return batch_norm_out - return dygraph_utils._append_activation_in_dygraph( - batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn - ) + if in_dynamic_mode(): + return dygraph_utils._append_activation_in_dygraph( + batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn + ) + else: + act_op = getattr(_C_ops, self._act) + return act_op(input) else: # create output # mean and mean_out share the same memory @@ -1600,6 +1584,24 @@ def __init__( None, name, ) + param_shape = [num_features] + if weight_attr is False: + self.weight = self.create_parameter( + attr=None, + shape=param_shape, + dtype=self._dtype, + default_initializer=Constant(1.0), + ) + self.weight.stop_gradient = True + if bias_attr is False: + self.bias = self.create_parameter( + attr=None, + shape=param_shape, + dtype=self._dtype, + default_initializer=Constant(0.0), + is_bias=True, + ) + self.bias.stop_gradient = True def _check_data_format(self): if self._data_format in ['NCHW', 'NCDHW', 'NC', 'NCL']: diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 676134b4efc14..3f56960e94c4c 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -783,28 +783,65 @@ def __init__( ) ) std = 1.0 / math.sqrt(hidden_size) - self.weight_ih = self.create_parameter( - (hidden_size, input_size), - weight_ih_attr, - default_initializer=I.Uniform(-std, std), - ) - self.weight_hh = self.create_parameter( - (hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std), - ) - self.bias_ih = self.create_parameter( - (hidden_size,), - bias_ih_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std), - ) - self.bias_hh = self.create_parameter( - (hidden_size,), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std), - ) + if weight_ih_attr is not False: + self.weight_ih = self.create_parameter( + (hidden_size, input_size), + weight_ih_attr, + default_initializer=I.Uniform(-std, std), + ) + else: + self.weight_ih = self.create_parameter( + (hidden_size, input_size), + None, + default_initializer=I.Constant(1.0), + ) + self.weight_ih.stop_gradient = True + + if weight_hh_attr is not False: + self.weight_hh = self.create_parameter( + (hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std), + ) + else: + self.weight_hh = self.create_parameter( + (hidden_size, hidden_size), + None, + default_initializer=I.Constant(1.0), + ) + self.weight_hh.stop_gradient = True + + if bias_ih_attr is not False: + self.bias_ih = self.create_parameter( + (hidden_size,), + bias_ih_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std), + ) + else: + self.bias_ih = self.create_parameter( + (hidden_size,), + None, + is_bias=True, + default_initializer=I.Constant(0.0), + ) + self.bias_ih.stop_gradient = True + + if bias_hh_attr is not False: + self.bias_hh = self.create_parameter( + (hidden_size,), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std), + ) + else: + self.bias_hh = self.create_parameter( + (hidden_size,), + None, + is_bias=True, + default_initializer=I.Constant(0.0), + ) + self.bias_hh.stop_gradient = True self.input_size = input_size self.hidden_size = hidden_size @@ -946,28 +983,63 @@ def __init__( ) ) std = 1.0 / math.sqrt(hidden_size) - self.weight_ih = self.create_parameter( - (4 * hidden_size, input_size), - weight_ih_attr, - default_initializer=I.Uniform(-std, std), - ) - self.weight_hh = self.create_parameter( - (4 * hidden_size, proj_size or hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std), - ) - self.bias_ih = self.create_parameter( - (4 * hidden_size,), - bias_ih_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std), - ) - self.bias_hh = self.create_parameter( - (4 * hidden_size,), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std), - ) + if weight_ih_attr is not False: + self.weight_ih = self.create_parameter( + (4 * hidden_size, input_size), + weight_ih_attr, + default_initializer=I.Uniform(-std, std), + ) + else: + self.weight_ih = self.create_parameter( + (4 * hidden_size, input_size), + None, + default_initializer=I.Constant(1.0), + ) + self.weight_ih.stop_gradient = True + if weight_hh_attr is not False: + self.weight_hh = self.create_parameter( + (4 * hidden_size, proj_size or hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std), + ) + else: + self.weight_hh = self.create_parameter( + (4 * hidden_size, proj_size or hidden_size), + None, + default_initializer=I.Constant(1.0), + ) + self.weight_hh.stop_gradient = True + if bias_ih_attr is not False: + self.bias_ih = self.create_parameter( + (4 * hidden_size,), + bias_ih_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std), + ) + else: + self.bias_ih = self.create_parameter( + (4 * hidden_size,), + None, + is_bias=True, + default_initializer=I.Constant(0.0), + ) + self.bias_ih.stop_gradient = True + if bias_hh_attr is not False: + self.bias_hh = self.create_parameter( + (4 * hidden_size,), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std), + ) + else: + self.bias_hh = self.create_parameter( + (4 * hidden_size,), + None, + is_bias=True, + default_initializer=I.Constant(0.0), + ) + self.bias_hh.stop_gradient = True + self.proj_size = proj_size if proj_size: self.weight_ho = self.create_parameter( @@ -1115,28 +1187,64 @@ def __init__( ) ) std = 1.0 / math.sqrt(hidden_size) - self.weight_ih = self.create_parameter( - (3 * hidden_size, input_size), - weight_ih_attr, - default_initializer=I.Uniform(-std, std), - ) - self.weight_hh = self.create_parameter( - (3 * hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std), - ) - self.bias_ih = self.create_parameter( - (3 * hidden_size,), - bias_ih_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std), - ) - self.bias_hh = self.create_parameter( - (3 * hidden_size,), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std), - ) + if weight_ih_attr is not False: + self.weight_ih = self.create_parameter( + (3 * hidden_size, input_size), + weight_ih_attr, + default_initializer=I.Uniform(-std, std), + ) + else: + self.weight_ih = self.create_parameter( + (3 * hidden_size, input_size), + None, + default_initializer=I.Constant(1.0), + ) + self.weight_ih.stop_gradient = True + if weight_hh_attr is not False: + self.weight_hh = self.create_parameter( + (3 * hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std), + ) + else: + self.weight_hh = self.create_parameter( + (3 * hidden_size, hidden_size), + None, + default_initializer=I.Constant(1.0), + ) + self.weight_hh.stop_gradient = True + + if bias_ih_attr is not False: + self.bias_ih = self.create_parameter( + (3 * hidden_size,), + bias_ih_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std), + ) + else: + self.bias_ih = self.create_parameter( + (3 * hidden_size,), + None, + is_bias=True, + default_initializer=I.Constant(0.0), + ) + self.bias_ih.stop_gradient = True + + if bias_hh_attr is not False: + self.bias_hh = self.create_parameter( + (3 * hidden_size,), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std), + ) + else: + self.bias_hh = self.create_parameter( + (3 * hidden_size,), + None, + is_bias=True, + default_initializer=I.Constant(0.0), + ) + self.bias_hh.stop_gradient = True self.hidden_size = hidden_size self.input_size = input_size diff --git a/python/paddle/nn/quant/__init__.py b/python/paddle/nn/quant/__init__.py index 85d9650ce400f..0a9ef677b200e 100644 --- a/python/paddle/nn/quant/__init__.py +++ b/python/paddle/nn/quant/__init__.py @@ -12,22 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .functional_layers import FloatFunctionalLayer # noqa: F401 -from .functional_layers import add # noqa: F401 -from .functional_layers import subtract # noqa: F401 -from .functional_layers import multiply # noqa: F401 -from .functional_layers import divide # noqa: F401 -from .functional_layers import reshape # noqa: F401 -from .functional_layers import transpose # noqa: F401 -from .functional_layers import concat # noqa: F401 -from .functional_layers import flatten # noqa: F401 -from .functional_layers import matmul # noqa: F401 -from .quantized_linear import weight_only_linear # noqa: F401 -from .quantized_linear import llm_int8_linear # noqa: F401 -from .quantized_linear import weight_quantize # noqa: F401 -from .quantized_linear import weight_dequantize # noqa: F401 +from . import qat # noqa: F401 +from .functional_layers import ( # noqa: F401 + FloatFunctionalLayer, + add, + concat, + divide, + flatten, + matmul, + multiply, + reshape, + subtract, + transpose, +) from .quant_layers import QuantStub # noqa: F401 -from . import qat +from .quantized_linear import ( + llm_int8_linear, + weight_dequantize, + weight_only_linear, + weight_quantize, +) from .stub import Stub __all__ = [ diff --git a/python/paddle/nn/quant/format.py b/python/paddle/nn/quant/format.py index caa8d3d542cf6..943f91a164355 100644 --- a/python/paddle/nn/quant/format.py +++ b/python/paddle/nn/quant/format.py @@ -18,7 +18,8 @@ import paddle from paddle import _legacy_C_ops as _C_ops from paddle.framework import in_dynamic_mode -from paddle.nn import Layer + +from ..layer.layers import Layer class LinearQuanterDequanter(Layer): diff --git a/python/paddle/nn/quant/functional_layers.py b/python/paddle/nn/quant/functional_layers.py index 3a0fafe6b6ad1..670984fe4f9c7 100644 --- a/python/paddle/nn/quant/functional_layers.py +++ b/python/paddle/nn/quant/functional_layers.py @@ -13,7 +13,7 @@ # limitations under the License. from ...tensor import linalg, manipulation, math -from .. import Layer +from ..layer.layers import Layer __all__ = [] diff --git a/python/paddle/nn/quant/lsq.py b/python/paddle/nn/quant/lsq.py index 4fa8f55266a38..a4adc4f05b412 100644 --- a/python/paddle/nn/quant/lsq.py +++ b/python/paddle/nn/quant/lsq.py @@ -17,10 +17,11 @@ import paddle from paddle.autograd import PyLayer from paddle.framework import ParamAttr -from paddle.nn import Layer from paddle.nn.initializer import Constant from paddle.utils import unique_name +from ..layer.layers import Layer + def round(x): sign = paddle.sign(x) diff --git a/python/paddle/nn/quant/qat/conv.py b/python/paddle/nn/quant/qat/conv.py index 1cf33a8bcb344..2bb3fefe1d642 100644 --- a/python/paddle/nn/quant/qat/conv.py +++ b/python/paddle/nn/quant/qat/conv.py @@ -14,9 +14,9 @@ """ Layers used for QAT. """ -from paddle.nn import Layer from paddle.nn import functional as F +from ...layer.layers import Layer from ..format import ConvertibleQuantedLayer diff --git a/python/paddle/nn/quant/qat/linear.py b/python/paddle/nn/quant/qat/linear.py index 39b177f2c2495..2b350912e26dc 100644 --- a/python/paddle/nn/quant/qat/linear.py +++ b/python/paddle/nn/quant/qat/linear.py @@ -13,9 +13,9 @@ # limitations under the License. -from paddle.nn import Layer from paddle.nn import functional as F +from ...layer.layers import Layer from ..format import ConvertibleQuantedLayer diff --git a/python/paddle/nn/quant/quant_layers.py b/python/paddle/nn/quant/quant_layers.py index a83cbef801f94..e55deb9a1d080 100644 --- a/python/paddle/nn/quant/quant_layers.py +++ b/python/paddle/nn/quant/quant_layers.py @@ -20,12 +20,13 @@ from paddle.base.framework import _create_tensor from paddle.base.log_helper import get_logger from paddle.framework import ParamAttr, core -from paddle.nn import Layer from paddle.nn import functional as F from paddle.nn.initializer import Constant from paddle.nn.quant.lsq import FakeQuantActLSQPlus, FakeQuantWeightLSQPlus from paddle.utils import unique_name +from ..layer.layers import Layer + __all__ = [ 'FakeQuantAbsMax', 'FakeQuantMovingAverageAbsMax', diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py index b60a9e6818c8e..8f962da6b6766 100644 --- a/python/paddle/nn/quant/quantized_linear.py +++ b/python/paddle/nn/quant/quantized_linear.py @@ -164,7 +164,7 @@ def weight_only_linear( 'weight': [weight], 'weight_scale': [weight_scale], } - if bias: + if bias is not None: inputs["bias"] = [bias] attrs = {'weight_dtype': weight_dtype} diff --git a/python/paddle/nn/quant/stub.py b/python/paddle/nn/quant/stub.py index 7e75889a4a037..314319caa707d 100644 --- a/python/paddle/nn/quant/stub.py +++ b/python/paddle/nn/quant/stub.py @@ -14,7 +14,7 @@ """ Define stub used in quantization.""" -from paddle.nn import Layer +from ..layer.layers import Layer class Stub(Layer): diff --git a/python/paddle/nn/utils/__init__.py b/python/paddle/nn/utils/__init__.py index 2d255055d8cf5..2f6b76db52008 100644 --- a/python/paddle/nn/utils/__init__.py +++ b/python/paddle/nn/utils/__init__.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .clip_grad_norm_ import clip_grad_norm_ +from .clip_grad_value_ import clip_grad_value_ from .spectral_norm_hook import spectral_norm -from .weight_norm_hook import weight_norm, remove_weight_norm # noqa: F401 -from .transform_parameters import ( +from .transform_parameters import ( # noqa: F401 + _stride_column, parameters_to_vector, vector_to_parameters, - _stride_column, -) # noqa: F401 -from .clip_grad_norm_ import clip_grad_norm_ # noqa: F401 -from .clip_grad_value_ import clip_grad_value_ # noqa: F401 +) +from .weight_norm_hook import remove_weight_norm, weight_norm __all__ = [ 'weight_norm', diff --git a/python/paddle/nn/utils/transform_parameters.py b/python/paddle/nn/utils/transform_parameters.py index 7cb628565cff9..8db65d61bb5ba 100644 --- a/python/paddle/nn/utils/transform_parameters.py +++ b/python/paddle/nn/utils/transform_parameters.py @@ -121,6 +121,7 @@ def parameters_to_vector(parameters, name=None): ) for i, param in enumerate(parameters): _inplace_reshape_dygraph(param, origin_shapes[i]) + out.stop_gradient = False return out diff --git a/python/paddle/onnx/__init__.py b/python/paddle/onnx/__init__.py index 8853e78bf3d80..879907beab073 100644 --- a/python/paddle/onnx/__init__.py +++ b/python/paddle/onnx/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .export import export # noqa: F401 +from .export import export __all__ = ['export'] diff --git a/python/paddle/optimizer/__init__.py b/python/paddle/optimizer/__init__.py index af86573905273..bf8d63b217123 100644 --- a/python/paddle/optimizer/__init__.py +++ b/python/paddle/optimizer/__init__.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .optimizer import Optimizer # noqa: F401 -from .adagrad import Adagrad # noqa: F401 -from .adam import Adam # noqa: F401 -from .adamw import AdamW # noqa: F401 -from .adamax import Adamax # noqa: F401 -from .rmsprop import RMSProp # noqa: F401 -from .adadelta import Adadelta # noqa: F401 -from .sgd import SGD # noqa: F401 -from .momentum import Momentum # noqa: F401 -from .lamb import Lamb # noqa: F401 -from .lbfgs import LBFGS # noqa: F401 from . import lr # noqa: F401 +from .adadelta import Adadelta +from .adagrad import Adagrad +from .adam import Adam +from .adamax import Adamax +from .adamw import AdamW +from .lamb import Lamb +from .lbfgs import LBFGS +from .momentum import Momentum +from .optimizer import Optimizer +from .rmsprop import RMSProp +from .sgd import SGD __all__ = [ 'Optimizer', diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 37a46f53707f1..12db1f607da02 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -46,6 +46,7 @@ 'OneCycleLR', 'CyclicLR', 'LinearLR', + 'CosineAnnealingWarmRestarts', ] @@ -225,9 +226,7 @@ def set_state_dict(self, state_dict): self.__dict__[key] = state_dict[key] else: raise RuntimeError( - "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".format( - key - ) + f"Please check whether state_dict is correct for optimizer. Can't find [ {key} ] in state_dict" ) if len(state_dict) > len(self.keys): warnings.warn( @@ -2349,6 +2348,168 @@ def get_lr(self): return self.last_lr * factor +class CosineAnnealingWarmRestarts(LRScheduler): + r""" + Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` + is the number of epochs since the last restart and :math:`T_{i}` is the number + of epochs between two warm restarts in SGDR: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) + + When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. + When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. + + It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts `_. + + Args: + learning_rate (float): Initial learning rate. + T_0 (int): Number of iterations for the first restart. + T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. + eta_min (float, optional): Minimum learning rate. Default: 0. + last_epoch (int, optional): The index of last epoch. Default: -1, means initial learning rate. + verbose (bool, optional): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Returns: + ``CosineAnnealingWarmRestarts`` instance to schedule learning rate. + + Examples: + .. code-block:: python + :name: code-example1 + + >>> import paddle + >>> import numpy as np + >>> # train on default dynamic graph mode + >>> linear = paddle.nn.Linear(10, 10) + >>> scheduler = paddle.optimizer.lr.CosineAnnealingWarmRestarts(learning_rate=0.5, T_0=1, T_mult=2, verbose=True) + >>> adam = paddle.optimizer.Adam(learning_rate=scheduler, parameters=linear.parameters()) + >>> for epoch in range(10): + ... for batch_id in range(10): + ... x = paddle.uniform([10, 10]) + ... out = linear(x) + ... loss = paddle.mean(out) + ... loss.backward() + ... adam.step() + ... adam.clear_grad() + ... scheduler.step(epoch) # You should update learning rate each step + + .. code-block:: python + :name: code-example2 + + >>> import paddle + >>> import numpy as np + >>> paddle.enable_static() + >>> main_prog = paddle.static.Program() + >>> start_prog = paddle.static.Program() + >>> with paddle.static.program_guard(main_prog, start_prog): + ... x = paddle.static.data(name='x', shape=[None, 4, 5]) + ... y = paddle.static.data(name='y', shape=[None, 4, 5]) + ... z = paddle.static.nn.fc(x, 100) + ... loss = paddle.mean(z) + ... scheduler = paddle.optimizer.lr.CosineAnnealingWarmRestarts(learning_rate=0.5, T_0=1, T_mult=2,verbose=True) + ... sgd = paddle.optimizer.SGD(learning_rate=scheduler) + ... sgd.minimize(loss) + >>> exe = paddle.static.Executor() + >>> exe.run(start_prog) + >>> for epoch in range(10): + ... for batch_id in range(10): + ... out = exe.run( + ... main_prog, + ... feed={ + ... 'x': np.random.randn(3, 4, 5).astype('float32'), + ... 'y': np.random.randn(3, 4, 5).astype('float32') + ... }, + ... fetch_list=loss.name) + ... scheduler.step(epoch) # You should update learning rate each step + """ + + def __init__( + self, + learning_rate, + T_0, + T_mult=1, + eta_min=0, + last_epoch=-1, + verbose=False, + ): + if T_0 <= 0 or not isinstance(T_0, int): + raise ValueError(f"Expected positive integer T_0, but got {T_0}") + if T_mult < 1 or not isinstance(T_mult, int): + raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}") + self.T_0 = T_0 + self.T_i = T_0 + self.T_mult = T_mult + self.eta_min = eta_min + self.T_cur = last_epoch + super().__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + return ( + self.eta_min + + (self.base_lr - self.eta_min) + * (1 + math.cos(math.pi * self.T_cur / self.T_i)) + / 2 + ) + + def step(self, epoch=None): + """ + step should be called after `optimizer.step()` . It will update the learning rate in optimizer. + The new learning rate will take effect on next epoch. + + Args: + epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1. + + Returns: + None + + Examples: + Please refer to the example of current LRScheduler. + """ + + if epoch is None and self.last_epoch < 0: + epoch = 0 + + if epoch is None: + epoch = self.last_epoch + 1 + self.T_cur = self.T_cur + 1 + if self.T_cur >= self.T_i: + self.T_cur = self.T_cur - self.T_i + self.T_i = self.T_i * self.T_mult + else: + if epoch < 0: + raise ValueError( + f"Expected non-negative epoch, but got {epoch}" + ) + if epoch >= self.T_0: + if self.T_mult == 1: + self.T_cur = epoch % self.T_0 + else: + n = int( + math.log( + (epoch / self.T_0 * (self.T_mult - 1) + 1), + self.T_mult, + ) + ) + self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / ( + self.T_mult - 1 + ) + self.T_i = self.T_0 * self.T_mult ** (n) + else: + self.T_i = self.T_0 + self.T_cur = epoch + self.last_epoch = math.floor(epoch) + self.last_lr = self.get_lr() + if self.verbose: + print( + 'Epoch {}: {} set learning rate to {}.'.format( + self.last_epoch, self.__class__.__name__, self.last_lr + ) + ) + + def autoincreased_step_counter(counter_name=None, begin=1, step=1): """ :api_attr: Static Graph diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index ced22d571b8e4..0752b5894b335 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -312,11 +312,11 @@ def state_dict(self): Examples: .. code-block:: python - import paddle - emb = paddle.nn.Embedding(10, 10) + >>> import paddle + >>> emb = paddle.nn.Embedding(10, 10) - adam = paddle.optimizer.Adam(0.001, parameters=emb.parameters()) - state_dict = adam.state_dict() + >>> adam = paddle.optimizer.Adam(0.001, parameters=emb.parameters()) + >>> state_dict = adam.state_dict() ''' state_dict = {} @@ -1154,7 +1154,7 @@ def _create_optimization_pass( end = len(target_block.ops) return target_block._slice_ops(start, end) - def _new_ir_create_optimization_pass( + def _pir_create_optimization_pass( self, parameters_and_grads, param_group_idx=0 ): """Add optimization operators to update gradients to tensors. @@ -1416,7 +1416,7 @@ def _apply_optimize( params_grads['params'], self.regularization ) if in_pir_mode(): - optimize_ops = self._new_ir_create_optimization_pass( + optimize_ops = self._pir_create_optimization_pass( params_grads, param_group_idx=param_group_idx ) else: diff --git a/python/paddle/pir/__init__.py b/python/paddle/pir/__init__.py index 145eb103918bf..b2a51d97cef90 100644 --- a/python/paddle/pir/__init__.py +++ b/python/paddle/pir/__init__.py @@ -13,30 +13,29 @@ # limitations under the License. from paddle.base.libpaddle.pir import ( # noqa: F401 - Program, Block, Operation, - Value, OpOperand, OpResult, + PassManager, + Program, + Type, + Value, + check_unregistered_ops, fake_op_result, is_fake_op_result, - Type, -) -from paddle.base.libpaddle.pir import ( # noqa: F401 - translate_to_new_ir, - translate_to_new_ir_with_param_map, + register_paddle_dialect, + reset_insertion_point_to_end, + reset_insertion_point_to_start, set_global_program, set_insertion_point, - reset_insertion_point_to_start, - reset_insertion_point_to_end, - check_unregistered_ops, - register_paddle_dialect, - PassManager, + translate_to_pir, + translate_to_pir_with_param_map, ) -from . import core +from . import core # noqa: F401 -from .math_op_patch import monkey_patch_opresult +from .math_op_patch import monkey_patch_opresult # noqa: F401 +from .program_patch import monkey_patch_program # noqa: F401 __all__ = [] diff --git a/python/paddle/pir/core.py b/python/paddle/pir/core.py index 2fcf73cd10fa8..1555fbfdec57f 100644 --- a/python/paddle/pir/core.py +++ b/python/paddle/pir/core.py @@ -298,7 +298,7 @@ def _convert_into_opresult(tensor): """ import paddle from paddle.base import core, framework - from paddle.jit.newir_dy2static.parameter_recorder import ( + from paddle.jit.pir_dy2static.parameter_recorder import ( _global_parameter_recorder, ) @@ -308,14 +308,10 @@ def _convert_into_opresult(tensor): is_persistable = True if new_var is not None: assert isinstance(new_var, framework.Variable) - elif isinstance(tensor, framework.EagerParamBase): - # Convert EagerParamBase into Parameter with same attributes in dy2stat. + else: new_var = _global_parameter_recorder.get( paddle.pir.core.default_main_program(), tensor ) - else: - # TODO(xiongkun): add this logic, we should call paddle.data() to create a non-parameter variable. - raise NotImplementedError("Not implemented, for buffers.") # add param into parameter recorder to collect all the params used in this program. return new_var else: diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index 2f52a5f8502c7..add6ea93b96ba 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -13,12 +13,23 @@ # limitations under the License. +import warnings + from paddle.base.libpaddle import DataType from . import OpResult _already_patch_opresult = False +_supported_int_dtype_ = [ + DataType.BOOL, + DataType.UINT8, + DataType.INT8, + DataType.INT16, + DataType.INT32, + DataType.INT64, +] + def create_tensor_with_batchsize(ref_var, value, dtype): assert isinstance(ref_var, OpResult) @@ -54,14 +65,143 @@ def safe_get_dtype(var): raise ValueError("Cannot get data type from var") return dtype - _supported_int_dtype_ = [ - DataType.BOOL, - DataType.UINT8, - DataType.INT8, - DataType.INT16, - DataType.INT32, - DataType.INT64, - ] + def place(self): + """ + OpResult don't have 'place' interface in static graph mode + But this interface can greatly facilitate dy2static. + So we give a warnning here and return None. + """ + warnings.warn( + "OpResult do not have 'place' interface for pir graph mode, try not to use it. None will be returned." + ) + + @property + def _ndim(self): + """ + Returns the dimension of current OpResult + + Returns: + the dimension + + Examples: + .. code-block:: python + + >>> import paddle + + >>> paddle.enable_static() + + >>> # create a static OpResult + >>> x = paddle.static.data(name='x', shape=[3, 2, 1]) + >>> # print the dimension of the OpResult + >>> print(x.ndim) + 3 + """ + return len(self.shape) + + def ndimension(self): + """ + Returns the dimension of current OpResult + + Returns: + the dimension + + Examples: + .. code-block:: python + + >>> import paddle + + >>> paddle.enable_static() + + >>> # create a static OpResult + >>> x = paddle.static.data(name='x', shape=[3, 2, 1]) + >>> # print the dimension of the OpResult + >>> print(x.ndimension()) + 3 + """ + return len(self.shape) + + def dim(self): + """ + Returns the dimension of current OpResult + + Returns: + the dimension + + Examples: + .. code-block:: python + + >>> import paddle + + >>> paddle.enable_static() + + >>> # create a static OpResult + >>> x = paddle.static.data(name='x', shape=[3, 2, 1]) + >>> # print the dimension of the OpResult + >>> print(x.dim()) + 3 + """ + return len(self.shape) + + def _item(self): + """ + In order to be compatible with the item interface introduced by the dynamic graph, it does nothing but returns self. + It will check that the shape must be a 1-D tensor + """ + if len(self.shape) > 1: + raise TypeError( + f"Required input var should be 1-D OpResult, but received {self.shape}" + ) + return self + + def astype(self, dtype): + """ + **Notes**: + + Cast a OpResult to a specified data type. + + Args: + + self(OpResult): The source OpResult + + dtype: The target data type + + Returns: + OpResult: OpResult with new dtype + + Examples: + In Static Graph Mode: + + .. code-block:: python + + >>> import paddle + >>> paddle.enable_static() + >>> startup_prog = paddle.static.Program() + >>> main_prog = paddle.static.Program() + >>> with paddle.static.program_guard(startup_prog, main_prog): + ... original_value = paddle.static.data(name = "new_value", shape=[2,2], dtype='float32') + ... new_value = original_value.astype('int64') + ... print("new value's dtype is: {}".format(new_value.dtype)) + ... + new OpResult's dtype is: paddle.int64 + + """ + from paddle import _C_ops + + if not isinstance(dtype, DataType): + dtype = paddle.pir.core.convert_np_dtype_to_dtype_(dtype) + return _C_ops.cast(self, dtype) + + def _scalar_add_(var, value): + return paddle.scale(var, 1.0, value) + + def _scalar_sub_(var, value): + return paddle.scale(var, 1.0, -value) + + def _scalar_rsub_(var, value): + return paddle.scale(var, -1.0, value) + + def _scalar_mul_(var, value): + return paddle.scale(var, value, 0.0) def _scalar_div_(var, value): return paddle.scale(var, 1.0 / value, 0.0) @@ -78,7 +218,7 @@ def __impl__(self, other_var): if isinstance(other_var, float): # in all cases(+, -, *, /, **, //, %), we need cast tensor.dtype to float if self.dtype in _supported_int_dtype_: - paddle.cast(self, DataType.FLOAT32) + self = astype(self, DataType.FLOAT32) # here use `scale` replace `elementwise` to get better performance # but only +, -, *, / can use this method if scalar_method is not None: @@ -121,14 +261,16 @@ def __impl__(self, other_var): break else: # when break is not triggered, enter the else branch - other_var_opresult = paddle.fill_constant( - self.shape, - lhs_dtype, - other_var, + other_var_opresult = ( + paddle.tensor.creation.fill_constant( + self.shape, + lhs_dtype, + other_var, + ) ) else: # add fill_op to current_block - other_var_opresult = paddle.fill_constant( + other_var_opresult = paddle.tensor.creation.fill_constant( [], lhs_dtype, other_var, @@ -147,7 +289,9 @@ def __impl__(self, other_var): python_api == paddle.divide ) and self.dtype in _supported_int_dtype_: self = paddle.cast(self, DataType.FLOAT32) - other_var = paddle.cast(other_var_opresult, DataType.FLOAT32) + other_var_opresult = paddle.cast( + other_var_opresult, DataType.FLOAT32 + ) out = python_api(self, other_var_opresult) return out @@ -163,48 +307,51 @@ def __impl__(self, other_var): __impl__.__name__ = method_name return __impl__ - def astype(self, dtype): - """ - **Notes**: - - Cast a OpResult to a specified data type. - - Args: - - self(OpResult): The source OpResult - - dtype: The target data type - - Returns: - OpResult: OpResult with new dtype - - Examples: - In Static Graph Mode: - - .. code-block:: python - - >>> import paddle - >>> paddle.enable_static() - >>> startup_prog = paddle.static.Program() - >>> main_prog = paddle.static.Program() - >>> with paddle.static.program_guard(startup_prog, main_prog): - ... original_value = paddle.static.data(name = "new_value", shape=[2,2], dtype='float32') - ... new_value = original_value.astype('int64') - ... print("new value's dtype is: {}".format(new_value.dtype)) - ... - new OpResult's dtype is: paddle.int64 - - """ - from paddle import _C_ops - - if not isinstance(dtype, DataType): - dtype = paddle.pir.core.convert_np_dtype_to_dtype_(dtype) - return _C_ops.cast(self, dtype) - import paddle opresult_methods = [ + ('place', place), + ('item', _item), + ('dim', dim), + ('ndimension', ndimension), + ('ndim', _ndim), ('astype', astype), + ( + '__add__', + _binary_creator_('__add__', paddle.tensor.add, False, _scalar_add_), + ), + # a+b == b+a. Do not need to reverse explicitly + ( + '__radd__', + _binary_creator_( + '__radd__', paddle.tensor.add, False, _scalar_add_ + ), + ), + ( + '__sub__', + _binary_creator_( + '__sub__', paddle.tensor.subtract, False, _scalar_sub_ + ), + ), + ( + '__rsub__', + _binary_creator_( + '__rsub__', paddle.tensor.subtract, True, _scalar_rsub_ + ), + ), + ( + '__mul__', + _binary_creator_( + '__mul__', paddle.tensor.multiply, False, _scalar_mul_ + ), + ), + # a*b == b*a. Do not need to reverse explicitly + ( + '__rmul__', + _binary_creator_( + '__rmul__', paddle.tensor.multiply, False, _scalar_mul_ + ), + ), ( '__div__', _binary_creator_( @@ -225,6 +372,56 @@ def astype(self, dtype): '__rtruediv__', _binary_creator_('__rtruediv__', paddle.tensor.divide, True, None), ), + ( + '__pow__', + _binary_creator_('__pow__', paddle.tensor.pow, False, None), + ), + ( + '__rpow__', + _binary_creator_('__rpow__', paddle.tensor.pow, True, None), + ), + ( + '__floordiv__', + _binary_creator_( + '__floordiv__', paddle.tensor.floor_divide, False, None + ), + ), + ( + '__mod__', + _binary_creator_('__mod__', paddle.tensor.remainder, False, None), + ), + ( + '__matmul__', + _binary_creator_('__matmul__', paddle.tensor.matmul, False, None), + ), + # for logical compare + # TODO(gouzil): Open after deleting c++ logic + # ( + # '__eq__', + # _binary_creator_('__eq__', paddle.tensor.equal, False, None), + # ), + ( + '__ne__', + _binary_creator_('__ne__', paddle.tensor.not_equal, False, None), + ), + ( + '__lt__', + _binary_creator_('__lt__', paddle.tensor.less_than, False, None), + ), + ( + '__le__', + _binary_creator_('__le__', paddle.tensor.less_equal, False, None), + ), + ( + '__gt__', + _binary_creator_('__gt__', paddle.tensor.greater_than, False, None), + ), + ( + '__ge__', + _binary_creator_( + '__ge__', paddle.tensor.greater_equal, False, None + ), + ), ] global _already_patch_opresult @@ -244,6 +441,12 @@ def astype(self, dtype): if method_impl: setattr(OpResult, method_name, method_impl) + # Bit operation symbol + for magic_method, origin_method in paddle.tensor.magic_method_func: + impl = getattr(paddle.tensor, origin_method, None) + if impl: + setattr(OpResult, magic_method, impl) + # Handling __getitem__ from ..base.variable_index import _getitem_static diff --git a/python/paddle/pir/program_patch.py b/python/paddle/pir/program_patch.py new file mode 100644 index 0000000000000..4de46a647259a --- /dev/null +++ b/python/paddle/pir/program_patch.py @@ -0,0 +1,34 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import Program + +_already_patch_program = False + +global_prog_seed = 0 + + +def monkey_patch_program(): + def global_seed(self, seed=0): + global global_prog_seed + global_prog_seed = seed + self._seed = global_prog_seed + + Program.global_seed = global_seed + global global_prog_seed + Program._seed = global_prog_seed + + global _already_patch_program + if not _already_patch_program: + _already_patch_program = True diff --git a/python/paddle/pir_utils.py b/python/paddle/pir_utils.py index a2b5244cad7c5..601b4d27688fa 100644 --- a/python/paddle/pir_utils.py +++ b/python/paddle/pir_utils.py @@ -62,9 +62,7 @@ def _switch_to_pir(self): if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[ "FLAGS_enable_pir_api" ]: - paddle.framework.set_flags( - {"FLAGS_enable_new_ir_in_executor": True} - ) + paddle.framework.set_flags({"FLAGS_enable_pir_in_executor": True}) paddle.pir.register_paddle_dialect() paddle.base.Program = paddle.pir.Program @@ -88,9 +86,7 @@ def _switch_to_old_ir(self): if not paddle.base.framework.get_flags("FLAGS_enable_pir_api")[ "FLAGS_enable_pir_api" ]: - paddle.framework.set_flags( - {"FLAGS_enable_new_ir_in_executor": False} - ) + paddle.framework.set_flags({"FLAGS_enable_pir_in_executor": False}) paddle.base.Program = self.old_Program paddle.base.program_guard = self.old_program_guard diff --git a/python/paddle/quantization/config.py b/python/paddle/quantization/config.py index 28feb8c6b087f..bafc24488f089 100644 --- a/python/paddle/quantization/config.py +++ b/python/paddle/quantization/config.py @@ -127,7 +127,7 @@ def add_layer_config( >>> quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) >>> q_config = QuantConfig(activation=None, weight=None) >>> q_config.add_layer_config([model.fc], activation=quanter, weight=quanter) - >>> # doctest: +SKIP + >>> # doctest: +SKIP('random memory address') >>> print(q_config) Global config: None @@ -176,7 +176,7 @@ def add_name_config( >>> quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) >>> q_config = QuantConfig(activation=None, weight=None) >>> q_config.add_name_config([model.fc.full_name()], activation=quanter, weight=quanter) - >>> # doctest: +SKIP + >>> # doctest: +SKIP('random memory address') >>> print(q_config) Global config: None @@ -226,7 +226,7 @@ def add_type_config( >>> quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) >>> q_config = QuantConfig(activation=None, weight=None) >>> q_config.add_type_config([Linear], activation=quanter, weight=quanter) - >>> # doctest: +SKIP + >>> # doctest: +SKIP('random memory address') >>> print(q_config) Global config: None diff --git a/python/paddle/quantization/factory.py b/python/paddle/quantization/factory.py index b0ef906220186..eb8916460975c 100644 --- a/python/paddle/quantization/factory.py +++ b/python/paddle/quantization/factory.py @@ -83,7 +83,7 @@ def quanter(class_name): Examples: .. code-block:: python - >>> # doctest: +SKIP + >>> # doctest: +SKIP('need 2 file to run example') >>> # Given codes in ./customized_quanter.py >>> from paddle.quantization import quanter >>> from paddle.quantization import BaseQuanter diff --git a/python/paddle/reader/__init__.py b/python/paddle/reader/__init__.py index 4ce0ed643c343..d4ee721541e81 100644 --- a/python/paddle/reader/__init__.py +++ b/python/paddle/reader/__init__.py @@ -63,15 +63,17 @@ """ -from paddle.reader.decorator import map_readers # noqa: F401 -from paddle.reader.decorator import shuffle # noqa: F401 -from paddle.reader.decorator import xmap_readers # noqa: F401 -from paddle.reader.decorator import firstn # noqa: F401 -from paddle.reader.decorator import buffered # noqa: F401 -from paddle.reader.decorator import compose # noqa: F401 -from paddle.reader.decorator import cache # noqa: F401 -from paddle.reader.decorator import ComposeNotAligned # noqa: F401 -from paddle.reader.decorator import chain # noqa: F401 -from paddle.reader.decorator import multiprocess_reader # noqa: F401 +from paddle.reader.decorator import ( # noqa: F401 + ComposeNotAligned, + buffered, + cache, + chain, + compose, + firstn, + map_readers, + multiprocess_reader, + shuffle, + xmap_readers, +) __all__ = [] diff --git a/python/paddle/sparse/nn/functional/__init__.py b/python/paddle/sparse/nn/functional/__init__.py index 0509352a67c6d..5fc68de914bd5 100644 --- a/python/paddle/sparse/nn/functional/__init__.py +++ b/python/paddle/sparse/nn/functional/__init__.py @@ -12,16 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .conv import conv2d # noqa: F401 -from .conv import conv3d # noqa: F401 -from .conv import subm_conv2d # noqa: F401 -from .conv import subm_conv3d # noqa: F401 -from .transformer import attention # noqa: F401 -from .pooling import max_pool3d # noqa: F401 -from .activation import relu # noqa: F401 -from .activation import relu6 # noqa: F401 -from .activation import leaky_relu # noqa: F401 -from .activation import softmax # noqa: F401 +from .activation import leaky_relu, relu, relu6, softmax +from .conv import conv2d, conv3d, subm_conv2d, subm_conv3d +from .pooling import max_pool3d +from .transformer import attention __all__ = [ 'conv2d', diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index 57c4abec6d8d0..657959f7ffcaa 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -13,71 +13,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - -from . import amp # noqa: F401 -from . import nn # noqa: F401 - -from .nn.common import py_func # noqa: F401 -from .nn.common import ExponentialMovingAverage # noqa: F401 - -from .io import save_inference_model # noqa: F401 -from .io import load_inference_model # noqa: F401 -from .io import deserialize_persistables # noqa: F401 -from .io import serialize_persistables # noqa: F401 -from .io import deserialize_program # noqa: F401 -from .io import serialize_program # noqa: F401 -from .io import load_from_file # noqa: F401 -from .io import save_to_file # noqa: F401 -from .io import normalize_program # noqa: F401 -from .io import is_persistable # noqa: F401 -from .io import save_vars # noqa: F401 -from .io import load_vars # noqa: F401 -from .io import save # noqa: F401 -from .io import load # noqa: F401 -from .io import load_program_state # noqa: F401 -from .io import set_program_state # noqa: F401 from ..base import Scope # noqa: F401 -from .input import data # noqa: F401 -from .input import InputSpec # noqa: F401 -from .input import setitem # noqa: F401 - -from ..tensor.creation import create_parameter # noqa: F401 -from ..tensor.creation import create_global_var # noqa: F401 - -from ..base.executor import Executor # noqa: F401 -from ..base.executor import global_scope # noqa: F401 -from ..base.executor import scope_guard # noqa: F401 -from ..base.backward import append_backward # noqa: F401 -from ..base.backward import gradients # noqa: F401 -from ..base.compiler import BuildStrategy # noqa: F401 -from ..base.compiler import CompiledProgram # noqa: F401 -from ..base.compiler import IpuCompiledProgram # noqa: F401 -from ..base.compiler import IpuStrategy # noqa: F401 -from ..base.compiler import ExecutionStrategy # noqa: F401 -from ..base.framework import default_main_program # noqa: F401 -from ..base.framework import default_startup_program # noqa: F401 -from ..base.framework import device_guard # noqa: F401 - -from ..base.framework import name_scope # noqa: F401 -from ..base.framework import cpu_places # noqa: F401 -from ..base.framework import cuda_places # noqa: F401 -from ..base.framework import xpu_places # noqa: F401 -from ..base.framework import Variable # noqa: F401 -from ..base.framework import Operator # noqa: F401 -from ..base.framework import Parameter # noqa: F401 -from ..base.framework import ipu_shard_guard # noqa: F401 -from ..base.framework import set_ipu_shard # noqa: F401 -from .nn.control_flow import Print # noqa: F401 -from ..base.param_attr import WeightNormParamAttr # noqa: F401 - - -from .nn.metric import auc # noqa: F401 -from .nn.metric import accuracy # noqa: F401 -from .nn.metric import ctr_metric_bundle # noqa: F401 - -from ..base.framework import program_guard # noqa: F401 -from ..base.framework import Program # noqa: F401 +from ..base.backward import append_backward, gradients +from ..base.compiler import ( + BuildStrategy, + CompiledProgram, + ExecutionStrategy, + IpuCompiledProgram, + IpuStrategy, +) +from ..base.executor import Executor, global_scope, scope_guard +from ..base.framework import ( # noqa: F401 + Operator, + Parameter, + Program, + Variable, + cpu_places, + cuda_places, + default_main_program, + default_startup_program, + device_guard, + ipu_shard_guard, + name_scope, + program_guard, + set_ipu_shard, + xpu_places, +) +from ..base.param_attr import WeightNormParamAttr +from ..tensor.creation import create_global_var, create_parameter +from . import amp, nn # noqa: F401 +from .input import InputSpec, data, setitem # noqa: F401 +from .io import ( # noqa: F401 + deserialize_persistables, + deserialize_program, + is_persistable, + load, + load_from_file, + load_inference_model, + load_program_state, + load_vars, + normalize_program, + save, + save_inference_model, + save_to_file, + save_vars, + serialize_persistables, + serialize_program, + set_program_state, +) +from .nn.common import ExponentialMovingAverage, py_func +from .nn.control_flow import Print +from .nn.metric import accuracy, auc, ctr_metric_bundle __all__ = [ 'append_backward', diff --git a/python/paddle/static/input.py b/python/paddle/static/input.py index 518fc8d6519cd..745af4b5e5f9b 100644 --- a/python/paddle/static/input.py +++ b/python/paddle/static/input.py @@ -143,7 +143,7 @@ def _reset_data_op_insertion_point(): need_check_feed=True, ) - is_pir_mode = os.environ.get("FLAGS_enable_new_ir_in_executor", None) + is_pir_mode = os.environ.get("FLAGS_enable_pir_in_executor", None) if evaluate_flag(is_pir_mode): helper = LayerHelper('data', **locals()) if not isinstance(dtype, core.VarDesc.VarType): diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index 943e8525ba466..4bc32952b958a 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -950,9 +950,7 @@ def load_inference_model(path_prefix, executor, **kwargs): params_path = os.path.join(path_prefix, params_filename) _logger.warning( "The old way to load inference model is deprecated. Please specify path_prefix." - " model path: {}, params path: {}".format( - model_path, params_path - ) + f" model path: {model_path}, params path: {params_path}" ) program_bytes = load_from_file(model_path) diff --git a/python/paddle/static/nn/__init__.py b/python/paddle/static/nn/__init__.py index f3693e1501c40..4f17de92e7f29 100755 --- a/python/paddle/static/nn/__init__.py +++ b/python/paddle/static/nn/__init__.py @@ -12,52 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .common import fc # noqa: F401 -from .common import batch_norm # noqa: F401 -from .common import instance_norm # noqa: F401 -from .common import data_norm # noqa: F401 -from .common import continuous_value_model # noqa: F401 -from .common import group_norm # noqa: F401 -from .common import deform_conv2d # noqa: F401 -from .common import conv2d # noqa: F401 -from .common import conv3d # noqa: F401 -from .common import conv2d_transpose # noqa: F401 -from .common import conv3d_transpose # noqa: F401 -from .control_flow import ( - case, - while_loop, - switch_case, -) -from .common import bilinear_tensor_product # noqa: F401 -from .common import py_func # noqa: F401 -from .common import row_conv # noqa: F401 -from .common import spectral_norm # noqa: F401 from ...tensor.creation import create_parameter # noqa: F401 -from .loss import nce # noqa: F401 -from .common import prelu # noqa: F401 -from .common import layer_norm # noqa: F401 - - -from .common import embedding # noqa: F401 -from .common import sparse_embedding # noqa: F401 - -from .sequence_lod import sequence_conv # noqa: F401 -from .sequence_lod import sequence_softmax # noqa: F401 -from .sequence_lod import sequence_pool # noqa: F401 -from .sequence_lod import sequence_concat # noqa: F401 -from .sequence_lod import sequence_first_step # noqa: F401 -from .sequence_lod import sequence_last_step # noqa: F401 -from .sequence_lod import sequence_slice # noqa: F401 -from .sequence_lod import sequence_expand # noqa: F401 -from .sequence_lod import sequence_expand_as # noqa: F401 -from .sequence_lod import sequence_pad # noqa: F401 -from .sequence_lod import sequence_unpad # noqa: F401 -from .sequence_lod import sequence_reshape # noqa: F401 -from .sequence_lod import sequence_scatter # noqa: F401 -from .sequence_lod import sequence_enumerate # noqa: F401 -from .sequence_lod import sequence_reverse # noqa: F401 - -from .control_flow import cond +from .common import ( # noqa: F401 + batch_norm, + bilinear_tensor_product, + continuous_value_model, + conv2d, + conv2d_transpose, + conv3d, + conv3d_transpose, + data_norm, + deform_conv2d, + embedding, + fc, + group_norm, + instance_norm, + layer_norm, + prelu, + py_func, + row_conv, + sparse_embedding, + spectral_norm, +) +from .control_flow import case, cond, switch_case, while_loop +from .loss import nce +from .sequence_lod import ( + sequence_concat, + sequence_conv, + sequence_enumerate, + sequence_expand, + sequence_expand_as, + sequence_first_step, + sequence_last_step, + sequence_pad, + sequence_pool, + sequence_reshape, + sequence_reverse, + sequence_scatter, + sequence_slice, + sequence_softmax, + sequence_unpad, +) from .static_pylayer import static_pylayer __all__ = [ diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index af135c4145de5..6af70651c72a0 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -284,17 +284,17 @@ def instance_norm( epsilon(float, Default 1e-05): A value added to the denominator for numerical stability. Default is 1e-5. param_attr(ParamAttr|None|bool, optional): The parameter attribute for Parameter `scale` - of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm - will create ParamAttr as param_attr, the name of scale can be set in ParamAttr. - If the Initializer of the param_attr is not set, the parameter is initialized - with Xavier. If the param_attr is set to False, instance_norm will not create param_attr. - Default: None. + of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm + will create ParamAttr as param_attr, the name of scale can be set in ParamAttr. + If the Initializer of the param_attr is not set, the parameter is initialized + with Xavier. If the param_attr is set to False, instance_norm will not create param_attr. + Default: None. bias_attr(ParamAttr|None|bool, optional): The parameter attribute for the bias of instance_norm. - If it is set to None or one attribute of ParamAttr, instance_norm - will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. - If the Initializer of the bias_attr is not set, the bias is initialized zero. - If the bias_attr is set to False, instance_norm will not create bias_attr. - Default: None. + If it is set to None or one attribute of ParamAttr, instance_norm + will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. + If the Initializer of the bias_attr is not set, the bias is initialized zero. + If the bias_attr is set to False, instance_norm will not create bias_attr. + Default: None. name(string, Default None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2064,9 +2064,7 @@ def _update_padding(padding, data_format): groups = 1 if groups is None else groups if groups <= 0: raise ValueError( - "the groups of conv3d_transpose should be greater than 0. Received groups: {}".format( - groups - ) + f"the groups of conv3d_transpose should be greater than 0. Received groups: {groups}" ) if num_filters % groups != 0: raise ValueError( @@ -2640,19 +2638,19 @@ def batch_norm( Internal Covariate Shift `_ for more details. - :math:input is the input features over a mini-batch. + :math:`input` is the input features over a mini-batch. .. math:: - \\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\ - \ mini-batch\ mean \\\\ - \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ - \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ - \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ - \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ - y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + \mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad &// + \ mini-batch\ mean \\ + \sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - + \mu_{\\beta})^2 \qquad &//\ mini-batch\ variance \\ + \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{ + \sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\ + y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift - moving\_mean = moving\_mean * momentum + mini-batch\_mean * (1. - momentum) \\\\ + moving\_mean = moving\_mean * momentum + mini-batch\_mean * (1. - momentum) \\ moving\_var = moving\_var * momentum + mini-batch\_var * (1. - momentum) @@ -2666,9 +2664,9 @@ def batch_norm( .. math:: - \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ - \\sigma_{\\beta}^{2} + \\epsilon}} \\\\ - y_i &\\gets \\gamma \\hat{x_i} + \\beta + \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{ + \sigma_{\beta}^{2} + \epsilon}} \\ + y_i &\gets \gamma \hat{x_i} + \beta Note: if build_strategy.sync_batch_norm=True, the batch_norm in network will use @@ -2691,14 +2689,14 @@ def batch_norm( numerical stability. Default is 1e-5. param_attr(ParamAttr|None): The parameter attribute for Parameter `scale` of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm - will create ParamAttr as param_attr, the name of scale can be set in ParamAttr. - If the Initializer of the param_attr is not set, the parameter is initialized - with Xavier. Default: None. + will create ParamAttr as param_attr, the name of scale can be set in ParamAttr. + If the Initializer of the param_attr is not set, the parameter is initialized + with Xavier. Default: None. bias_attr(ParamAttr|None): The parameter attribute for the bias of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm - will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. - If the Initializer of the bias_attr is not set, the bias is initialized zero. - Default: None. + will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. + If the Initializer of the bias_attr is not set, the bias is initialized zero. + Default: None. data_layout (str, optional): Specify the data format of the input, and the data format of the output will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index 5a9e9eebdce7a..0cbaca8fa6a00 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -1037,9 +1037,7 @@ def _check_args(branch_index, branch_fns, default): if key in keys_of_fns: raise ValueError( - "The key in 'branch_fns' must be unique, but '{}' appears more than once.".format( - key - ) + f"The key in 'branch_fns' must be unique, but '{key}' appears more than once." ) else: keys_of_fns.append(key) @@ -1344,7 +1342,7 @@ def check_ret_none(seq_true, seq_false, seq_names): def merge_every_var_list(false_vars, true_vars, name): return map_structure(partial(merge_func, name), false_vars, true_vars) - merged_output = list( + merged_output_fns = list( map( merge_every_var_list, _to_sequence_except_dict(false_output), @@ -1352,6 +1350,7 @@ def merge_every_var_list(false_vars, true_vars, name): _to_sequence_except_dict(return_names), ) ) + merged_output = map_structure(lambda fn: fn(), merged_output_fns) merged_output = pack_sequence_as(false_output, flatten(merged_output)) return merged_output @@ -1469,13 +1468,7 @@ def select_input_with_buildin_type(inputs, mask, name): false_var, true_var = inputs - if isinstance(false_var, UndefinedVar) and isinstance( - true_var, UndefinedVar - ): - """None -> UndefinedVar, so the real value is a [None, UndefinedVar] or [None, None], we just return None.""" - return None - - if isinstance(false_var, Variable) and isinstance(true_var, Variable): + def start_select_input(): try: return select_input(inputs, mask) except Exception as e: @@ -1483,11 +1476,20 @@ def select_input_with_buildin_type(inputs, mask, name): f"Exceptions throwed while doing select_input on {name}:\n{e}" ) + if isinstance(false_var, UndefinedVar) and isinstance( + true_var, UndefinedVar + ): + """None -> UndefinedVar, so the real value is a [None, UndefinedVar] or [None, None], we just return None.""" + return lambda: None + + if isinstance(false_var, Variable) and isinstance(true_var, Variable): + return start_select_input + elif isinstance(false_var, support_ret_buildin_type) and isinstance( false_var, type(true_var) ): if false_var == true_var: - return false_var + return lambda: false_var else: inputs = [ to_static_variable(false_var), @@ -1514,12 +1516,6 @@ def select_input_with_buildin_type(inputs, mask, name): isinstance(true_var, UndefinedVar) and isinstance(false_var, (Variable,) + support_ret_buildin_type) ): - - def create_var_if_not_undefined_var(a): - if isinstance(a, UndefinedVar): - return a - return to_static_variable(a) - true_var, false_var = to_static_variable(true_var), to_static_variable( false_var ) @@ -1531,12 +1527,7 @@ def create_var_if_not_undefined_var(a): type(false_var), type(true_var) ) ) - try: - return select_input(inputs, mask) - except Exception as e: - raise RuntimeError( - f"Exceptions throwed while doing select_input on {name}:\n{e}" - ) + return start_select_input def _is_sequence_except_dict(x): diff --git a/python/paddle/static/nn/metric.py b/python/paddle/static/nn/metric.py index f9941c4744723..672bc80ece926 100644 --- a/python/paddle/static/nn/metric.py +++ b/python/paddle/static/nn/metric.py @@ -17,9 +17,14 @@ import numpy as np import paddle -from paddle import _legacy_C_ops +from paddle import _C_ops, _legacy_C_ops from paddle.base.data_feeder import check_variable_and_dtype -from paddle.base.framework import Variable, _create_tensor, in_dygraph_mode +from paddle.base.framework import ( + Variable, + _create_tensor, + in_dygraph_mode, + in_pir_mode, +) from paddle.base.layer_helper import LayerHelper from paddle.nn.initializer import ConstantInitializer @@ -88,6 +93,10 @@ def accuracy(input, label, k=1, correct=None, total=None): topk_out, topk_indices, label, correct, total ) return _acc + elif in_pir_mode(): + topk_out, topk_indices = paddle.topk(input, k=k, sorted=False) + _acc, _, _ = _C_ops.accuracy(topk_out, topk_indices, label) + return _acc helper = LayerHelper("accuracy", **locals()) check_variable_and_dtype( diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 61005132276d9..84c28ce58dca8 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -24,6 +24,7 @@ from .creation import to_tensor # noqa: F401 from .creation import diag # noqa: F401 from .creation import diagflat # noqa: F401 +from .creation import diag_embed # noqa: F401 from .creation import eye # noqa: F401 from .creation import linspace # noqa: F401 from .creation import fill_constant # noqa: F401 @@ -165,6 +166,10 @@ from .manipulation import view # noqa: F401 from .manipulation import view_as # noqa: F401 from .manipulation import unfold # noqa: F401 +from .manipulation import masked_fill # noqa: F401 +from .manipulation import masked_fill_ # noqa: F401 +from .manipulation import index_fill # noqa: F401 +from .manipulation import index_fill_ # noqa: F401 from .math import abs # noqa: F401 from .math import abs_ # noqa: F401 from .math import acos # noqa: F401 @@ -329,6 +334,8 @@ from .math import polygamma_ # noqa: F401 from .math import renorm # noqa: F401 from .math import renorm_ # noqa: F401 +from .math import hypot # noqa: F401 +from .math import hypot_ # noqa: F401 from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 @@ -464,6 +471,8 @@ 'sum', 'nan_to_num', 'nan_to_num_', + 'hypot', + 'hypot_', 'nansum', 'nanmean', 'count_nonzero', @@ -690,6 +699,9 @@ 'i1e', 'polygamma', 'polygamma_', + 'masked_fill', + 'masked_fill_', + 'diag_embed', 'atan2', 'diagflat', 'multinomial', @@ -715,6 +727,8 @@ 'asinh_', 'diag', 'normal_', + 'index_fill', + 'index_fill_', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/attribute.py b/python/paddle/tensor/attribute.py index f3dcaf06cd9bf..8bc7cff200b34 100644 --- a/python/paddle/tensor/attribute.py +++ b/python/paddle/tensor/attribute.py @@ -20,11 +20,7 @@ from paddle import _C_ops from ..base.data_feeder import check_type, check_variable_and_dtype -from ..base.framework import ( - in_dygraph_mode, - in_dynamic_or_pir_mode, - in_pir_mode, -) +from ..base.framework import in_dynamic_or_pir_mode, in_pir_mode from ..common_ops_import import Variable from ..framework import LayerHelper, core from .creation import _complex_to_real_dtype, assign @@ -300,7 +296,7 @@ def real(x, name=None): [[1., 2., 3.], [4., 5., 6.]]) """ - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.real(x) else: check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'real') @@ -348,7 +344,7 @@ def imag(x, name=None): [[6., 5., 4.], [3., 2., 1.]]) """ - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.imag(x) else: check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'imag') diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index ff778c035e6e2..71c7d3e8866ba 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -36,6 +36,7 @@ from ..framework import ( LayerHelper, _current_expected_place, + _current_expected_place_, _get_paddle_place, convert_np_dtype_to_dtype_, core, @@ -308,20 +309,20 @@ def linspace(start, stop, num, dtype=None, name=None): tensor_num = num tensor_start = start tensor_stop = stop - if not isinstance(num, Variable): + if not isinstance(num, (Variable, paddle.pir.OpResult)): check_type(num, 'num', (int), 'linspace') if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - if not isinstance(start, Variable): + if not isinstance(start, (Variable, paddle.pir.OpResult)): with device_guard("cpu"): tensor_start = fill_constant([1], dtype, start, force_cpu=True) - if not isinstance(stop, Variable): + if not isinstance(stop, (Variable, paddle.pir.OpResult)): with device_guard("cpu"): tensor_stop = fill_constant([1], dtype, stop, force_cpu=True) - if not isinstance(num, Variable): + if not isinstance(num, (Variable, paddle.pir.OpResult)): with device_guard("cpu"): tensor_num = fill_constant([1], 'int32', num, force_cpu=True) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.linspace( tensor_start, tensor_stop, @@ -440,23 +441,23 @@ def logspace(start, stop, num, base=10.0, dtype=None, name=None): tensor_start = start tensor_stop = stop tensor_base = base - if not isinstance(num, Variable): + if not isinstance(num, (Variable, paddle.pir.OpResult)): check_type(num, 'num', (int), 'logspace') - if not isinstance(dtype, core.VarDesc.VarType): + if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) - if not isinstance(start, Variable): + if not isinstance(start, (Variable, paddle.pir.OpResult)): with device_guard("cpu"): tensor_start = fill_constant([1], dtype, start) - if not isinstance(stop, Variable): + if not isinstance(stop, (Variable, paddle.pir.OpResult)): with device_guard("cpu"): tensor_stop = fill_constant([1], dtype, stop) - if not isinstance(num, Variable): + if not isinstance(num, (Variable, paddle.pir.OpResult)): with device_guard("cpu"): tensor_num = fill_constant([1], 'int32', num) - if not isinstance(base, Variable): + if not isinstance(base, (Variable, paddle.pir.OpResult)): with device_guard("cpu"): tensor_base = fill_constant([1], dtype, base) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.logspace( tensor_start, tensor_stop, @@ -651,10 +652,11 @@ def _handle_np_dtype(ndarray, dtype): def _to_tensor_static(data, dtype=None, stop_gradient=None): - if isinstance(data, Variable): + if isinstance(data, (Variable, paddle.pir.OpResult)): output = data if dtype is not None and dtype != data.dtype: output = paddle.cast(output, dtype) + else: if isinstance(data, np.number): # Special case for numpy scalars data = np.array(data) @@ -704,6 +706,8 @@ def _to_tensor_static(data, dtype=None, stop_gradient=None): target_dtype = paddle.get_default_dtype() target_dtype = convert_dtype(target_dtype) + if data.dtype == "int16": + data = data.astype("int32") output = assign(data) if convert_dtype(output.dtype) != target_dtype: @@ -785,8 +789,7 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): """ place = _get_paddle_place(place) if place is None: - place = _current_expected_place() - + place = _current_expected_place_() if in_dynamic_mode(): return _to_tensor_non_static(data, dtype, place, stop_gradient) @@ -794,7 +797,6 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): else: re_exp = re.compile(r'[(](.+?)[)]', re.S) place_str = re.findall(re_exp, str(place))[0] - with paddle.static.device_guard(place_str): return _to_tensor_static(data, dtype, stop_gradient) @@ -891,6 +893,9 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) + if in_pir_mode() and isinstance(dtype, core.VarDesc.VarType): + dtype = paddle.pir.core.vartype_to_datatype[dtype] + if in_dynamic_mode(): value = float(value) if isinstance(shape, (list, tuple)): @@ -899,7 +904,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): else: if isinstance(shape, (list, tuple)): if paddle.utils._contain_var(shape): - shape = paddle.utils.get_pir_shape_tensor(shape, place) + shape = paddle.utils.get_int_tensor_list(shape, place) elif isinstance(shape, paddle.pir.OpResult): pass else: @@ -1366,16 +1371,21 @@ def arange(start=0, end=None, step=1, dtype=None, name=None): dtype = 'int64' out_shape = None - if not in_dynamic_or_pir_mode() and ( + is_value_input = ( not isinstance(start, (Variable, paddle.pir.OpResult)) and not isinstance(end, (Variable, paddle.pir.OpResult)) and not isinstance(step, (Variable, paddle.pir.OpResult)) - ): + ) + + if not in_dynamic_mode() and is_value_input: out_shape = [int(math.ceil((end - start) / step))] if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) + if is_value_input and in_pir_mode(): + return _C_ops.arange(start, end, step, dtype, _current_expected_place()) + if not isinstance(start, (Variable, paddle.pir.OpResult)): with device_guard("cpu"): start = fill_constant([1], dtype, start, force_cpu=True) @@ -1641,7 +1651,7 @@ def meshgrid(*args, **kwargs): if len(args) == 1 and isinstance(args[0], (list, tuple)): args = args[0] - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.meshgrid(list(args)) else: name = kwargs.get("name", None) @@ -1672,6 +1682,125 @@ def meshgrid(*args, **kwargs): return out +def diag_embed(input, offset=0, dim1=-2, dim2=-1): + """ + Creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) + are filled by ``input``. By default, a 2D plane formed by the last two dimensions + of the returned tensor will be selected. + + The argument ``offset`` determines which diagonal is generated: + + - If offset = 0, it is the main diagonal. + - If offset > 0, it is above the main diagonal. + - If offset < 0, it is below the main diagonal. + + Args: + input(Tensor|numpy.ndarray): The input tensor. Must be at least 1-dimensional. The input data type should be float32, float64, int32, int64. + offset(int, optional): Which diagonal to consider. Default: 0 (main diagonal). + dim1(int, optional): The first dimension with respect to which to take diagonal. Default: -2. + dim2(int, optional): The second dimension with respect to which to take diagonal. Default: -1. + + Returns: + Tensor, the output data type is the same as input data type. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> diag_embed_input = paddle.arange(6) + + >>> diag_embed_output1 = paddle.diag_embed(diag_embed_input) + >>> print(diag_embed_output1) + Tensor(shape=[6, 6], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 2, 0, 0, 0], + [0, 0, 0, 3, 0, 0], + [0, 0, 0, 0, 4, 0], + [0, 0, 0, 0, 0, 5]]) + + >>> diag_embed_output2 = paddle.diag_embed(diag_embed_input, offset=-1, dim1=0,dim2=1 ) + >>> print(diag_embed_output2) + Tensor(shape=[7, 7], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], + [0, 0, 2, 0, 0, 0, 0], + [0, 0, 0, 3, 0, 0, 0], + [0, 0, 0, 0, 4, 0, 0], + [0, 0, 0, 0, 0, 5, 0]]) + + >>> diag_embed_input_2dim = paddle.reshape(diag_embed_input,[2,3]) + >>> print(diag_embed_input_2dim) + Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 1, 2], + [3, 4, 5]]) + >>> diag_embed_output3 = paddle.diag_embed(diag_embed_input_2dim,offset= 0, dim1=0, dim2=2 ) + >>> print(diag_embed_output3) + Tensor(shape=[3, 2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[[0, 0, 0], + [3, 0, 0]], + [[0, 1, 0], + [0, 4, 0]], + [[0, 0, 2], + [0, 0, 5]]]) + """ + if not isinstance(input, Variable): + input = assign(input) + + if in_dynamic_mode(): + return _C_ops.diag_embed(input, offset, dim1, dim2) + + inputs = {'Input': [input]} + attrs = {'offset': offset, 'dim1': dim1, 'dim2': dim2} + + def __check_input(input, offset, dim1, dim2): + check_dtype( + input.dtype, + 'Input', + ['int32', 'int64', 'float16', 'float32', 'float64'], + 'diag_embed', + ) + + input_shape = list(input.shape) + assert len(input_shape) >= 1, ( + "Input must be at least 1-dimensional, " + "But received Input's dimensional: %s.\n" % len(input_shape) + ) + + assert np.abs(dim1) <= len(input_shape), ( + "Dim1 is out of range (expected to be in range of [%d, %d], but got %d).\n" + % (-(len(input_shape) + 1), len(input_shape), dim1) + ) + + assert np.abs(dim2) <= len(input_shape), ( + "Dim2 is out of range (expected to be in range of [%d, %d], but got %d).\n" + % (-(len(input_shape) + 1), len(input_shape), dim2) + ) + + dim1_ = dim1 if dim1 >= 0 else len(input_shape) + dim1 + 1 + dim2_ = dim2 if dim2 >= 0 else len(input_shape) + dim2 + 1 + assert dim1_ != dim2_, ( + "dim1 and dim2 cannot be the same dimension." + "But received dim1 = %d, dim2 = %d\n" % (dim1, dim2) + ) + + __check_input(input, offset, dim1, dim2) + helper = LayerHelper("diag_embed", **locals()) + + out = helper.create_variable_for_type_inference(dtype=input.dtype) + + helper.append_op( + type='diag_embed', + inputs={'Input': [input]}, + attrs={'offset': offset, 'dim1': dim1, 'dim2': dim2}, + outputs={'Out': [out]}, + ) + out.stop_gradient = True + return out + + def diagflat(x, offset=0, name=None): """ If ``x`` is a vector (1-D tensor), a 2-D square tensor with the elements of ``x`` as the diagonal is returned. @@ -1880,7 +2009,7 @@ def diag(x, offset=0, padding_value=0, name=None): Tensor(shape=[1], dtype=int64, place=Place(cpu), stop_gradient=True, [4]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.diag(x, offset, padding_value) else: check_type(x, 'x', (Variable), 'diag_v2') @@ -2482,7 +2611,7 @@ def complex(real, imag, name=None): [[0j , 1j , 2j ], [(1+0j), (1+1j), (1+2j)]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.complex(real, imag) else: check_variable_and_dtype( @@ -2557,7 +2686,7 @@ def tril_indices(row, col, offset=0, dtype='int64'): if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if col is None: col = row out = _C_ops.tril_indices( @@ -2636,7 +2765,7 @@ def triu_indices(row, col=None, offset=0, dtype='int64'): if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if col is None: col = row out = _C_ops.triu_indices( diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index 955d104804bc8..8eb523be364ea 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -23,7 +23,7 @@ from paddle import _C_ops from ..base.data_feeder import check_type, check_variable_and_dtype -from ..base.framework import in_dygraph_mode +from ..base.framework import in_dynamic_or_pir_mode from ..base.layer_helper import LayerHelper from .linalg import matmul, transpose from .manipulation import reshape, squeeze, unsqueeze @@ -832,7 +832,7 @@ def gen_einsum_op(equation, *operands): EinsumOp Python Interface: """ - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.einsum(operands, equation)[0] else: assert len(operands) <= 2, "Only support two operands in EinsumOp." diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 97172e39b5492..0c76fed54a8b1 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -382,7 +382,7 @@ def frobenius_norm(input, dim=None, keepdim=False, name=None): "The dim of frobenius norm op should be None or two elements list!" ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if dim is None: return _C_ops.frobenius_norm(input, [], keepdim, True) return _C_ops.frobenius_norm(input, dim, keepdim, False) @@ -613,9 +613,7 @@ def p_matrix_norm(input, porder=1.0, axis=axis, keepdim=False, name=None): return inf_norm(x, porder=p, axis=axis, keepdim=keepdim, name=name) elif p == 0: raise ValueError( - "just support axis type int or list (length of list <=1) if p = 0, found {}".format( - axis - ) + f"just support axis type int or list (length of list <=1) if p = 0, found {axis}" ) else: return p_matrix_norm( @@ -719,7 +717,7 @@ def dist(x, y, p=2, name=None): Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.dist(x, y, p) check_variable_and_dtype( @@ -859,7 +857,7 @@ def mat_norm(input, porder=1.0, axis=None): Calculate the matrix norm of a square matrix or batches of square matrices, when porder is in (1, -1, inf, -inf) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): abs_out = _C_ops.abs(input) sum_out = _C_ops.sum(abs_out, axis, None, False) @@ -922,7 +920,7 @@ def fro_norm(input, porder=2, axis=[-1]): NOTE: Calculate the frobenius norm of a square matrix or batches of square matrices. """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): pow_out = _C_ops.pow(input, porder) sum_out_1 = _C_ops.sum(pow_out, axis, None, False) sum_out_2 = _C_ops.sum(sum_out_1, axis, None, False) @@ -985,7 +983,7 @@ def svd_norm(input, porder, axis=[-1]): """ u, s, vh = svd(input, full_matrices=False) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if porder == "nuc": return _C_ops.sum(s, axis, None, False) max_out = _C_ops.max(s, axis, False) @@ -1056,7 +1054,7 @@ def svd_norm(input, porder, axis=[-1]): return out def empty_tensor(input, shape): - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return input.reshape(shape) raise ValueError( "only support x is nonempty tensor in static graph mode" @@ -1375,7 +1373,7 @@ def t(input, name=None): "length of Input(input) is %s. Perhaps you can use paddle." "tensor.transpose() instead." % len(input.shape) ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if len(input.shape) <= 1: return input # 2-D tensor @@ -1539,7 +1537,7 @@ def cholesky(x, upper=False, name=None): [1.06467664, 0.17859250, 0. ], [1.30602181, 0.08326444, 0.22790681]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.cholesky(x, upper) else: check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'cholesky') @@ -1907,7 +1905,7 @@ def det(x, name=None): """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.det(x) else: check_dtype(x.dtype, 'Input', ['float16', 'float32', 'float64'], 'det') @@ -1958,17 +1956,15 @@ def slogdet(x, name=None): >>> import paddle >>> paddle.seed(2023) - >>> x = paddle.randn([3,3,3]) + >>> x = paddle.randn([3, 3, 3]) >>> A = paddle.linalg.slogdet(x) >>> print(A) - >>> # doctest: +SKIP Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, [[-1. , 1. , 1. ], [ 0.25681755, -0.25061053, -0.10809582]]) - >>> # doctest: -SKIP """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.slogdet(x) else: check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'slogdet') @@ -2621,7 +2617,7 @@ def eig(x, name=None): (-0.21026138961315155+0j)]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.eig(x) else: check_variable_and_dtype( @@ -2687,12 +2683,10 @@ def eigvals(x, name=None): if x_shape[-1] != x_shape[-2]: raise ValueError( - "The last two dimensions of Input(x) should be equal, but received x's shape = {}".format( - x_shape - ) + f"The last two dimensions of Input(x) should be equal, but received x's shape = {x_shape}" ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.eigvals(x) else: check_variable_and_dtype( @@ -2801,10 +2795,12 @@ def eigh(x, UPLO='L', name=None): property. For more information, please refer to :ref:`api_guide_Name`. Returns: - - out_value(Tensor): A Tensor with shape [*, N] and data type of float32 and float64. - The eigenvalues of eigh op. - - out_vector(Tensor): A Tensor with shape [*, N, N] and data type of float32,float64, - complex64 and complex128. The eigenvectors of eigh op. + 2-element tuple containing + + - out_value(Tensor): A Tensor with shape :math:`[*, N]` and data type of float32 and float64. + The eigenvalues of eigh op. + - out_vector(Tensor): A Tensor with shape :math:`[*, N, N]` and data type of float32, float64, + complex64 and complex128. The eigenvectors of eigh op. Examples: .. code-block:: python @@ -2822,7 +2818,7 @@ def eigh(x, UPLO='L', name=None): [ 0.3826833963394165j , -0.9238795042037964j ]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.eigh(x, UPLO) else: @@ -2835,9 +2831,7 @@ def __check_input(x, UPLO): ) if x_shape[-1] != x_shape[-2]: raise ValueError( - "The input matrix must be batches of square matrices. But received x's dimention: {}".format( - x_shape - ) + f"The input matrix must be batches of square matrices. But received x's dimention: {x_shape}" ) if UPLO != 'L' and UPLO != 'U': raise ValueError( @@ -3149,7 +3143,7 @@ def solve(x, y, name=None): Tensor(shape=[2], dtype=float64, place=Place(cpu), stop_gradient=True, [2., 3.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.solve(x, y) else: inputs = {"X": [x], "Y": [y]} @@ -3221,7 +3215,7 @@ def triangular_solve( [-2.], [-5.]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.triangular_solve(x, y, upper, transpose, unitriangular) else: inputs = {"X": [x], "Y": [y]} @@ -3283,7 +3277,7 @@ def cholesky_solve(x, y, upper=False, name=None): [-7. ], [ 9.50000000]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.cholesky_solve(x, y, upper) else: helper = LayerHelper("cholesky_solve", **locals()) @@ -3330,7 +3324,7 @@ def eigvalsh(x, UPLO='L', name=None): Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, [0.17157286, 5.82842731]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): values, _ = _C_ops.eigvalsh(x, UPLO, x.stop_gradient) return values else: @@ -3344,9 +3338,7 @@ def __check_input(x, UPLO): ) if x_shape[-1] != x_shape[-2]: raise ValueError( - "The input matrix must be batches of square matrices. But received x's dimention: {}".format( - x_shape - ) + f"The input matrix must be batches of square matrices. But received x's dimention: {x_shape}" ) if UPLO != 'L' and UPLO != 'U': raise ValueError( @@ -3444,17 +3436,13 @@ def lstsq(x, y, rcond=None, driver=None, name=None): if device == "cpu": if driver not in (None, "gels", "gelss", "gelsd", "gelsy"): raise ValueError( - "Only support valid driver is 'gels', 'gelss', 'gelsd', 'gelsy' or None for CPU inputs. But got {}".format( - driver - ) + f"Only support valid driver is 'gels', 'gelss', 'gelsd', 'gelsy' or None for CPU inputs. But got {driver}" ) driver = "gelsy" if driver is None else driver elif "gpu" in device: if driver not in (None, "gels"): raise ValueError( - "Only support valid driver is 'gels' or None for CUDA inputs. But got {}".format( - driver - ) + f"Only support valid driver is 'gels' or None for CUDA inputs. But got {driver}" ) driver = "gels" if driver is None else driver else: diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 0deeefcc15c74..7e29f89a9de17 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -139,7 +139,7 @@ def logical_and(x, y, out=None, name=None): [True , False, True , False]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.logical_and(x, y) return _logical_op( @@ -413,7 +413,7 @@ def equal_all(x, y, name=None): Tensor(shape=[], dtype=bool, place=Place(cpu), stop_gradient=True, False) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.equal_all(x, y) else: helper = LayerHelper("equal_all", **locals()) @@ -535,9 +535,7 @@ def equal(x, y, name=None): """ if not isinstance(y, (int, bool, float, Variable, paddle.pir.OpResult)): raise TypeError( - "Type of input args must be float, bool, int or Tensor, but received type {}".format( - type(y) - ) + f"Type of input args must be float, bool, int or Tensor, but received type {type(y)}" ) if not isinstance(y, (Variable, paddle.pir.OpResult)): y = full(shape=[], dtype=x.dtype, fill_value=y) @@ -718,7 +716,7 @@ def greater_than(x, y, name=None): Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True, [False, False, True ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.greater_than(x, y) else: check_variable_and_dtype( @@ -807,7 +805,7 @@ def less_equal(x, y, name=None): Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True, [True , True , False]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.less_equal(x, y) else: check_variable_and_dtype( @@ -896,7 +894,7 @@ def less_than(x, y, name=None): Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True, [False, True , False]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.less_than(x, y) else: check_variable_and_dtype( @@ -985,7 +983,7 @@ def not_equal(x, y, name=None): Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True, [False, True , True ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.not_equal(x, y) else: check_variable_and_dtype( @@ -1213,7 +1211,7 @@ def bitwise_or(x, y, out=None, name=None): Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [-1, -1, -3]) """ - if in_dynamic_mode() and out is None: + if in_dynamic_or_pir_mode() and out is None: return _C_ops.bitwise_or(x, y) return _bitwise_op( @@ -1272,7 +1270,7 @@ def bitwise_xor(x, y, out=None, name=None): Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [-1, -3, -4]) """ - if in_dynamic_mode() and out is None: + if in_dynamic_or_pir_mode() and out is None: return _C_ops.bitwise_xor(x, y) return _bitwise_op( op_name="bitwise_xor", x=x, y=y, name=name, out=out, binary_op=True @@ -1328,7 +1326,7 @@ def bitwise_not(x, out=None, name=None): Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [ 4, 0, -2]) """ - if in_dynamic_mode() and out is None: + if in_dynamic_or_pir_mode() and out is None: return _C_ops.bitwise_not(x) return _bitwise_op( @@ -1402,7 +1400,7 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): [True, True]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.isclose(x, y, rtol, atol, equal_nan) else: check_variable_and_dtype( diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 221cf524347d0..203b98c78683b 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -30,6 +30,7 @@ from ..base.framework import Variable from ..framework import ( LayerHelper, + _current_expected_place, convert_np_dtype_to_dtype_, core, dygraph_only, @@ -112,15 +113,15 @@ def tensor_array_to_tensor(input, axis=1, use_stack=False, name=None): Examples: .. code-block:: python - import numpy - import paddle - x0 = paddle.assign(numpy.random.rand(2, 2).astype("float32")) - x1 = paddle.assign(numpy.random.rand(2, 2).astype("float32")) - i = paddle.full(shape=[1], dtype="int64", fill_value=0) - array = paddle.tensor.array.create_array(dtype='float32') - paddle.tensor.array.array_write(x0, i, array) - paddle.tensor.array.array_write(x1, i + 1, array) - output, output_index = paddle.tensor.manipulation.tensor_array_to_tensor(input=array) + >>> import numpy + >>> import paddle + >>> x0 = paddle.assign(numpy.random.rand(2, 2).astype("float32")) + >>> x1 = paddle.assign(numpy.random.rand(2, 2).astype("float32")) + >>> i = paddle.full(shape=[1], dtype="int64", fill_value=0) + >>> array = paddle.tensor.array.create_array(dtype='float32') + >>> paddle.tensor.array.array_write(x0, i, array) + >>> paddle.tensor.array.array_write(x1, i + 1, array) + >>> output, output_index = paddle.tensor.manipulation.tensor_array_to_tensor(input=array) """ if in_dynamic_mode(): assert isinstance( @@ -175,10 +176,10 @@ def cast(x, dtype): Examples: .. code-block:: python - import paddle + >>> import paddle - x = paddle.to_tensor([2, 3, 4], 'float64') - y = paddle.cast(x, 'uint8') + >>> x = paddle.to_tensor([2, 3, 4], 'float64') + >>> y = paddle.cast(x, 'uint8') """ if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) @@ -295,24 +296,24 @@ def slice(input, axes, starts, ends): Examples: .. code-block:: python - import paddle - - input = paddle.rand(shape=[4, 5, 6], dtype='float32') - # example 1: - # attr starts is a list which doesn't contain tensor. - axes = [0, 1, 2] - starts = [-3, 0, 2] - ends = [3, 2, 4] - sliced_1 = paddle.slice(input, axes=axes, starts=starts, ends=ends) - # sliced_1 is input[1:3, 0:2, 2:4]. - - # example 2: - # attr starts is a list which contain tensor. - minus_3 = paddle.full([1], -3, "int32") - sliced_2 = paddle.slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends) - # sliced_2 is input[1:3, 0:2, 2:4]. + >>> import paddle + + >>> input = paddle.rand(shape=[4, 5, 6], dtype='float32') + >>> # example 1: + >>> # attr starts is a list which doesn't contain tensor. + >>> axes = [0, 1, 2] + >>> starts = [-3, 0, 2] + >>> ends = [3, 2, 4] + >>> sliced_1 = paddle.slice(input, axes=axes, starts=starts, ends=ends) + >>> # sliced_1 is input[1:3, 0:2, 2:4]. + + >>> # example 2: + >>> # attr starts is a list which contain tensor. + >>> minus_3 = paddle.full([1], -3, "int32") + >>> sliced_2 = paddle.slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends) + >>> # sliced_2 is input[1:3, 0:2, 2:4]. """ - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): attrs = () starts_tensor = None ends_tensor = None @@ -357,6 +358,38 @@ def slice(input, axes, starts, ends): infer_flags = [-1 for i in range(len(axes))] return _C_ops.slice(input, axes, starts, ends, infer_flags, []) + elif in_pir_mode(): + if not isinstance(starts, (list, tuple, paddle.pir.OpResult)): + raise ValueError( + "Input starts must be an OpResult, python list or tuple." + ) + if not isinstance(ends, (list, tuple, paddle.pir.OpResult)): + raise ValueError( + "Input ends must be an OpResult, python list or tuple." + ) + infer_flags = [1 for i in range(len(axes))] + # starts + if isinstance(starts, paddle.pir.OpResult): + starts.stop_gradient = True + infer_flags = [-1 for i in range(len(axes))] + elif isinstance(starts, (list, tuple)): + if paddle.utils._contain_var(starts): + for i, dim in enumerate(starts): + if isinstance(dim, paddle.pir.OpResult): + infer_flags[i] = -1 + starts = paddle.utils.get_int_tensor_list(starts) + + # ends + if isinstance(ends, paddle.pir.OpResult): + ends.stop_gradient = True + infer_flags = [-1 for i in range(len(axes))] + elif isinstance(ends, (list, tuple)): + if paddle.utils._contain_var(ends): + for i, dim in enumerate(ends): + if isinstance(dim, paddle.pir.OpResult): + infer_flags[i] = -1 + ends = paddle.utils.get_int_tensor_list(ends) + return _C_ops.slice(input, axes, starts, ends, infer_flags, []) else: if not isinstance(starts, (list, tuple, Variable)): raise ValueError( @@ -467,12 +500,12 @@ def transpose(x, perm, name=None): .. code-block:: python - import paddle + >>> import paddle - x = paddle.randn([2, 3, 4]) - x_transposed = paddle.transpose(x, perm=[1, 0, 2]) - print(x_transposed.shape) - # [3L, 2L, 4L] + >>> x = paddle.randn([2, 3, 4]) + >>> x_transposed = paddle.transpose(x, perm=[1, 0, 2]) + >>> print(x_transposed.shape) + [3, 2, 4] """ if in_dynamic_or_pir_mode(): @@ -544,16 +577,16 @@ def unstack(x, axis=0, num=None): Examples: .. code-block:: python - import paddle - x = paddle.ones(name='x', shape=[2, 3, 5], dtype='float32') # create a tensor with shape=[2, 3, 5] - y = paddle.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5] + >>> import paddle + >>> x = paddle.ones(name='x', shape=[2, 3, 5], dtype='float32') # create a tensor with shape=[2, 3, 5] + >>> y = paddle.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5] """ if not (-x.ndim <= axis < x.ndim): raise ValueError(f'`axis` must be in the range [-{x.ndim}, {x.ndim})') if num is not None and (num < 0 or num > x.shape[axis]): raise ValueError(f'`num` must be in the range [0, {x.shape[axis]})') - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if num is None: num = x.shape[axis] if num == 0: @@ -617,14 +650,15 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1): Examples: .. code-block:: python - import paddle - label = paddle.to_tensor([[16], [1]], "int64") - shard_label = paddle.shard_index(input=label, - index_num=20, - nshards=2, - shard_id=0) - print(shard_label) - # [[-1], [1]] + >>> import paddle + >>> label = paddle.to_tensor([[16], [1]], "int64") + >>> shard_label = paddle.shard_index(input=label, + ... index_num=20, + ... nshards=2, + ... shard_id=0) + >>> print(shard_label.numpy()) + [[-1] + [ 1]] """ if in_dynamic_mode(): return _C_ops.shard_index( @@ -716,29 +750,29 @@ def crop(x, shape=None, offsets=None, name=None): .. code-block:: python - import paddle - x = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - # x.shape = [3, 3] - # x = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - - # shape can be a 1-D Tensor or list or tuple. - shape = paddle.to_tensor([2, 2], dtype='int32') - # shape = [2, 2] - # shape = (2, 2) - out = paddle.crop(x, shape) - # out.shape = [2, 2] - # out = [[1,2], [4,5]] - - # offsets can be a 1-D Tensor or list or tuple. - offsets = paddle.to_tensor([0, 1], dtype='int32') - # offsets = [1, 0] - # offsets = (1, 1) - out = paddle.crop(x, shape, offsets) - # out.shape = [2, 2] - # if offsets = [0, 0], out = [[1,2], [4,5]] - # if offsets = [0, 1], out = [[2,3], [5,6]] - # if offsets = [1, 0], out = [[4,5], [7,8]] - # if offsets = [1, 1], out = [[5,6], [8,9]] + >>> import paddle + >>> x = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> # x.shape = [3, 3] + >>> # x = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + + >>> # shape can be a 1-D Tensor or list or tuple. + >>> shape = paddle.to_tensor([2, 2], dtype='int32') + >>> # shape = [2, 2] + >>> # shape = (2, 2) + >>> out = paddle.crop(x, shape) + >>> # out.shape = [2, 2] + >>> # out = [[1,2], [4,5]] + + >>> # offsets can be a 1-D Tensor or list or tuple. + >>> offsets = paddle.to_tensor([0, 1], dtype='int32') + >>> # offsets = [1, 0] + >>> # offsets = (1, 1) + >>> out = paddle.crop(x, shape, offsets) + >>> # out.shape = [2, 2] + >>> # if offsets = [0, 0], out = [[1,2], [4,5]] + >>> # if offsets = [0, 1], out = [[2,3], [5,6]] + >>> # if offsets = [1, 0], out = [[4,5], [7,8]] + >>> # if offsets = [1, 1], out = [[5,6], [8,9]] """ @@ -873,12 +907,13 @@ def fill_(x, value): Examples: .. code-block:: python - import paddle + >>> import paddle - tensor = paddle.to_tensor([0, 1, 2, 3, 4]) + >>> tensor = paddle.to_tensor([0, 1, 2, 3, 4]) - tensor.fill_(0) - print(tensor.tolist()) #[0, 0, 0, 0, 0] + >>> tensor.fill_(0) + >>> print(tensor.tolist()) + [0, 0, 0, 0, 0] """ if not isinstance(value, (float, int)): @@ -906,12 +941,13 @@ def zero_(x): Examples: .. code-block:: python - import paddle + >>> import paddle - tensor = paddle.to_tensor([0, 1, 2, 3, 4]) + >>> tensor = paddle.to_tensor([0, 1, 2, 3, 4]) - tensor.zero_() - print(tensor.tolist()) #[0, 0, 0, 0, 0] + >>> tensor.zero_() + >>> print(tensor.tolist()) + [0, 0, 0, 0, 0] """ return _C_ops.fill_(x, 0.0) @@ -937,10 +973,12 @@ def fill_diagonal_(x, value, offset=0, wrap=False, name=None): Examples: .. code-block:: python - import paddle - x = paddle.ones((4, 3)) * 2 - x.fill_diagonal_(1.0) - print(x.tolist()) #[[1.0, 2.0, 2.0], [2.0, 1.0, 2.0], [2.0, 2.0, 1.0], [2.0, 2.0, 2.0]] + + >>> import paddle + >>> x = paddle.ones((4, 3)) * 2 + >>> x.fill_diagonal_(1.0) + >>> print(x.tolist()) + [[1.0, 2.0, 2.0], [2.0, 1.0, 2.0], [2.0, 2.0, 1.0], [2.0, 2.0, 2.0]] """ if in_dynamic_mode(): if len(x.shape) == 2: @@ -1003,12 +1041,13 @@ def fill_diagonal_tensor_(x, y, offset=0, dim1=0, dim2=1, name=None): Examples: .. code-block:: python - import paddle + >>> import paddle - x = paddle.ones((4, 3)) * 2 - y = paddle.ones((3,)) - x.fill_diagonal_tensor_(y) - print(x.tolist()) #[[1.0, 2.0, 2.0], [2.0, 1.0, 2.0], [2.0, 2.0, 1.0], [2.0, 2.0, 2.0]] + >>> x = paddle.ones((4, 3)) * 2 + >>> y = paddle.ones((3,)) + >>> x.fill_diagonal_tensor_(y) + >>> print(x.tolist()) + [[1.0, 2.0, 2.0], [2.0, 1.0, 2.0], [2.0, 2.0, 1.0], [2.0, 2.0, 2.0]] """ return _fill_diagonal_tensor_impl( @@ -1034,12 +1073,13 @@ def fill_diagonal_tensor(x, y, offset=0, dim1=0, dim2=1, name=None): Examples: .. code-block:: python - import paddle + >>> import paddle - x = paddle.ones((4, 3)) * 2 - y = paddle.ones((3,)) - nx = x.fill_diagonal_tensor(y) - print(nx.tolist()) #[[1.0, 2.0, 2.0], [2.0, 1.0, 2.0], [2.0, 2.0, 1.0], [2.0, 2.0, 2.0]] + >>> x = paddle.ones((4, 3)) * 2 + >>> y = paddle.ones((3,)) + >>> nx = x.fill_diagonal_tensor(y) + >>> print(nx.tolist()) + [[1.0, 2.0, 2.0], [2.0, 1.0, 2.0], [2.0, 2.0, 1.0], [2.0, 2.0, 2.0]] """ return _fill_diagonal_tensor_impl( @@ -1065,14 +1105,16 @@ def tolist(x): Examples: .. code-block:: python - import paddle + >>> import paddle - t = paddle.to_tensor([0,1,2,3,4]) - expectlist = t.tolist() - print(expectlist) #[0, 1, 2, 3, 4] + >>> t = paddle.to_tensor([0,1,2,3,4]) + >>> expectlist = t.tolist() + >>> print(expectlist) + [0, 1, 2, 3, 4] - expectlist = paddle.tolist(t) - print(expectlist) #[0, 1, 2, 3, 4] + >>> expectlist = paddle.tolist(t) + >>> print(expectlist) + [0, 1, 2, 3, 4] """ # TODO(zhouwei): will remove 0-D Tensor.numpy() hack @@ -1099,28 +1141,36 @@ def concat(x, axis=0, name=None): Examples: .. code-block:: python - import paddle - - x1 = paddle.to_tensor([[1, 2, 3], - [4, 5, 6]]) - x2 = paddle.to_tensor([[11, 12, 13], - [14, 15, 16]]) - x3 = paddle.to_tensor([[21, 22], - [23, 24]]) - zero = paddle.full(shape=[1], dtype='int32', fill_value=0) - # When the axis is negative, the real axis is (axis + Rank(x)) - # As follow, axis is -1, Rank(x) is 2, the real axis is 1 - out1 = paddle.concat(x=[x1, x2, x3], axis=-1) - out2 = paddle.concat(x=[x1, x2], axis=0) - out3 = paddle.concat(x=[x1, x2], axis=zero) - # out1 - # [[ 1 2 3 11 12 13 21 22] - # [ 4 5 6 14 15 16 23 24]] - # out2 out3 - # [[ 1 2 3] - # [ 4 5 6] - # [11 12 13] - # [14 15 16]] + >>> import paddle + + >>> x1 = paddle.to_tensor([[1, 2, 3], + ... [4, 5, 6]]) + >>> x2 = paddle.to_tensor([[11, 12, 13], + ... [14, 15, 16]]) + >>> x3 = paddle.to_tensor([[21, 22], + ... [23, 24]]) + >>> zero = paddle.full(shape=[1], dtype='int32', fill_value=0) + >>> # When the axis is negative, the real axis is (axis + Rank(x)) + >>> # As follow, axis is -1, Rank(x) is 2, the real axis is 1 + >>> out1 = paddle.concat(x=[x1, x2, x3], axis=-1) + >>> out2 = paddle.concat(x=[x1, x2], axis=0) + >>> out3 = paddle.concat(x=[x1, x2], axis=zero) + >>> print(out1) + Tensor(shape=[2, 8], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1 , 2 , 3 , 11, 12, 13, 21, 22], + [4 , 5 , 6 , 14, 15, 16, 23, 24]]) + >>> print(out2) + Tensor(shape=[4, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1 , 2 , 3 ], + [4 , 5 , 6 ], + [11, 12, 13], + [14, 15, 16]]) + >>> print(out3) + Tensor(shape=[4, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1 , 2 , 3 ], + [4 , 5 , 6 ], + [11, 12, 13], + [14, 15, 16]]) """ input = x if in_dynamic_or_pir_mode(): @@ -1227,12 +1277,12 @@ def broadcast_tensors(input, name=None): Examples: .. code-block:: python - import paddle - x1 = paddle.rand([1, 2, 3, 4]).astype('float32') - x2 = paddle.rand([1, 2, 1, 4]).astype('float32') - x3 = paddle.rand([1, 1, 3, 1]).astype('float32') - out1, out2, out3 = paddle.broadcast_tensors(input=[x1, x2, x3]) - # out1, out2, out3: tensors broadcasted from x1, x2, x3 with shape [1,2,3,4] + >>> import paddle + >>> x1 = paddle.rand([1, 2, 3, 4]).astype('float32') + >>> x2 = paddle.rand([1, 2, 1, 4]).astype('float32') + >>> x3 = paddle.rand([1, 1, 3, 1]).astype('float32') + >>> out1, out2, out3 = paddle.broadcast_tensors(input=[x1, x2, x3]) + >>> # out1, out2, out3: tensors broadcasted from x1, x2, x3 with shape [1,2,3,4] """ num_inputs = len(input) @@ -1337,20 +1387,34 @@ def flip(x, axis, name=None): Examples: .. code-block:: python - import paddle - - image_shape=(3, 2, 2) - img = paddle.arange(image_shape[0] * image_shape[1] * image_shape[2]).reshape(image_shape) - tmp = paddle.flip(img, [0,1]) - print(tmp) # [[[10,11],[8, 9]], [[6, 7],[4, 5]], [[2, 3],[0, 1]]] - - out = paddle.flip(tmp,-1) - print(out) # [[[11,10],[9, 8]], [[7, 6],[5, 4]], [[3, 2],[1, 0]]] + >>> import paddle + + >>> image_shape=(3, 2, 2) + >>> img = paddle.arange(image_shape[0] * image_shape[1] * image_shape[2]).reshape(image_shape) + >>> tmp = paddle.flip(img, [0,1]) + >>> print(tmp) + Tensor(shape=[3, 2, 2], dtype=int64, place=Place(cpu), stop_gradient=True, + [[[10, 11], + [8 , 9 ]], + [[6 , 7 ], + [4 , 5 ]], + [[2 , 3 ], + [0 , 1 ]]]) + + >>> out = paddle.flip(tmp,-1) + >>> print(out) + Tensor(shape=[3, 2, 2], dtype=int64, place=Place(cpu), stop_gradient=True, + [[[11, 10], + [9 , 8 ]], + [[7 , 6 ], + [5 , 4 ]], + [[3 , 2 ], + [1 , 0 ]]]) """ if isinstance(axis, int): axis = [axis] - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.flip(x, axis) else: helper = LayerHelper("flip", **locals()) @@ -1397,38 +1461,38 @@ def rot90(x, k=1, axes=[0, 1], name=None): Examples: .. code-block:: python - import paddle - - data = paddle.arange(4) - data = paddle.reshape(data, (2, 2)) - print(data) - #[[0, 1], - # [2, 3]] - - y = paddle.rot90(data, 1, [0, 1]) - print(y) - #[[1, 3], - # [0, 2]] - - y= paddle.rot90(data, -1, [0, 1]) - print(y) - #[[2, 0], - # [3, 1]] - - data2 = paddle.arange(8) - data2 = paddle.reshape(data2, (2,2,2)) - print(data2) - #[[[0, 1], - # [2, 3]], - # [[4, 5], - # [6, 7]]] - - y = paddle.rot90(data2, 1, [1, 2]) - print(y) - #[[[1, 3], - # [0, 2]], - # [[5, 7], - # [4, 6]]] + >>> import paddle + + >>> data = paddle.arange(4) + >>> data = paddle.reshape(data, (2, 2)) + >>> print(data.numpy()) + [[0 1] + [2 3]] + + >>> y = paddle.rot90(data, 1, [0, 1]) + >>> print(y.numpy()) + [[1 3] + [0 2]] + + >>> y= paddle.rot90(data, -1, [0, 1]) + >>> print(y.numpy()) + [[2 0] + [3 1]] + + >>> data2 = paddle.arange(8) + >>> data2 = paddle.reshape(data2, (2,2,2)) + >>> print(data2.numpy()) + [[[0 1] + [2 3]] + [[4 5] + [6 7]]] + + >>> y = paddle.rot90(data2, 1, [1, 2]) + >>> print(y.numpy()) + [[[1 3] + [0 2]] + [[5 7] + [4 6]]] """ helper = LayerHelper("rot90", **locals()) @@ -1535,19 +1599,22 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): .. code-block:: python - import paddle + >>> import paddle - image_shape=(2, 3, 4, 4) + >>> image_shape=(2, 3, 4, 4) - x = paddle.arange(end=image_shape[0] * image_shape[1] * image_shape[2] * image_shape[3]) - img = paddle.reshape(x, image_shape) + >>> x = paddle.arange(end=image_shape[0] * image_shape[1] * image_shape[2] * image_shape[3]) + >>> img = paddle.reshape(x, image_shape) - out = paddle.flatten(img, start_axis=1, stop_axis=2) - # out shape is [2, 12, 4] + >>> out = paddle.flatten(img, start_axis=1, stop_axis=2) + >>> print(out.shape) + [2, 12, 4] - # out shares data with img in dygraph mode - img[0, 0, 0, 0] = -1 - print(out[0, 0, 0]) # [-1] + >>> # out shares data with img in dygraph mode + >>> img[0, 0, 0, 0] = -1 + >>> print(out[0, 0, 0]) + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, + -1) """ if not (isinstance(x, (Variable, paddle.pir.OpResult))): raise ValueError("The input x should be a Tensor") @@ -1676,26 +1743,26 @@ def roll(x, shifts, axis=None, name=None): Examples: .. code-block:: python - import paddle - - x = paddle.to_tensor([[1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - [7.0, 8.0, 9.0]]) - out_z1 = paddle.roll(x, shifts=1) - print(out_z1) - #[[9. 1. 2.] - # [3. 4. 5.] - # [6. 7. 8.]] - out_z2 = paddle.roll(x, shifts=1, axis=0) - print(out_z2) - #[[7. 8. 9.] - # [1. 2. 3.] - # [4. 5. 6.]] - out_z3 = paddle.roll(x, shifts=1, axis=1) - print(out_z3) - #[[3. 1. 2.] - # [6. 4. 5.] - # [9. 7. 8.]] + >>> import paddle + + >>> x = paddle.to_tensor([[1.0, 2.0, 3.0], + ... [4.0, 5.0, 6.0], + ... [7.0, 8.0, 9.0]]) + >>> out_z1 = paddle.roll(x, shifts=1) + >>> print(out_z1.numpy()) + [[9. 1. 2.] + [3. 4. 5.] + [6. 7. 8.]] + >>> out_z2 = paddle.roll(x, shifts=1, axis=0) + >>> print(out_z2.numpy()) + [[7. 8. 9.] + [1. 2. 3.] + [4. 5. 6.]] + >>> out_z3 = paddle.roll(x, shifts=1, axis=1) + >>> print(out_z3.numpy()) + [[3. 1. 2.] + [6. 4. 5.] + [9. 7. 8.]] """ origin_shape = x.shape if type(shifts) == int: @@ -1819,28 +1886,32 @@ def stack(x, axis=0, name=None): Returns: Tensor, The stacked tensor with same data type as input. - Example: + Examples: .. code-block:: python - import paddle - - x1 = paddle.to_tensor([[1.0, 2.0]]) - x2 = paddle.to_tensor([[3.0, 4.0]]) - x3 = paddle.to_tensor([[5.0, 6.0]]) - - out = paddle.stack([x1, x2, x3], axis=0) - print(out.shape) # [3, 1, 2] - print(out) - # [[[1., 2.]], - # [[3., 4.]], - # [[5., 6.]]] - - out = paddle.stack([x1, x2, x3], axis=-2) - print(out.shape) # [1, 3, 2] - print(out) - # [[[1., 2.], - # [3., 4.], - # [5., 6.]]] + >>> import paddle + + >>> x1 = paddle.to_tensor([[1.0, 2.0]]) + >>> x2 = paddle.to_tensor([[3.0, 4.0]]) + >>> x3 = paddle.to_tensor([[5.0, 6.0]]) + + >>> out = paddle.stack([x1, x2, x3], axis=0) + >>> print(out.shape) + [3, 1, 2] + >>> print(out) + Tensor(shape=[3, 1, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[1., 2.]], + [[3., 4.]], + [[5., 6.]]]) + + >>> out = paddle.stack([x1, x2, x3], axis=-2) + >>> print(out.shape) + [1, 3, 2] + >>> print(out) + Tensor(shape=[1, 3, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[1., 2.], + [3., 4.], + [5., 6.]]]) """ axis = 0 if axis is None else axis @@ -1926,35 +1997,48 @@ def split(x, num_or_sections, axis=0, name=None): Returns: list(Tensor), The list of segmented Tensors. - Example: + Examples: .. code-block:: python - import paddle - - # x is a Tensor of shape [3, 9, 5] - x = paddle.rand([3, 9, 5]) - - out0, out1, out2 = paddle.split(x, num_or_sections=3, axis=1) - print(out0.shape) # [3, 3, 5] - print(out1.shape) # [3, 3, 5] - print(out2.shape) # [3, 3, 5] - - out0, out1, out2 = paddle.split(x, num_or_sections=[2, 3, 4], axis=1) - print(out0.shape) # [3, 2, 5] - print(out1.shape) # [3, 3, 5] - print(out2.shape) # [3, 4, 5] - - out0, out1, out2 = paddle.split(x, num_or_sections=[2, 3, -1], axis=1) - print(out0.shape) # [3, 2, 5] - print(out1.shape) # [3, 3, 5] - print(out2.shape) # [3, 4, 5] - - # axis is negative, the real axis is (rank(x) + axis)=1 - out0, out1, out2 = paddle.split(x, num_or_sections=3, axis=-2) - print(out0.shape) # [3, 3, 5] - print(out1.shape) # [3, 3, 5] - print(out2.shape) # [3, 3, 5] + >>> import paddle + + >>> # x is a Tensor of shape [3, 9, 5] + >>> x = paddle.rand([3, 9, 5]) + + >>> out0, out1, out2 = paddle.split(x, num_or_sections=3, axis=1) + >>> print(out0.shape) + [3, 3, 5] + >>> print(out1.shape) + [3, 3, 5] + >>> print(out2.shape) + [3, 3, 5] + + >>> out0, out1, out2 = paddle.split(x, num_or_sections=[2, 3, 4], axis=1) + >>> print(out0.shape) + [3, 2, 5] + >>> print(out1.shape) + [3, 3, 5] + >>> print(out2.shape) + [3, 4, 5] + + >>> out0, out1, out2 = paddle.split(x, num_or_sections=[2, 3, -1], axis=1) + >>> print(out0.shape) + [3, 2, 5] + >>> print(out1.shape) + [3, 3, 5] + >>> print(out2.shape) + [3, 4, 5] + + >>> # axis is negative, the real axis is (rank(x) + axis)=1 + >>> out0, out1, out2 = paddle.split(x, num_or_sections=3, axis=-2) + >>> print(out0.shape) + [3, 3, 5] + >>> print(out1.shape) + [3, 3, 5] + >>> print(out2.shape) + [3, 3, 5] """ + input = x dim = axis if in_dynamic_mode(): @@ -1979,15 +2063,32 @@ def split(x, num_or_sections, axis=0, name=None): else: return _C_ops.split(input, num_or_sections, dim) elif in_pir_mode(): + if isinstance(dim, paddle.pir.OpResult): + dim.stop_gradient = True if isinstance(dim, int): assert len(input.shape) + dim >= 0, "(rank(x) + axis) must >= 0" dim = (len(input.shape) + dim) if dim < 0 else dim + input_shape = input.shape if isinstance(num_or_sections, int): - dim = dim if dim >= 0 else dim + len(input.shape) + assert num_or_sections > 0, 'num_or_sections must be than 0.' + if isinstance(dim, int) and input_shape[dim] > 0: + assert input_shape[dim] % num_or_sections == 0, ( + "The input's size along the split dimension " + "must be evenly divisible by Attr(num_or_sections). " + "But %d is not evenly divisible by %d. " + % (num_or_sections, input_shape[dim]) + ) return _C_ops.split_with_num(input, num_or_sections, dim) else: - dim = dim if dim >= 0 else dim + len(input.shape) + if isinstance(dim, int) and input_shape[dim] > 0: + assert ( + len(num_or_sections) <= input_shape[dim] + ), 'len(num_or_sections) must not be more than input.shape[dim].' + if paddle.utils._contain_var(num_or_sections): + num_or_sections = paddle.utils.get_int_tensor_list( + num_or_sections + ) return _C_ops.split(input, num_or_sections, dim) else: @@ -2109,24 +2210,32 @@ def vsplit(x, num_or_sections, name=None): Returns: list[Tensor], The list of segmented Tensors. - Example: + Examples: .. code-block:: python - import paddle - - # x is a Tensor of shape [8, 6, 7] - x = paddle.rand([8, 6, 7]) - out0, out1 = paddle.vsplit(x, num_or_sections=2) - print(out0.shape) # [4, 6, 7] - print(out1.shape) # [4, 6, 7] - out0, out1, out2 = paddle.vsplit(x, num_or_sections=[1, 3, 4]) - print(out0.shape) # [1, 6, 7] - print(out1.shape) # [3, 6, 7] - print(out2.shape) # [4, 6, 7] - out0, out1, out2 = paddle.vsplit(x, num_or_sections=[2, 3, -1]) - print(out0.shape) # [2, 6, 7] - print(out1.shape) # [3, 6, 7] - print(out2.shape) # [3, 6, 7] + >>> import paddle + + >>> # x is a Tensor of shape [8, 6, 7] + >>> x = paddle.rand([8, 6, 7]) + >>> out0, out1 = paddle.vsplit(x, num_or_sections=2) + >>> print(out0.shape) + [4, 6, 7] + >>> print(out1.shape) + [4, 6, 7] + >>> out0, out1, out2 = paddle.vsplit(x, num_or_sections=[1, 3, 4]) + >>> print(out0.shape) + [1, 6, 7] + >>> print(out1.shape) + [3, 6, 7] + >>> print(out2.shape) + [4, 6, 7] + >>> out0, out1, out2 = paddle.vsplit(x, num_or_sections=[2, 3, -1]) + >>> print(out0.shape) + [2, 6, 7] + >>> print(out1.shape) + [3, 6, 7] + >>> print(out2.shape) + [3, 6, 7] """ if x.ndim < 2: raise ValueError( @@ -2195,17 +2304,21 @@ def squeeze(x, axis=None, name=None): Examples: .. code-block:: python - import paddle + >>> import paddle - x = paddle.rand([5, 1, 10]) - output = paddle.squeeze(x, axis=1) + >>> x = paddle.rand([5, 1, 10]) + >>> output = paddle.squeeze(x, axis=1) - print(x.shape) # [5, 1, 10] - print(output.shape) # [5, 10] + >>> print(x.shape) + [5, 1, 10] + >>> print(output.shape) + [5, 10] - # output shares data with x in dygraph mode - x[0, 0, 0] = 10. - print(output[0, 0]) # [10.] + >>> # output shares data with x in dygraph mode + >>> x[0, 0, 0] = 10. + >>> print(output[0, 0]) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 10.) """ if axis is None: @@ -2217,7 +2330,18 @@ def squeeze(x, axis=None, name=None): input = x axes = axis - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): + return _C_ops.squeeze(input, axes) + elif in_pir_mode(): + if isinstance(axes, int): + axes = [axes] + if isinstance(axes, paddle.pir.OpResult): + axes.stop_gradient = True + elif isinstance(axes, (list, tuple)): + if paddle.utils._contain_var(axes): + axes = paddle.utils.get_int_tensor_list( + axes, default_dtype='int64' + ) return _C_ops.squeeze(input, axes) else: helper = LayerHelper("squeeze", **locals()) @@ -2317,40 +2441,40 @@ def unique_consecutive( - counts (Tensor), the counts of the every unique consecutive element in the input tensor. counts is provided only if return_counts is True. - Example: + Examples: .. code-block:: python - import paddle - - x = paddle.to_tensor([1, 1, 2, 2, 3, 1, 1, 2]) - output = paddle.unique_consecutive(x) # - print(output) - # Tensor(shape=[5], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [1, 2, 3, 1, 2]) - - _, inverse, counts = paddle.unique_consecutive(x, return_inverse=True, return_counts=True) - print(inverse) - # Tensor(shape=[8], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [0, 0, 1, 1, 2, 3, 3, 4]) - print(counts) - # Tensor(shape=[5], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [2, 2, 1, 2, 1]) - - x = paddle.to_tensor([[2, 1, 3], [3, 0, 1], [2, 1, 3], [2, 1, 3]]) - output = paddle.unique_consecutive(x, axis=0) # - print(output) - # Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [[2, 1, 3], - # [3, 0, 1], - # [2, 1, 3]]) - - x = paddle.to_tensor([[2, 1, 3], [3, 0, 1], [2, 1, 3], [2, 1, 3]]) - output = paddle.unique_consecutive(x, axis=0) # - print(output) - # Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [[2, 1, 3], - # [3, 0, 1], - # [2, 1, 3]]) + >>> import paddle + + >>> x = paddle.to_tensor([1, 1, 2, 2, 3, 1, 1, 2]) + >>> output = paddle.unique_consecutive(x) # + >>> print(output) + Tensor(shape=[5], dtype=int64, place=Place(cpu), stop_gradient=True, + [1, 2, 3, 1, 2]) + + >>> _, inverse, counts = paddle.unique_consecutive(x, return_inverse=True, return_counts=True) + >>> print(inverse) + Tensor(shape=[8], dtype=int64, place=Place(cpu), stop_gradient=True, + [0, 0, 1, 1, 2, 3, 3, 4]) + >>> print(counts) + Tensor(shape=[5], dtype=int64, place=Place(cpu), stop_gradient=True, + [2, 2, 1, 2, 1]) + + >>> x = paddle.to_tensor([[2, 1, 3], [3, 0, 1], [2, 1, 3], [2, 1, 3]]) + >>> output = paddle.unique_consecutive(x, axis=0) # + >>> print(output) + Tensor(shape=[3, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[2, 1, 3], + [3, 0, 1], + [2, 1, 3]]) + + >>> x = paddle.to_tensor([[2, 1, 3], [3, 0, 1], [2, 1, 3], [2, 1, 3]]) + >>> output = paddle.unique_consecutive(x, axis=0) # + >>> print(output) + Tensor(shape=[3, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[2, 1, 3], + [3, 0, 1], + [2, 1, 3]]) """ if axis is None: @@ -2358,7 +2482,7 @@ def unique_consecutive( else: axis = [axis] attr_dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out, inverse, counts = _C_ops.unique_consecutive( x, return_inverse, return_counts, axis, attr_dtype ) @@ -2449,43 +2573,43 @@ def unique( Examples: .. code-block:: python - import paddle - - x = paddle.to_tensor([2, 3, 3, 1, 5, 3]) - unique = paddle.unique(x) - print(unique) - # Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [1, 2, 3, 5]) - - _, indices, inverse, counts = paddle.unique(x, return_index=True, return_inverse=True, return_counts=True) - print(indices) - # Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [3, 0, 1, 4]) - print(inverse) - # Tensor(shape=[6], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [1, 2, 2, 0, 3, 2]) - print(counts) - # Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [1, 1, 3, 1]) - - x = paddle.to_tensor([[2, 1, 3], [3, 0, 1], [2, 1, 3]]) - unique = paddle.unique(x) - print(unique) - # Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [0, 1, 2, 3]) - - unique = paddle.unique(x, axis=0) - print(unique) - # Tensor(shape=[2, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [[2, 1, 3], - # [3, 0, 1]]) + >>> import paddle + + >>> x = paddle.to_tensor([2, 3, 3, 1, 5, 3]) + >>> unique = paddle.unique(x) + >>> print(unique) + Tensor(shape=[4], dtype=int64, place=Place(cpu), stop_gradient=True, + [1, 2, 3, 5]) + + >>> _, indices, inverse, counts = paddle.unique(x, return_index=True, return_inverse=True, return_counts=True) + >>> print(indices) + Tensor(shape=[4], dtype=int64, place=Place(cpu), stop_gradient=True, + [3, 0, 1, 4]) + >>> print(inverse) + Tensor(shape=[6], dtype=int64, place=Place(cpu), stop_gradient=True, + [1, 2, 2, 0, 3, 2]) + >>> print(counts) + Tensor(shape=[4], dtype=int64, place=Place(cpu), stop_gradient=True, + [1, 1, 3, 1]) + + >>> x = paddle.to_tensor([[2, 1, 3], [3, 0, 1], [2, 1, 3]]) + >>> unique = paddle.unique(x) + >>> print(unique) + Tensor(shape=[4], dtype=int64, place=Place(cpu), stop_gradient=True, + [0, 1, 2, 3]) + + >>> unique = paddle.unique(x, axis=0) + >>> print(unique) + Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[2, 1, 3], + [3, 0, 1]]) """ if axis is None: axis = [] else: axis = [axis] attr_dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out, indices, inverse, counts = _C_ops.unique( x, return_index, return_inverse, return_counts, axis, attr_dtype ) @@ -2584,31 +2708,41 @@ def unsqueeze(x, axis, name=None): Examples: .. code-block:: python - import paddle - - x = paddle.rand([5, 10]) - print(x.shape) # [5, 10] - - out1 = paddle.unsqueeze(x, axis=0) - print(out1.shape) # [1, 5, 10] - - out2 = paddle.unsqueeze(x, axis=[0, 2]) - print(out2.shape) # [1, 5, 1, 10] - - axis = paddle.to_tensor([0, 1, 2]) - out3 = paddle.unsqueeze(x, axis=axis) - print(out3.shape) # [1, 1, 1, 5, 10] - - # out1, out2, out3 share data with x in dygraph mode - x[0, 0] = 10. - print(out1[0, 0, 0]) # [10.] - print(out2[0, 0, 0, 0]) # [10.] - print(out3[0, 0, 0, 0, 0]) # [10.] + >>> import paddle + + >>> x = paddle.rand([5, 10]) + >>> print(x.shape) + [5, 10] + + >>> out1 = paddle.unsqueeze(x, axis=0) + >>> print(out1.shape) + [1, 5, 10] + + >>> out2 = paddle.unsqueeze(x, axis=[0, 2]) + >>> print(out2.shape) + [1, 5, 1, 10] + + >>> axis = paddle.to_tensor([0, 1, 2]) + >>> out3 = paddle.unsqueeze(x, axis=axis) + >>> print(out3.shape) + [1, 1, 1, 5, 10] + + >>> # out1, out2, out3 share data with x in dygraph mode + >>> x[0, 0] = 10. + >>> print(out1[0, 0, 0]) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 10.) + >>> print(out2[0, 0, 0, 0]) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 10.) + >>> print(out3[0, 0, 0, 0, 0]) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 10.) """ input = x axes = axis - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): if isinstance(axes, int): axes = [axes] elif isinstance(axes, Variable): @@ -2619,6 +2753,17 @@ def unsqueeze(x, axis, name=None): for item in axes ] return _C_ops.unsqueeze(input, axes) + elif in_pir_mode(): + if isinstance(axes, int): + axes = [axes] + if isinstance(axes, paddle.pir.OpResult): + axes.stop_gradient = True + elif isinstance(axes, (list, tuple)): + if paddle.utils._contain_var(axes): + axes = paddle.utils.get_int_tensor_list( + axes, default_dtype='int64' + ) + return _C_ops.unsqueeze(input, axes) else: check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze') check_variable_and_dtype( @@ -2727,17 +2872,20 @@ def gather(x, index, axis=None, name=None): .. code-block:: python - import paddle + >>> import paddle - input = paddle.to_tensor([[1,2],[3,4],[5,6]]) - index = paddle.to_tensor([0,1]) - output = paddle.gather(input, index, axis=0) - # expected output: [[1,2],[3,4]] + >>> input = paddle.to_tensor([[1,2],[3,4],[5,6]]) + >>> index = paddle.to_tensor([0,1]) + >>> output = paddle.gather(input, index, axis=0) + >>> print(output) + Tensor(shape=[2, 2], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1, 2], + [3, 4]]) """ if axis is None: axis = 0 - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.gather(x, index, axis) else: check_variable_and_dtype( @@ -2793,24 +2941,24 @@ def unbind(input, axis=0): Returns: list(Tensor), The list of segmented Tensor variables. - Example: + Examples: .. code-block:: python - import paddle + >>> import paddle - # input is a Tensor which shape is [3, 4, 5] - input = paddle.rand([3, 4, 5]) + >>> # input is a Tensor which shape is [3, 4, 5] + >>> input = paddle.rand([3, 4, 5]) - [x0, x1, x2] = paddle.unbind(input, axis=0) - # x0.shape [4, 5] - # x1.shape [4, 5] - # x2.shape [4, 5] + >>> [x0, x1, x2] = paddle.unbind(input, axis=0) + >>> # x0.shape [4, 5] + >>> # x1.shape [4, 5] + >>> # x2.shape [4, 5] - [x0, x1, x2, x3] = paddle.unbind(input, axis=1) - # x0.shape [3, 5] - # x1.shape [3, 5] - # x2.shape [3, 5] - # x3.shape [3, 5] + >>> [x0, x1, x2, x3] = paddle.unbind(input, axis=1) + >>> # x0.shape [3, 5] + >>> # x1.shape [3, 5] + >>> # x2.shape [3, 5] + >>> # x3.shape [3, 5] """ if not isinstance(axis, (int)): raise TypeError( @@ -2870,26 +3018,27 @@ def scatter(x, index, updates, overwrite=True, name=None): .. code-block:: python :name: code-example1 - import paddle - #input: - x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32') - index = paddle.to_tensor([2, 1, 0, 1], dtype='int64') - # shape of updates should be the same as x - # shape of updates with dim > 1 should be the same as input - updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32') - overwrite = False - # calculation: - if not overwrite: - for i in range(len(index)): - x[index[i]] = paddle.zeros([2]) - for i in range(len(index)): - if (overwrite): - x[index[i]] = updates[i] - else: - x[index[i]] += updates[i] - # output: - out = paddle.to_tensor([[3, 3], [6, 6], [1, 1]]) - out.shape # [3, 2] + >>> import paddle + >>> #input: + >>> x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32') + >>> index = paddle.to_tensor([2, 1, 0, 1], dtype='int64') + >>> # shape of updates should be the same as x + >>> # shape of updates with dim > 1 should be the same as input + >>> updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32') + >>> overwrite = False + >>> # calculation: + >>> if not overwrite: + ... for i in range(len(index)): + ... x[index[i]] = paddle.zeros([2]) + >>> for i in range(len(index)): + ... if (overwrite): + ... x[index[i]] = updates[i] + ... else: + ... x[index[i]] += updates[i] + >>> # output: + >>> out = paddle.to_tensor([[3, 3], [6, 6], [1, 1]]) + >>> print(out.shape) + [3, 2] **NOTICE**: The order in which updates are applied is nondeterministic, so the output will be nondeterministic if index contains duplicates. @@ -2907,33 +3056,35 @@ def scatter(x, index, updates, overwrite=True, name=None): Examples: .. code-block:: python - import paddle - - x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32') - index = paddle.to_tensor([2, 1, 0, 1], dtype='int64') - updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32') - - output1 = paddle.scatter(x, index, updates, overwrite=False) - # [[3., 3.], - # [6., 6.], - # [1., 1.]] - - output2 = paddle.scatter(x, index, updates, overwrite=True) - # CPU device: - # [[3., 3.], - # [4., 4.], - # [1., 1.]] - # GPU device maybe have two results because of the repeated numbers in index - # result 1: - # [[3., 3.], - # [4., 4.], - # [1., 1.]] - # result 2: - # [[3., 3.], - # [2., 2.], - # [1., 1.]] + >>> import paddle + + >>> x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32') + >>> index = paddle.to_tensor([2, 1, 0, 1], dtype='int64') + >>> updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32') + + >>> output1 = paddle.scatter(x, index, updates, overwrite=False) + >>> print(output1) + Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[3., 3.], + [6., 6.], + [1., 1.]]) + + >>> output2 = paddle.scatter(x, index, updates, overwrite=True) + >>> # CPU device: + >>> # [[3., 3.], + >>> # [4., 4.], + >>> # [1., 1.]] + >>> # GPU device maybe have two results because of the repeated numbers in index + >>> # result 1: + >>> # [[3., 3.], + >>> # [4., 4.], + >>> # [1., 1.]] + >>> # result 2: + >>> # [[3., 3.], + >>> # [2., 2.], + >>> # [1., 1.]] """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.scatter(x, index, updates, overwrite) else: check_variable_and_dtype( @@ -3020,19 +3171,19 @@ def scatter_nd_add(x, index, updates, name=None): .. code-block:: python - import paddle + >>> import paddle - x = paddle.rand(shape=[3, 5, 9, 10], dtype='float32') - updates = paddle.rand(shape=[3, 9, 10], dtype='float32') - index = paddle.to_tensor([[1, 1], - [0, 1], - [1, 3]], dtype='int64') + >>> x = paddle.rand(shape=[3, 5, 9, 10], dtype='float32') + >>> updates = paddle.rand(shape=[3, 9, 10], dtype='float32') + >>> index = paddle.to_tensor([[1, 1], + ... [0, 1], + ... [1, 3]], dtype='int64') - output = paddle.scatter_nd_add(x, index, updates) - print(output.shape) - # [3, 5, 9, 10] + >>> output = paddle.scatter_nd_add(x, index, updates) + >>> print(output.shape) + [3, 5, 9, 10] """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.scatter_nd_add(x, index, updates) else: if x.dtype != updates.dtype: @@ -3077,15 +3228,15 @@ def scatter_nd(index, updates, shape, name=None): .. code-block:: python - import paddle + >>> import paddle - index = paddle.to_tensor([[1, 1], - [0, 1], - [1, 3]], dtype="int64") - updates = paddle.rand(shape=[3, 9, 10], dtype='float32') - shape = [3, 5, 9, 10] + >>> index = paddle.to_tensor([[1, 1], + ... [0, 1], + ... [1, 3]], dtype="int64") + >>> updates = paddle.rand(shape=[3, 9, 10], dtype='float32') + >>> shape = [3, 5, 9, 10] - output = paddle.scatter_nd(index, updates, shape) + >>> output = paddle.scatter_nd(index, updates, shape) """ return scatter_nd_add(zeros(shape, updates.dtype), index, updates, name) @@ -3108,22 +3259,22 @@ def chunk(x, chunks, axis=0, name=None): Examples: .. code-block:: python - import paddle + >>> import paddle - x = paddle.rand([3, 9, 5]) + >>> x = paddle.rand([3, 9, 5]) - out0, out1, out2 = paddle.chunk(x, chunks=3, axis=1) - # out0.shape [3, 3, 5] - # out1.shape [3, 3, 5] - # out2.shape [3, 3, 5] + >>> out0, out1, out2 = paddle.chunk(x, chunks=3, axis=1) + >>> # out0.shape [3, 3, 5] + >>> # out1.shape [3, 3, 5] + >>> # out2.shape [3, 3, 5] - # axis is negative, the real axis is (rank(x) + axis) which real - # value is 1. - out0, out1, out2 = paddle.chunk(x, chunks=3, axis=-2) - # out0.shape [3, 3, 5] - # out1.shape [3, 3, 5] - # out2.shape [3, 3, 5] + >>> # axis is negative, the real axis is (rank(x) + axis) which real + >>> # value is 1. + >>> out0, out1, out2 = paddle.chunk(x, chunks=3, axis=-2) + >>> # out0.shape [3, 3, 5] + >>> # out1.shape [3, 3, 5] + >>> # out2.shape [3, 3, 5] """ check_type(chunks, 'chunks', (int), 'chunk') return split(x, num_or_sections=chunks, axis=axis, name=name) @@ -3149,26 +3300,26 @@ def tile(x, repeat_times, name=None): Examples: .. code-block:: python - import paddle - - data = paddle.to_tensor([1, 2, 3], dtype='int32') - out = paddle.tile(data, repeat_times=[2, 1]) - print(out) - # Tensor(shape=[2, 3], dtype=int32, place=Place(gpu:0), stop_gradient=True, - # [[1, 2, 3], - # [1, 2, 3]]) - - out = paddle.tile(data, repeat_times=(2, 2)) - print(out) - # Tensor(shape=[2, 6], dtype=int32, place=Place(gpu:0), stop_gradient=True, - # [[1, 2, 3, 1, 2, 3], - # [1, 2, 3, 1, 2, 3]]) - - repeat_times = paddle.to_tensor([1, 2], dtype='int32') - out = paddle.tile(data, repeat_times=repeat_times) - print(out) - # Tensor(shape=[1, 6], dtype=int32, place=Place(gpu:0), stop_gradient=True, - # [[1, 2, 3, 1, 2, 3]]) + >>> import paddle + + >>> data = paddle.to_tensor([1, 2, 3], dtype='int32') + >>> out = paddle.tile(data, repeat_times=[2, 1]) + >>> print(out) + Tensor(shape=[2, 3], dtype=int32, place=Place(cpu), stop_gradient=True, + [[1, 2, 3], + [1, 2, 3]]) + + >>> out = paddle.tile(data, repeat_times=(2, 2)) + >>> print(out) + Tensor(shape=[2, 6], dtype=int32, place=Place(cpu), stop_gradient=True, + [[1, 2, 3, 1, 2, 3], + [1, 2, 3, 1, 2, 3]]) + + >>> repeat_times = paddle.to_tensor([1, 2], dtype='int32') + >>> out = paddle.tile(data, repeat_times=repeat_times) + >>> print(out) + Tensor(shape=[1, 6], dtype=int32, place=Place(cpu), stop_gradient=True, + [[1, 2, 3, 1, 2, 3]]) """ def check_input(x, repeat_times): @@ -3237,7 +3388,7 @@ def check_input(x, repeat_times): def get_attr_repeat_times(list_repeat_times): attrs_repeat_times = [] for idx, times in enumerate(list_repeat_times): - if isinstance(times, (Variable, paddle.pir.OpResult)): + if isinstance(times, Variable): attrs_repeat_times.append(-1) else: attrs_repeat_times.append(times) @@ -3288,17 +3439,17 @@ def expand_as(x, y, name=None): Examples: .. code-block:: python - import paddle + >>> import paddle - data_x = paddle.to_tensor([1, 2, 3], 'int32') - data_y = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], 'int32') - out = paddle.expand_as(data_x, data_y) - print(out) - # Tensor(shape=[2, 3], dtype=int32, place=Place(gpu:0), stop_gradient=True, - # [[1, 2, 3], - # [1, 2, 3]]) + >>> data_x = paddle.to_tensor([1, 2, 3], 'int32') + >>> data_y = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], 'int32') + >>> out = paddle.expand_as(data_x, data_y) + >>> print(out) + Tensor(shape=[2, 3], dtype=int32, place=Place(cpu), stop_gradient=True, + [[1, 2, 3], + [1, 2, 3]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.expand_as(x, None, y.shape) else: check_variable_and_dtype( @@ -3358,15 +3509,27 @@ def broadcast_to(x, shape, name=None): Examples: .. code-block:: python - import paddle + >>> import paddle - data = paddle.to_tensor([1, 2, 3], dtype='int32') - out = paddle.broadcast_to(data, shape=[2, 3]) - print(out) - # [[1, 2, 3], [1, 2, 3]] + >>> data = paddle.to_tensor([1, 2, 3], dtype='int32') + >>> out = paddle.broadcast_to(data, shape=[2, 3]) + >>> print(out) + Tensor(shape=[2, 3], dtype=int32, place=Place(cpu), stop_gradient=True, + [[1, 2, 3], + [1, 2, 3]]) """ if in_dynamic_mode(): return _C_ops.expand(x, shape) + elif in_pir_mode(): + place = _current_expected_place() + if isinstance(shape, (list, tuple)): + if paddle.utils._contain_var(shape): + shape = paddle.utils.get_int_tensor_list(shape, place) + elif isinstance(shape, paddle.pir.OpResult): + shape.stop_gradient = True + else: + TypeError("Shape only supports OpReslut, or list, or tuple.") + return _C_ops.expand(x, shape) else: if isinstance(shape, Variable): assert len(shape.shape) == 1, 'shape must be an 1-D Tensor.' @@ -3460,14 +3623,32 @@ def expand(x, shape, name=None): Examples: .. code-block:: python - import paddle + >>> import paddle - data = paddle.to_tensor([1, 2, 3], dtype='int32') - out = paddle.expand(data, shape=[2, 3]) - print(out) - # [[1, 2, 3], [1, 2, 3]] + >>> data = paddle.to_tensor([1, 2, 3], dtype='int32') + >>> out = paddle.expand(data, shape=[2, 3]) + >>> print(out) + Tensor(shape=[2, 3], dtype=int32, place=Place(cpu), stop_gradient=True, + [[1, 2, 3], + [1, 2, 3]]) """ - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): + return _C_ops.expand(x, shape) + elif in_pir_mode(): + if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient: + raise ValueError( + "When the data type of input 'x' for expand is bool, " + "you must set its stop_gradient to be False by " + "some_var.stop_gradient = True, supporting " + "some_var as the input." + ) + if isinstance(shape, paddle.pir.OpResult): + shape.stop_gradient = True + elif isinstance(shape, (list, tuple)): + if paddle.utils._contain_var(shape): + shape = paddle.utils._convert_to_tensor_list(shape) + else: + TypeError("Shape only supports OpReslut, or list, or tuple.") return _C_ops.expand(x, shape) else: if isinstance(shape, Variable): @@ -3578,27 +3759,28 @@ def reshape(x, shape, name=None): Examples: .. code-block:: python - import paddle + >>> import paddle - x = paddle.rand([2, 4, 6], dtype="float32") - positive_four = paddle.full([1], 4, "int32") + >>> x = paddle.rand([2, 4, 6], dtype="float32") + >>> positive_four = paddle.full([1], 4, "int32") - out = paddle.reshape(x, [-1, 0, 3, 2]) - print(out) - # the shape is [2,4,3,2]. + >>> out = paddle.reshape(x, [-1, 0, 3, 2]) + >>> print(out.shape) + [2, 4, 3, 2] - out = paddle.reshape(x, shape=[positive_four, 12]) - print(out) - # the shape of out_2 is [4, 12]. + >>> out = paddle.reshape(x, shape=[positive_four, 12]) + >>> print(out.shape) + [4, 12] - shape_tensor = paddle.to_tensor([8, 6], dtype=paddle.int32) - out = paddle.reshape(x, shape=shape_tensor) - print(out.shape) - # the shape is [8, 6]. - # out shares data with x in dygraph mode - x[0, 0, 0] = 10. - print(out[0, 0]) - # the value is [10.] + >>> shape_tensor = paddle.to_tensor([8, 6], dtype=paddle.int32) + >>> out = paddle.reshape(x, shape=shape_tensor) + >>> print(out.shape) + [8, 6] + >>> # out shares data with x in dygraph mode + >>> x[0, 0, 0] = 10. + >>> print(out[0, 0]) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 10.) """ @@ -3661,6 +3843,26 @@ def get_attr_shape(list_shape): ) return out elif in_pir_mode(): + check_variable_and_dtype( + x, + 'x', + [ + 'float16', + 'float32', + 'float64', + 'int8', + 'uint8', + 'int16', + 'int32', + 'int64', + 'bool', + 'uint16', + ], + 'reshape', + ) + check_type( + shape, 'shape', (list, tuple, paddle.pir.OpResult), 'reshape' + ) if isinstance(shape, (list, tuple)): if paddle.utils._contain_var(shape): new_shape = paddle.utils._convert_to_tensor_list(shape) @@ -3812,13 +4014,16 @@ def gather_nd(x, index, name=None): .. code-block:: python - import paddle + >>> import paddle - x = paddle.to_tensor([[[1, 2], [3, 4], [5, 6]], - [[7, 8], [9, 10], [11, 12]]]) - index = paddle.to_tensor([[0, 1]]) + >>> x = paddle.to_tensor([[[1, 2], [3, 4], [5, 6]], + ... [[7, 8], [9, 10], [11, 12]]]) + >>> index = paddle.to_tensor([[0, 1]]) - output = paddle.gather_nd(x, index) #[[3, 4]] + >>> output = paddle.gather_nd(x, index) + >>> print(output) + Tensor(shape=[1, 2], dtype=int64, place=Place(cpu), stop_gradient=True, + [[3, 4]]) """ if in_dynamic_or_pir_mode(): @@ -3922,22 +4127,22 @@ def strided_slice(x, axes, starts, ends, strides, name=None): Examples: .. code-block:: python - import paddle - x = paddle.zeros(shape=[3,4,5,6], dtype="float32") - # example 1: - # attr starts is a list which doesn't contain Tensor. - axes = [1, 2, 3] - starts = [-3, 0, 2] - ends = [3, 2, 4] - strides_1 = [1, 1, 1] - strides_2 = [1, 1, 2] - sliced_1 = paddle.strided_slice(x, axes=axes, starts=starts, ends=ends, strides=strides_1) - # sliced_1 is x[:, 1:3:1, 0:2:1, 2:4:1]. - # example 2: - # attr starts is a list which contain tensor Tensor. - minus_3 = paddle.full(shape=[1], fill_value=-3, dtype='int32') - sliced_2 = paddle.strided_slice(x, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides_2) - # sliced_2 is x[:, 1:3:1, 0:2:1, 2:4:2]. + >>> import paddle + >>> x = paddle.zeros(shape=[3,4,5,6], dtype="float32") + >>> # example 1: + >>> # attr starts is a list which doesn't contain Tensor. + >>> axes = [1, 2, 3] + >>> starts = [-3, 0, 2] + >>> ends = [3, 2, 4] + >>> strides_1 = [1, 1, 1] + >>> strides_2 = [1, 1, 2] + >>> sliced_1 = paddle.strided_slice(x, axes=axes, starts=starts, ends=ends, strides=strides_1) + >>> # sliced_1 is x[:, 1:3:1, 0:2:1, 2:4:1]. + >>> # example 2: + >>> # attr starts is a list which contain tensor Tensor. + >>> minus_3 = paddle.full(shape=[1], fill_value=-3, dtype='int32') + >>> sliced_2 = paddle.strided_slice(x, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides_2) + >>> # sliced_2 is x[:, 1:3:1, 0:2:1, 2:4:2]. """ if in_dynamic_mode(): return _C_ops.strided_slice(x, axes, starts, ends, strides) @@ -4111,74 +4316,85 @@ def tensordot(x, y, axes=2, name=None): Examples: .. code-block:: python - import paddle - - data_type = 'float64' - - # For two 2-d tensor x and y, the case axes=0 is equivalent to outer product. - # Note that tensordot supports empty axis sequence, so all the axes=0, axes=[], axes=[[]], and axes=[[],[]] are equivalent cases. - x = paddle.arange(4, dtype=data_type).reshape([2, 2]) - y = paddle.arange(4, dtype=data_type).reshape([2, 2]) - z = paddle.tensordot(x, y, axes=0) - # z = [[[[0., 0.], - # [0., 0.]], - # - # [[0., 1.], - # [2., 3.]]], - # - # - # [[[0., 2.], - # [4., 6.]], - # - # [[0., 3.], - # [6., 9.]]]] - - - # For two 1-d tensor x and y, the case axes=1 is equivalent to inner product. - x = paddle.arange(10, dtype=data_type) - y = paddle.arange(10, dtype=data_type) - z1 = paddle.tensordot(x, y, axes=1) - z2 = paddle.dot(x, y) - # z1 = z2 = 285. - - - # For two 2-d tensor x and y, the case axes=1 is equivalent to matrix multiplication. - x = paddle.arange(6, dtype=data_type).reshape([2, 3]) - y = paddle.arange(12, dtype=data_type).reshape([3, 4]) - z1 = paddle.tensordot(x, y, axes=1) - z2 = paddle.matmul(x, y) - # z1 = z2 = [[20., 23., 26., 29.], - # [56., 68., 80., 92.]] - - - # When axes is a 1-d int list, x and y will be contracted along the same given axes. - # Note that axes=[1, 2] is equivalent to axes=[[1, 2]], axes=[[1, 2], []], axes=[[1, 2], [1]], and axes=[[1, 2], [1, 2]]. - x = paddle.arange(24, dtype=data_type).reshape([2, 3, 4]) - y = paddle.arange(36, dtype=data_type).reshape([3, 3, 4]) - z = paddle.tensordot(x, y, axes=[1, 2]) - # z = [[506. , 1298., 2090.], - # [1298., 3818., 6338.]] - - - # When axes is a list containing two 1-d int list, the first will be applied to x and the second to y. - x = paddle.arange(60, dtype=data_type).reshape([3, 4, 5]) - y = paddle.arange(24, dtype=data_type).reshape([4, 3, 2]) - z = paddle.tensordot(x, y, axes=([1, 0], [0, 1])) - # z = [[4400., 4730.], - # [4532., 4874.], - # [4664., 5018.], - # [4796., 5162.], - # [4928., 5306.]] - - - # Thanks to the support of axes expansion, axes=[[0, 1, 3, 4], [1, 0, 3, 4]] can be abbreviated as axes= [[0, 1, 3, 4], [1, 0]]. - x = paddle.arange(720, dtype=data_type).reshape([2, 3, 4, 5, 6]) - y = paddle.arange(720, dtype=data_type).reshape([3, 2, 4, 5, 6]) - z = paddle.tensordot(x, y, axes=[[0, 1, 3, 4], [1, 0]]) - # z = [[23217330., 24915630., 26613930., 28312230.], - # [24915630., 26775930., 28636230., 30496530.], - # [26613930., 28636230., 30658530., 32680830.], - # [28312230., 30496530., 32680830., 34865130.]] + >>> import paddle + + >>> data_type = 'float64' + + >>> # For two 2-d tensor x and y, the case axes=0 is equivalent to outer product. + >>> # Note that tensordot supports empty axis sequence, so all the axes=0, axes=[], axes=[[]], and axes=[[],[]] are equivalent cases. + >>> x = paddle.arange(4, dtype=data_type).reshape([2, 2]) + >>> y = paddle.arange(4, dtype=data_type).reshape([2, 2]) + >>> z = paddle.tensordot(x, y, axes=0) + >>> print(z) + Tensor(shape=[2, 2, 2, 2], dtype=float64, place=Place(cpu), stop_gradient=True, + [[[[0., 0.], + [0., 0.]], + [[0., 1.], + [2., 3.]]], + [[[0., 2.], + [4., 6.]], + [[0., 3.], + [6., 9.]]]]) + + >>> # For two 1-d tensor x and y, the case axes=1 is equivalent to inner product. + >>> x = paddle.arange(10, dtype=data_type) + >>> y = paddle.arange(10, dtype=data_type) + >>> z1 = paddle.tensordot(x, y, axes=1) + >>> z2 = paddle.dot(x, y) + >>> print(z1) + Tensor(shape=[], dtype=float64, place=Place(cpu), stop_gradient=True, + 285.) + >>> print(z2) + Tensor(shape=[], dtype=float64, place=Place(cpu), stop_gradient=True, + 285.) + + + >>> # For two 2-d tensor x and y, the case axes=1 is equivalent to matrix multiplication. + >>> x = paddle.arange(6, dtype=data_type).reshape([2, 3]) + >>> y = paddle.arange(12, dtype=data_type).reshape([3, 4]) + >>> z1 = paddle.tensordot(x, y, axes=1) + >>> z2 = paddle.matmul(x, y) + >>> print(z1) + Tensor(shape=[2, 4], dtype=float64, place=Place(cpu), stop_gradient=True, + [[20., 23., 26., 29.], + [56., 68., 80., 92.]]) + >>> print(z2) + Tensor(shape=[2, 4], dtype=float64, place=Place(cpu), stop_gradient=True, + [[20., 23., 26., 29.], + [56., 68., 80., 92.]]) + + >>> # When axes is a 1-d int list, x and y will be contracted along the same given axes. + >>> # Note that axes=[1, 2] is equivalent to axes=[[1, 2]], axes=[[1, 2], []], axes=[[1, 2], [1]], and axes=[[1, 2], [1, 2]]. + >>> x = paddle.arange(24, dtype=data_type).reshape([2, 3, 4]) + >>> y = paddle.arange(36, dtype=data_type).reshape([3, 3, 4]) + >>> z = paddle.tensordot(x, y, axes=[1, 2]) + >>> print(z) + Tensor(shape=[2, 3], dtype=float64, place=Place(cpu), stop_gradient=True, + [[506. , 1298., 2090.], + [1298., 3818., 6338.]]) + + >>> # When axes is a list containing two 1-d int list, the first will be applied to x and the second to y. + >>> x = paddle.arange(60, dtype=data_type).reshape([3, 4, 5]) + >>> y = paddle.arange(24, dtype=data_type).reshape([4, 3, 2]) + >>> z = paddle.tensordot(x, y, axes=([1, 0], [0, 1])) + >>> print(z) + Tensor(shape=[5, 2], dtype=float64, place=Place(cpu), stop_gradient=True, + [[4400., 4730.], + [4532., 4874.], + [4664., 5018.], + [4796., 5162.], + [4928., 5306.]]) + + >>> # Thanks to the support of axes expansion, axes=[[0, 1, 3, 4], [1, 0, 3, 4]] can be abbreviated as axes= [[0, 1, 3, 4], [1, 0]]. + >>> x = paddle.arange(720, dtype=data_type).reshape([2, 3, 4, 5, 6]) + >>> y = paddle.arange(720, dtype=data_type).reshape([3, 2, 4, 5, 6]) + >>> z = paddle.tensordot(x, y, axes=[[0, 1, 3, 4], [1, 0]]) + >>> print(z) + Tensor(shape=[4, 4], dtype=float64, place=Place(cpu), stop_gradient=True, + [[23217330., 24915630., 26613930., 28312230.], + [24915630., 26775930., 28636230., 30496530.], + [26613930., 28636230., 30658530., 32680830.], + [28312230., 30496530., 32680830., 34865130.]]) """ op_type = 'tensordot' input_dtype = ['float16', 'float32', 'float64'] @@ -4302,16 +4518,15 @@ def as_complex(x, name=None): Examples: .. code-block:: python - import paddle - x = paddle.arange(12, dtype=paddle.float32).reshape([2, 3, 2]) - y = paddle.as_complex(x) - print(y) - - # Tensor(shape=[2, 3], dtype=complex64, place=Place(gpu:0), stop_gradient=True, - # [[1j , (2+3j) , (4+5j) ], - # [(6+7j) , (8+9j) , (10+11j)]]) + >>> import paddle + >>> x = paddle.arange(12, dtype=paddle.float32).reshape([2, 3, 2]) + >>> y = paddle.as_complex(x) + >>> print(y) + Tensor(shape=[2, 3], dtype=complex64, place=Place(cpu), stop_gradient=True, + [[1j , (2+3j) , (4+5j) ], + [(6+7j) , (8+9j) , (10+11j)]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.as_complex(x) else: check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'as_complex') @@ -4349,22 +4564,20 @@ def as_real(x, name=None): Examples: .. code-block:: python - import paddle - x = paddle.arange(12, dtype=paddle.float32).reshape([2, 3, 2]) - y = paddle.as_complex(x) - z = paddle.as_real(y) - print(z) - - # Tensor(shape=[2, 3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, - # [[[0. , 1. ], - # [2. , 3. ], - # [4. , 5. ]], - - # [[6. , 7. ], - # [8. , 9. ], - # [10., 11.]]]) + >>> import paddle + >>> x = paddle.arange(12, dtype=paddle.float32).reshape([2, 3, 2]) + >>> y = paddle.as_complex(x) + >>> z = paddle.as_real(y) + >>> print(z) + Tensor(shape=[2, 3, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[0. , 1. ], + [2. , 3. ], + [4. , 5. ]], + [[6. , 7. ], + [8. , 9. ], + [10., 11.]]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.as_real(x) else: check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'as_real') @@ -4399,22 +4612,33 @@ def repeat_interleave(x, repeats, axis=None, name=None): Examples: .. code-block:: python - import paddle - - x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) - repeats = paddle.to_tensor([3, 2, 1], dtype='int32') - - paddle.repeat_interleave(x, repeats, 1) - # [[1, 1, 1, 2, 2, 3], - # [4, 4, 4, 5, 5, 6]] - - paddle.repeat_interleave(x, 2, 0) - # [[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]] - - paddle.repeat_interleave(x, 2, None) - # [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6] + >>> import paddle + + >>> x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) + >>> repeats = paddle.to_tensor([3, 2, 1], dtype='int32') + + >>> out = paddle.repeat_interleave(x, repeats, 1) + >>> print(out) + Tensor(shape=[2, 6], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1, 1, 1, 2, 2, 3], + [4, 4, 4, 5, 5, 6]]) + + >>> out = paddle.repeat_interleave(x, 2, 0) + >>> print(out) + Tensor(shape=[4, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1, 2, 3], + [1, 2, 3], + [4, 5, 6], + [4, 5, 6]]) + + >>> out = paddle.repeat_interleave(x, 2, None) + >>> print(out) + Tensor(shape=[12], dtype=int64, place=Place(cpu), stop_gradient=True, + [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]) """ + if isinstance(repeats, Variable) and not repeats.shape: + repeats = paddle.reshape(repeats, [1]) if axis is None: x = paddle.flatten(x) axis = 0 @@ -4467,15 +4691,17 @@ def moveaxis(x, source, destination, name=None): Examples: .. code-block:: python - import paddle + >>> import paddle - x = paddle.ones([3, 2, 4]) - paddle.moveaxis(x, [0, 1], [1, 2]).shape - # [4, 3, 2] + >>> x = paddle.ones([3, 2, 4]) + >>> outshape = paddle.moveaxis(x, [0, 1], [1, 2]).shape + >>> print(outshape) + [4, 3, 2] - x = paddle.ones([2, 3]) - paddle.moveaxis(x, 0, 1).shape # equivalent to paddle.t(x) - # [3, 2] + >>> x = paddle.ones([2, 3]) + >>> outshape = paddle.moveaxis(x, 0, 1).shape # equivalent to paddle.t(x) + >>> print(outshape) + [3, 2] """ src = [source] if isinstance(source, int) else source dst = [destination] if isinstance(destination, int) else destination @@ -4529,7 +4755,7 @@ def moveaxis(x, source, destination, name=None): for i in range(len(src_dims)): perm[dst_dims[i]] = src_dims[i] - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out = _C_ops.transpose(x, perm) return out else: @@ -4561,6 +4787,76 @@ def moveaxis(x, source, destination, name=None): return out +def masked_fill(x, mask, value, name=None): + """ + Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor. + + Args: + x (Tensor) : The Destination Tensor. Supported data types are float, + double, int, int64_t,float16 and bfloat16. + mask (Tensor): The boolean tensor indicate the position to be filled. + The data type of mask must be bool. + value (Scalar or 0-D Tensor): The value used to fill the target tensor. + Supported data types are float, double, int, int64_t,float16 and bfloat16. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + Tensor, same dimention and dtype with x. + Examples: + .. code-block:: python + + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> x = paddle.ones((3, 3), dtype="float32") + >>> mask = paddle.to_tensor([[True, True, False]]) + >>> print(mask) + Tensor(shape=[1, 3], dtype=bool, place=Place(gpu:0), stop_gradient=True, + [[True , True , False]]) + >>> out = paddle.masked_fill(x, mask, 2) + >>> print(out) + Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[2., 2., 1.], + [2., 2., 1.], + [2., 2., 1.]]) + """ + if np.isscalar(value): + value = paddle.full([], value, x.dtype) + + mask = paddle.logical_not(mask) + out = paddle.where(mask, x, value) + return out + + +@inplace_apis_in_dygraph_only +def masked_fill_(x, mask, value, name=None): + """ + Inplace version of ``masked_fill`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_masked_fill`. + + Examples: + .. code-block:: python + + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> x = paddle.ones((3, 3), dtype="float32") + >>> mask = paddle.to_tensor([[True, False, False]]) + >>> out = paddle.masked_fill_(x, mask, 2) + >>> print(out) + Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[2., 1., 1.], + [2., 1., 1.], + [2., 1., 1.]]) + """ + if np.isscalar(value): + value = paddle.full([], value, x.dtype) + + mask = paddle.logical_not(mask) + out = paddle.where_(mask, x, value) + return out + + def non_negative_axis(arr, axis): ndim = len(arr.shape) if axis >= 0: @@ -4602,14 +4898,15 @@ def take_along_axis(arr, indices, axis): Examples: .. code-block:: python - import paddle + >>> import paddle - x = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7,8,9]]) - index = paddle.to_tensor([[0]]) - axis = 0 - result = paddle.take_along_axis(x, index, axis) - print(result) - # [[1, 2, 3]] + >>> x = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7,8,9]]) + >>> index = paddle.to_tensor([[0]]) + >>> axis = 0 + >>> result = paddle.take_along_axis(x, index, axis) + >>> print(result) + Tensor(shape=[1, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1, 2, 3]]) """ if len(arr.shape) != len(indices.shape): raise ValueError( @@ -4620,7 +4917,7 @@ def take_along_axis(arr, indices, axis): if not broadcast_shape: # if indices matrix have larger size than arr, arr should broadcast into indices shape. broadcast_shape = indices.shape - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): indices = paddle.broadcast_to(indices, broadcast_shape) broadcast_shape_list = list(broadcast_shape) broadcast_shape_list[axis] = list(arr.shape)[axis] @@ -4679,16 +4976,17 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): Examples: .. code-block:: python - import paddle + >>> import paddle - x = paddle.to_tensor([[10, 30, 20], [60, 40, 50]]) - index = paddle.to_tensor([[0]]) - value = 99 - axis = 0 - result = paddle.put_along_axis(x, index, value, axis) - print(result) - # [[99, 99, 99], - # [60, 40, 50]] + >>> x = paddle.to_tensor([[10, 30, 20], [60, 40, 50]]) + >>> index = paddle.to_tensor([[0]]) + >>> value = 99 + >>> axis = 0 + >>> result = paddle.put_along_axis(x, index, value, axis) + >>> print(result) + Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[99, 99, 99], + [60, 40, 50]]) """ if len(arr.shape) != len(indices.shape): @@ -4781,20 +5079,21 @@ def index_add(x, index, axis, value, name=None): Examples: .. code-block:: python - # required: gpu - import paddle - - input_tensor = paddle.to_tensor(paddle.ones((3, 3)), dtype="float32") - index = paddle.to_tensor([0, 2], dtype="int32") - value = paddle.to_tensor([[1, 1, 1], [1, 1, 1]], dtype="float32") - outplace_res = paddle.index_add(input_tensor, index, 0, value) - print(outplace_res) - # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, - # [[2., 2., 2.], - # [1., 1., 1.], - # [2., 2., 2.]]) + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + + >>> input_tensor = paddle.to_tensor(paddle.ones((3, 3)), dtype="float32") + >>> index = paddle.to_tensor([0, 2], dtype="int32") + >>> value = paddle.to_tensor([[1, 1, 1], [1, 1, 1]], dtype="float32") + >>> outplace_res = paddle.index_add(input_tensor, index, 0, value) + >>> print(outplace_res) + Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[2., 2., 2.], + [1., 1., 1.], + [2., 2., 2.]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.index_add(x, index, value, axis) helper = LayerHelper("index_add", **locals()) @@ -4841,18 +5140,19 @@ def index_add_(x, index, axis, value, name=None): Examples: .. code-block:: python - # required: gpu - import paddle - - input_tensor = paddle.to_tensor(paddle.ones((3, 3)), dtype="float32") - index = paddle.to_tensor([0, 2], dtype="int32") - value = paddle.to_tensor([[1, 1], [1, 1], [1, 1]], dtype="float32") - inplace_res = paddle.index_add_(input_tensor, index, 1, value) - print(inplace_res) - # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, - # [[2., 1., 2.], - # [2., 1., 2.], - # [2., 1., 2.]]) + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + + >>> input_tensor = paddle.to_tensor(paddle.ones((3, 3)), dtype="float32") + >>> index = paddle.to_tensor([0, 2], dtype="int32") + >>> value = paddle.to_tensor([[1, 1], [1, 1], [1, 1]], dtype="float32") + >>> inplace_res = paddle.index_add_(input_tensor, index, 1, value) + >>> print(inplace_res) + Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[2., 1., 2.], + [2., 1., 2.], + [2., 1., 2.]]) """ return _C_ops.index_add_(x, index, value, axis) @@ -4878,25 +5178,25 @@ def index_put_(x, indices, value, accumulate=False, name=None): Examples: .. code-block:: python - import paddle - - x = paddle.zeros([3, 3]) - value = paddle.ones([3]) - ix1 = paddle.to_tensor([0,1,2]) - ix2 = paddle.to_tensor([1,2,1]) - indices=(ix1,ix2) - - out = paddle.index_put_(x,indices,value) - print(x) - # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, - # [[0., 1., 0.], - # [0., 0., 1.], - # [0., 1., 0.]]) - print(out) - # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, - # [[0., 1., 0.], - # [0., 0., 1.], - # [0., 1., 0.]]) + >>> import paddle + + >>> x = paddle.zeros([3, 3]) + >>> value = paddle.ones([3]) + >>> ix1 = paddle.to_tensor([0,1,2]) + >>> ix2 = paddle.to_tensor([1,2,1]) + >>> indices=(ix1,ix2) + + >>> out = paddle.index_put_(x,indices,value) + >>> print(x) + Tensor(shape=[3, 3], dtype=float32, place=Place(cpu), stop_gradient=True, + [[0., 1., 0.], + [0., 0., 1.], + [0., 1., 0.]]) + >>> print(out) + Tensor(shape=[3, 3], dtype=float32, place=Place(cpu), stop_gradient=True, + [[0., 1., 0.], + [0., 0., 1.], + [0., 1., 0.]]) """ return _C_ops.index_put_(x, indices, value, accumulate) @@ -4909,27 +5209,27 @@ def index_put(x, indices, value, accumulate=False, name=None): Examples: .. code-block:: python - import paddle - - x = paddle.zeros([3, 3]) - value = paddle.ones([3]) - ix1 = paddle.to_tensor([0,1,2]) - ix2 = paddle.to_tensor([1,2,1]) - indices=(ix1,ix2) - - out = paddle.index_put(x,indices,value) - print(x) - # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, - # [[0., 0., 0.], - # [0., 0., 0.], - # [0., 0., 0.]]) - print(out) - # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, - # [[0., 1., 0.], - # [0., 0., 1.], - # [0., 1., 0.]]) + >>> import paddle + + >>> x = paddle.zeros([3, 3]) + >>> value = paddle.ones([3]) + >>> ix1 = paddle.to_tensor([0,1,2]) + >>> ix2 = paddle.to_tensor([1,2,1]) + >>> indices=(ix1,ix2) + + >>> out = paddle.index_put(x,indices,value) + >>> print(x) + Tensor(shape=[3, 3], dtype=float32, place=Place(cpu), stop_gradient=True, + [[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + >>> print(out) + Tensor(shape=[3, 3], dtype=float32, place=Place(cpu), stop_gradient=True, + [[0., 1., 0.], + [0., 0., 1.], + [0., 1., 0.]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.index_put(x, indices, value, accumulate) helper = LayerHelper("index_put", **locals()) @@ -4980,28 +5280,28 @@ def unflatten(x, axis, shape, name=None): Examples: .. code-block:: python - import paddle - - x = paddle.randn(shape=[4, 6, 8]) - shape = [2, 3] - axis = 1 - res = paddle.unflatten(x, axis, shape) - print(res.shape) - # [4, 2, 3, 8] - - x = paddle.randn(shape=[4, 6, 8]) - shape = (-1, 2) - axis = -1 - res = paddle.unflatten(x, axis, shape) - print(res.shape) - # [4, 6, 4, 2] - - x = paddle.randn(shape=[4, 6, 8]) - shape = paddle.to_tensor([2, 2]) - axis = 0 - res = paddle.unflatten(x, axis, shape) - print(res.shape) - # [2, 2, 6, 8] + >>> import paddle + + >>> x = paddle.randn(shape=[4, 6, 8]) + >>> shape = [2, 3] + >>> axis = 1 + >>> res = paddle.unflatten(x, axis, shape) + >>> print(res.shape) + [4, 2, 3, 8] + + >>> x = paddle.randn(shape=[4, 6, 8]) + >>> shape = (-1, 2) + >>> axis = -1 + >>> res = paddle.unflatten(x, axis, shape) + >>> print(res.shape) + [4, 6, 4, 2] + + >>> x = paddle.randn(shape=[4, 6, 8]) + >>> shape = paddle.to_tensor([2, 2]) + >>> axis = 0 + >>> res = paddle.unflatten(x, axis, shape) + >>> print(res.shape) + [2, 2, 6, 8] """ # determine whether the input axis is valid. @@ -5050,15 +5350,15 @@ def as_strided(x, shape, stride, offset=0, name=None): Examples: .. code-block:: python - import paddle - paddle.base.set_flags({"FLAGS_use_stride_kernel": True}) + >>> import paddle + >>> paddle.base.set_flags({"FLAGS_use_stride_kernel": True}) - x = paddle.rand([2, 4, 6], dtype="float32") + >>> x = paddle.rand([2, 4, 6], dtype="float32") - out = paddle.as_strided(x, [8, 6], [6, 1]) - print(out) - # the shape is [8, 6]. - # the stride is [6, 1]. + >>> out = paddle.as_strided(x, [8, 6], [6, 1]) + >>> print(out.shape) + [8, 6] + >>> # the stride is [6, 1]. """ return _C_ops.as_strided(x, shape, stride, offset) @@ -5082,22 +5382,24 @@ def view(x, shape_or_dtype, name=None): Examples: .. code-block:: python - import paddle - paddle.base.set_flags({"FLAGS_use_stride_kernel": True}) + >>> import paddle + >>> paddle.base.set_flags({"FLAGS_use_stride_kernel": True}) - x = paddle.rand([2, 4, 6], dtype="float32") + >>> x = paddle.rand([2, 4, 6], dtype="float32") - out = paddle.view(x, [8, 6]) - print(out) + >>> out = paddle.view(x, [8, 6]) + >>> print(out.shape) + [8, 6] + >>> import paddle + >>> paddle.base.set_flags({"FLAGS_use_stride_kernel": True}) - import paddle - paddle.base.set_flags({"FLAGS_use_stride_kernel": True}) + >>> x = paddle.rand([2, 4, 6], dtype="float32") - x = paddle.rand([2, 4, 6], dtype="float32") + >>> out = paddle.view(x, "uint8") + >>> print(out.shape) + [2, 4, 24] - out = paddle.view(x, "uint8") - print(out) """ if isinstance(shape_or_dtype, (list, tuple)): return _C_ops.view_shape(x, shape_or_dtype) @@ -5126,14 +5428,15 @@ def view_as(x, other, name=None): Examples: .. code-block:: python - import paddle - paddle.base.set_flags({"FLAGS_use_stride_kernel": True}) + >>> import paddle + >>> paddle.base.set_flags({"FLAGS_use_stride_kernel": True}) - x = paddle.rand([2, 4, 6], dtype="float32") - y = paddle.rand([8, 6], dtype="float32") + >>> x = paddle.rand([2, 4, 6], dtype="float32") + >>> y = paddle.rand([8, 6], dtype="float32") - out = paddle.view_as(x, y) - print(out) + >>> out = paddle.view_as(x, y) + >>> print(out.shape) + [8, 6] """ return _C_ops.view_shape(x, other.shape) @@ -5159,13 +5462,16 @@ def unfold(x, axis, size, step, name=None): Examples: .. code-block:: python - import paddle - paddle.base.set_flags({"FLAGS_use_stride_kernel": True}) + >>> import paddle + >>> paddle.base.set_flags({"FLAGS_use_stride_kernel": True}) - x = paddle.arange(9, dtype="float64") + >>> x = paddle.arange(9, dtype="float64") - out = paddle.unfold(x, 0, 2, 4) - print(out) # [[0, 1], [4, 5]] + >>> out = paddle.unfold(x, 0, 2, 4) + >>> print(out) + Tensor(shape=[2, 2], dtype=float64, place=Place(cpu), stop_gradient=True, + [[0., 1.], + [4., 5.]]) """ return _C_ops.tensor_unfold(x, axis, size, step) @@ -5181,3 +5487,104 @@ def unfold(x, axis, size, step, name=None): } for name, func in __METHODS.items(): setattr(core.eager.Tensor, name, func) + + +def _index_fill_impl(x, index, axis, value, inplace): + if not isinstance(index, Variable): + raise ValueError("index must be Tensor") + + if not isinstance(value, Variable): + value = paddle.to_tensor(value, dtype=x.dtype) + else: + if len(value.shape) > 0: + raise ValueError("value must be scalar or 0-D tensor") + + x_dim = len(x.shape) + if not (isinstance(axis, int)) or (axis > x_dim - 1) or axis < -x_dim: + raise ValueError( + "The axis should be int, and in range [-rank(x), rank(x))" + ) + + if axis < 0: + axis = axis + x_dim + + perm = list(range(len(x.shape))) + perm[0] = axis + perm[axis] = 0 + + if inplace: + paddle.transpose(x, perm) + paddle.index_put_(x, (index,), value) + return x + else: + out = paddle.transpose(x, perm) + out = paddle.index_put(out, (index,), value) + out = paddle.transpose(out, perm) + return out + + +def index_fill(x, index, axis, value, name=None): + """ + Outplace version of ``index_fill_`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_index_fill_`. + + Examples: + .. code-block:: python + + >>> import paddle + >>> input_tensor = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype='int64') + >>> index = paddle.to_tensor([0, 2], dtype="int32") + >>> value = -1 + >>> res = paddle.index_fill(input_tensor, index, 0, value) + >>> print(input_tensor) + Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[1, 2, 3], + [4, 5, 6], + [7, 8, 9]]) + >>> print(res) + Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[-1, -1, -1], + [ 4, 5, 6], + [-1, -1, -1]]) + + """ + return _index_fill_impl(x, index, axis, value, False) + + +@inplace_apis_in_dygraph_only +def index_fill_(x, index, axis, value, name=None): + """ + Fill the elements of the input tensor with value by the spcific axis and index. + + Args: + x (Tensor) : The Destination Tensor. Supported data types are int32, int64, float16, float32, float64. + index (Tensor): The 1-D Tensor containing the indices to index. + The data type of ``index`` must be int32 or int64. + axis (int): The dimension along which to index. + value (float): The tensor used to fill with. + name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + Tensor, same dimention and dtype with x. + + Examples: + .. code-block:: python + + >>> import paddle + >>> input_tensor = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype='int64') + >>> index = paddle.to_tensor([0, 2], dtype="int32") + >>> value = -1 + >>> res = paddle.index_fill_(input_tensor, index, 0, value) + >>> print(input_tensor) + Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[-1, -1, -1], + [ 4, 5, 6], + [-1, -1, -1]]) + >>> print(res) + Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[-1, -1, -1], + [ 4, 5, 6], + [-1, -1, -1]]) + + """ + return _index_fill_impl(x, index, axis, value, True) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index c21d18d845562..01e01d6b449cb 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -20,7 +20,8 @@ import paddle from paddle import _C_ops, _legacy_C_ops -from paddle.common_ops_import import VarDesc, dygraph_only, dygraph_utils +from paddle.base.libpaddle import DataType +from paddle.common_ops_import import VarDesc, dygraph_utils from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only from ..base.data_feeder import ( @@ -130,7 +131,7 @@ def _get_reduce_axis(axis, x): def _get_reduce_axis_with_tensor(axis, x): - if isinstance(axis, Variable): + if isinstance(axis, (Variable, paddle.pir.OpResult)): if axis.shape[0] == len(x.shape): reduce_all = True else: @@ -340,7 +341,7 @@ def stanh(x, scale_a=0.67, scale_b=1.7159, name=None): """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.stanh(x, scale_a, scale_b) else: check_variable_and_dtype( @@ -941,7 +942,7 @@ def floor_divide(x, y, name=None): [2, 0, 2, 2]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.floor_divide(x, y) else: return _elementwise_op(LayerHelper('elementwise_floordiv', **locals())) @@ -976,6 +977,8 @@ def remainder(x, y, name=None): .. _Introduction to Tensor: ../../guides/beginner/tensor_en.html#chapter5-broadcasting-of-tensor + And `mod`, `floor_mod` are all functions with the same name + Args: x (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64. y (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64. @@ -997,8 +1000,18 @@ def remainder(x, y, name=None): Tensor(shape=[4], dtype=int64, place=Place(cpu), stop_gradient=True, [0, 3, 2, 1]) + >>> z = paddle.floor_mod(x, y) + >>> print(z) + Tensor(shape=[4], dtype=int64, place=Place(cpu), stop_gradient=True, + [0, 3, 2, 1]) + + >>> z = paddle.mod(x, y) + >>> print(z) + Tensor(shape=[4], dtype=int64, place=Place(cpu), stop_gradient=True, + [0, 3, 2, 1]) + """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.remainder(x, y) else: return _elementwise_op(LayerHelper('elementwise_mod', **locals())) @@ -1020,14 +1033,14 @@ def remainder_(x, y, name=None): return _C_ops.remainder_(x, y) -mod = remainder # noqa: F841 -floor_mod = remainder # noqa: F841 -mod_ = remainder_ # noqa: F841 +mod = remainder +floor_mod = remainder +mod_ = remainder_ mod_.__doc__ = r""" Inplace version of ``mod`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_mod`. """ -floor_mod_ = remainder_ # noqa: F841 +floor_mod_ = remainder_ floor_mod_.__doc__ = r""" Inplace version of ``floor_mod_`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_floor_mod_`. @@ -1105,13 +1118,10 @@ def multiply_(x, y, name=None): return _C_ops.multiply_(x, y) -@dygraph_only -def _elementwise_op_with_axis_in_dygraph( - x, y, axis=-1, name=None, op_type="Undifined" -): +def _elementwise_op_with_axis(x, y, axis=-1, name=None, op_type="Undifined"): assert ( - in_dynamic_mode() - ), "You can only call `_elementwise_op_with_axis_in_dygraph` function within in_dynamic_mode" + in_dynamic_or_pir_mode() + ), "You can only call `_elementwise_op_with_axis` function within in_dynamic_or_pir_mode" assert op_type in ["add", "subtract", "multiply", "divide"], ( "op_name input error! _elementwise_op_with_axis is an inner function to replace elementwise_add/sub/mul/div. Input op_name=%s, Expect op_name=[add|subtract|multiply|divide]\n" % op_type @@ -1132,8 +1142,8 @@ def _elementwise_op_with_axis_in_dygraph( def _add_with_axis(x, y, axis=-1, name=None): # opt performance, only dynamic mode needs reshape - if in_dynamic_mode(): - return _elementwise_op_with_axis_in_dygraph(x, y, axis, name, "add") + if in_dynamic_or_pir_mode(): + return _elementwise_op_with_axis(x, y, axis, name, "add") else: op_type = 'elementwise_add' return _elementwise_op(LayerHelper(op_type, **locals())) @@ -1142,9 +1152,7 @@ def _add_with_axis(x, y, axis=-1, name=None): def _subtract_with_axis(x, y, axis=-1, name=None): # opt performance, only dynamic mode needs reshape if in_dynamic_mode(): - return _elementwise_op_with_axis_in_dygraph( - x, y, axis, name, "subtract" - ) + return _elementwise_op_with_axis(x, y, axis, name, "subtract") else: op_type = 'elementwise_sub' return _elementwise_op(LayerHelper(op_type, **locals())) @@ -1153,9 +1161,7 @@ def _subtract_with_axis(x, y, axis=-1, name=None): def _multiply_with_axis(x, y, axis=-1, name=None): # opt performance, only dynamic mode needs reshape if in_dynamic_mode(): - return _elementwise_op_with_axis_in_dygraph( - x, y, axis, name, "multiply" - ) + return _elementwise_op_with_axis(x, y, axis, name, "multiply") else: op_type = 'elementwise_mul' return _elementwise_op(LayerHelper(op_type, **locals())) @@ -1164,7 +1170,7 @@ def _multiply_with_axis(x, y, axis=-1, name=None): def _divide_with_axis(x, y, axis=-1, name=None): # opt performance, only dynamic mode needs reshape if in_dynamic_mode(): - return _elementwise_op_with_axis_in_dygraph(x, y, axis, name, "divide") + return _elementwise_op_with_axis(x, y, axis, name, "divide") else: op_type = 'elementwise_div' return _elementwise_op(LayerHelper(op_type, **locals())) @@ -1352,7 +1358,7 @@ def fmax(x, y, name=None): Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, [5. , 3. , inf.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.fmax(x, y) else: return _elementwise_op(LayerHelper('elementwise_fmax', **locals())) @@ -1416,7 +1422,7 @@ def fmin(x, y, name=None): Tensor(shape=[3], dtype=float64, place=Place(cpu), stop_gradient=True, [ 1. , -inf., 5. ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.fmin(x, y) else: return _elementwise_op(LayerHelper('elementwise_fmin', **locals())) @@ -1955,7 +1961,7 @@ def add_n(inputs, name=None): [14., 16., 18.]]) """ if in_dynamic_or_pir_mode(): - if isinstance(inputs, Variable): + if isinstance(inputs, (Variable, paddle.pir.OpResult)): inputs = [inputs] return _C_ops.add_n(inputs) else: @@ -2021,7 +2027,7 @@ def trunc(input, name=None): [[ 0., 1.], [-0., -2.]]) ''' - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.trunc(input) else: inputs = {"X": input} @@ -2666,7 +2672,7 @@ def inverse(x, name=None): [0. , 0.50000000]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.inverse(x) else: @@ -2937,7 +2943,7 @@ def min(x, axis=None, keepdim=False, name=None): [0., 0.]]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.min(x, axis, keepdim) else: reduce_all, axis = _get_reduce_axis_with_tensor(axis, x) @@ -3086,7 +3092,7 @@ def amax(x, axis=None, keepdim=False, name=None): [[0.50000000, 0.33333333], [0. , 0. ]]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.amax(x, axis, keepdim) else: @@ -3234,7 +3240,7 @@ def amin(x, axis=None, keepdim=False, name=None): [[0.50000000, 0.33333333], [0. , 0. ]]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.amin(x, axis, keepdim) else: @@ -3355,7 +3361,7 @@ def log2(x, name=None): Tensor(shape=[1], dtype=float64, place=Place(cpu), stop_gradient=True, [1.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.log2(x) else: check_variable_and_dtype( @@ -3429,7 +3435,7 @@ def log10(x, name=None): Tensor(shape=[1], dtype=float64, place=Place(cpu), stop_gradient=True, [1.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.log10(x) else: check_variable_and_dtype( @@ -3670,7 +3676,7 @@ def __check_input(x, offset, axis1, axis2): "But received axis1 = %d, axis2 = %d\n" % (axis1, axis2) ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.trace(x, offset, axis1, axis2) else: __check_input(x, offset, axis1, axis2) @@ -3925,7 +3931,7 @@ def cumsum(x, axis=None, dtype=None, name=None): if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): x = cast(x, dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if axis is None: axis = -1 return _C_ops.cumsum(x, axis, flatten, False, False) @@ -4029,7 +4035,7 @@ def cummax(x, axis=None, dtype='int64', name=None): check_dtype(dtype, 'dtype', ['int32', 'int64'], 'cummax') dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.cummax(x, axis, dtype) else: check_variable_and_dtype( @@ -4114,7 +4120,7 @@ def cummin(x, axis=None, dtype='int64', name=None): check_dtype(dtype, 'dtype', ['int32', 'int64'], 'cummin') dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.cummin(x, axis, dtype) else: check_variable_and_dtype( @@ -4197,7 +4203,7 @@ def logcumsumexp(x, axis=None, dtype=None, name=None): if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): x = cast(x, dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if axis is None: axis = -1 return _C_ops.logcumsumexp(x, axis, flatten, False, False) @@ -4278,7 +4284,7 @@ def cumprod(x, dim=None, dtype=None, name=None): if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): x = cast(x, dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.cumprod(x, dim) else: check_variable_and_dtype( @@ -4437,7 +4443,7 @@ def isnan(x, name=None): Tensor(shape=[7], dtype=bool, place=Place(cpu), stop_gradient=True, [False, False, False, False, False, True , True ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.isnan(x) else: helper = LayerHelper("isnan_v2", **locals()) @@ -4538,7 +4544,7 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): x = cast(x, dtype) reduce_all, axis = _get_reduce_axis_with_tensor(axis, x) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.prod(x, axis, keepdim, reduce_all) else: helper = LayerHelper('reduce_prod', **locals()) @@ -4565,7 +4571,7 @@ def sign(x, name=None): Returns sign of every element in `x`: 1 for positive, -1 for negative and 0 for zero. Args: - x (Tensor): The input tensor. The data type can be float16, float32 or float64. + x (Tensor): The input tensor. The data type can be int8, int16, int32, int64, float16, float32 or float64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -4582,11 +4588,23 @@ def sign(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [ 1., 0., -1., 1.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.sign(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'sign' + x, + 'x', + [ + 'int8', + 'int16', + 'int32', + 'int64', + 'float16', + 'float32', + 'float64', + 'uint16', + ], + 'sign', ) helper = LayerHelper("sign", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -4669,7 +4687,7 @@ def increment(x, value=1.0, name=None): [1.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.increment_(x, value) else: check_variable_and_dtype( @@ -4833,7 +4851,7 @@ def any(x, axis=None, keepdim=False, name=None): [True]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.any(x, axis, keepdim) else: reduce_all, axis = _get_reduce_axis(axis, x) @@ -4976,7 +4994,7 @@ def digamma(x, name=None): [ nan , 5.32286835]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.digamma(x) else: check_variable_and_dtype( @@ -5197,7 +5215,7 @@ def logit(x, eps=None, name=None): """ if eps is None: eps = 0.0 - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.logit(x, eps) else: check_variable_and_dtype( @@ -5332,7 +5350,7 @@ def erfinv(x, name=None): [ 0. , 0.47693631, -inf. ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.erfinv(x) else: check_variable_and_dtype( @@ -5938,7 +5956,7 @@ def angle(x, name=None): [-1.10714877, -0.78539819, 0. , 0.78539819]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.angle(x) else: check_variable_and_dtype( @@ -6044,13 +6062,15 @@ def frac(x, name=None): paddle.int64, paddle.float32, paddle.float64, + DataType.INT32, + DataType.INT64, + DataType.FLOAT32, + DataType.FLOAT64, ]: raise TypeError( - "The data type of input must be one of ['int32', 'int64', 'float32', 'float64'], but got {}".format( - x.dtype - ) + f"The data type of input must be one of ['int32', 'int64', 'float32', 'float64'], but got {x.dtype}" ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): y = _C_ops.trunc(x) return _C_ops.subtract(x, y) else: @@ -6082,9 +6102,7 @@ def frac_(x, name=None): paddle.float64, ]: raise TypeError( - "The data type of input must be one of ['int32', 'int64', 'float32', 'float64'], but got {}".format( - x.dtype - ) + f"The data type of input must be one of ['int32', 'int64', 'float32', 'float64'], but got {x.dtype}" ) if in_dynamic_mode(): y = _C_ops.trunc(x) @@ -6131,9 +6149,7 @@ def sgn(x, name=None): paddle.complex128, ]: raise TypeError( - "The data type of input must be one of ['float16', 'float32', 'float64', 'complex64', 'complex128'], but got {}".format( - x.dtype - ) + f"The data type of input must be one of ['float16', 'float32', 'float64', 'complex64', 'complex128'], but got {x.dtype}" ) if paddle.is_complex(x): expand_x = paddle.as_real(x) @@ -6287,9 +6303,7 @@ def frexp(x, name=None): """ if x.dtype not in [paddle.float32, paddle.float64]: raise TypeError( - "The data type of input must be one of ['float32', 'float64'], but got {}".format( - x.dtype - ) + f"The data type of input must be one of ['float32', 'float64'], but got {x.dtype}" ) input_x = paddle.abs(x) exponent = paddle.floor(paddle.log2(input_x)) @@ -6341,9 +6355,7 @@ def _trapezoid(y, x=None, dx=None, axis=-1, mode='sum'): raise ValueError("Not permitted to specify both x and dx input args.") if y.dtype not in [paddle.float16, paddle.float32, paddle.float64]: raise TypeError( - "The data type of input must be Tensor, and dtype should be one of ['paddle.float16', 'paddle.float32', 'paddle.float64'], but got {}".format( - y.dtype - ) + f"The data type of input must be Tensor, and dtype should be one of ['paddle.float16', 'paddle.float32', 'paddle.float64'], but got {y.dtype}" ) y_shape = y.shape @@ -6359,9 +6371,7 @@ def _trapezoid(y, x=None, dx=None, axis=-1, mode='sum'): else: if x.dtype not in [paddle.float16, paddle.float32, paddle.float64]: raise TypeError( - "The data type of input must be Tensor, and dtype should be one of ['paddle.float16', 'paddle.float32', 'paddle.float64'], but got {}".format( - x.dtype - ) + f"The data type of input must be Tensor, and dtype should be one of ['paddle.float16', 'paddle.float32', 'paddle.float64'], but got {x.dtype}" ) # Reshape to correct shape if x.dim() == 1: @@ -6654,7 +6664,7 @@ def i0(x, name=None): Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True, [0.99999994 , 1.26606596 , 2.27958512 , 4.88079262 , 11.30192089]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.i0(x) else: check_variable_and_dtype(x, "x", ["float32", "float64"], "i0") @@ -6703,7 +6713,7 @@ def i0e(x, name=None): Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True, [0.99999994, 0.46575963, 0.30850831, 0.24300036, 0.20700191]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.i0e(x) else: check_variable_and_dtype(x, "x", ["float32", "float64"], "i0e") @@ -6735,7 +6745,7 @@ def i1(x, name=None): Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True, [0. , 0.56515908, 1.59063685, 3.95337057, 9.75946712]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.i1(x) else: check_variable_and_dtype(x, "x", ["float32", "float64"], "i1") @@ -6770,7 +6780,7 @@ def i1e(x, name=None): Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True, [0. , 0.20791042, 0.21526928, 0.19682673, 0.17875087]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.i1e(x) else: check_variable_and_dtype(x, "x", ["float32", "float64"], "i1e") @@ -6823,7 +6833,7 @@ def polygamma(x, n, name=None): if n == 0: return digamma(x) else: - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.polygamma(x, n) else: check_variable_and_dtype( @@ -6932,3 +6942,56 @@ def ldexp_(x, y, name=None): y = paddle.cast(y, dtype=out_dtype) two = paddle.to_tensor(2, dtype=out_dtype) return paddle.multiply_(x, paddle.pow(two, y)) + + +def hypot(x, y, name=None): + """ + Calculate the length of the hypotenuse of a right-angle triangle. The equation is: + + .. math:: + out = {\\sqrt{x^2 + y^2}} + + Args: + x (Tensor): The input Tensor, the data type is float32, float64, int32 or int64. + y (Tensor): The input Tensor, the data type is float32, float64, int32 or int64. + name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): An N-D Tensor. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y. And the data type is float32 or float64. + + Examples: + + .. code-block:: python + + >>> import paddle + + >>> x = paddle.to_tensor([3], dtype='float32') + >>> y = paddle.to_tensor([4], dtype='float32') + >>> res = paddle.hypot(x, y) + >>> print(res) + Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + [5.]) + + """ + if not isinstance(x, (paddle.Tensor, Variable)): + raise TypeError(f"x must be tensor type, but got {type(x)}") + if not isinstance(y, (paddle.Tensor, Variable)): + raise TypeError(f"y must be tensor type, but got {type(y)}") + + out = (paddle.pow(x, 2) + paddle.pow(y, 2)).sqrt() + return out + + +@inplace_apis_in_dygraph_only +def hypot_(x, y, name=None): + r""" + Inplace version of ``hypot`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_hypot`. + """ + if not isinstance(x, (paddle.Tensor, Variable)): + raise TypeError(f"x must be tensor type, but got {type(x)}") + if not isinstance(y, (paddle.Tensor, Variable)): + raise TypeError(f"y must be tensor type, but got {type(y)}") + + out = x.pow_(2).add_(y.pow(2)).sqrt_() + return out diff --git a/python/paddle/tensor/ops.py b/python/paddle/tensor/ops.py index 54505f952bdfc..ebff56db96dfa 100644 --- a/python/paddle/tensor/ops.py +++ b/python/paddle/tensor/ops.py @@ -242,7 +242,7 @@ def acos(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [1.98231316, 1.77215421, 1.47062886, 1.26610363]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.acos(x) else: check_variable_and_dtype( @@ -289,7 +289,7 @@ def acosh(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [0. , 1.76274717, 2.06343699, 2.29243159]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.acosh(x) else: check_variable_and_dtype( @@ -336,7 +336,7 @@ def asin(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [-0.41151685, -0.20135793, 0.10016742, 0.30469266]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.asin(x) else: check_variable_and_dtype( @@ -383,7 +383,7 @@ def asinh(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [-0.39003533, -0.19869010, 0.09983408, 0.29567307]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.asinh(x) else: check_variable_and_dtype( @@ -430,7 +430,7 @@ def atan(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [-0.38050640, -0.19739556, 0.09966865, 0.29145682]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.atan(x) else: check_variable_and_dtype( @@ -477,7 +477,7 @@ def atanh(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [-0.42364895, -0.20273255, 0.10033534, 0.30951962]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.atanh(x) else: check_variable_and_dtype( @@ -564,7 +564,7 @@ def cos(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [0.92106098, 0.98006660, 0.99500418, 0.95533651]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.cos(x) else: check_variable_and_dtype( @@ -754,7 +754,7 @@ def floor(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [-1., -1., 0., 0.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.floor(x) else: check_variable_and_dtype( @@ -839,7 +839,7 @@ def round(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [-1., -0., 1., 2.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.round(x) else: check_variable_and_dtype( @@ -916,7 +916,7 @@ def sigmoid(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [0.40131235, 0.45016602, 0.52497917, 0.57444251]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.sigmoid(x) else: check_variable_and_dtype( @@ -963,7 +963,7 @@ def sin(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [-0.38941833, -0.19866933, 0.09983342, 0.29552022]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.sin(x) else: check_variable_and_dtype( @@ -1010,7 +1010,7 @@ def sinh(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [-0.41075233, -0.20133601, 0.10016675, 0.30452031]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.sinh(x) else: check_variable_and_dtype( @@ -1097,7 +1097,7 @@ def square(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [0.16000001, 0.04000000, 0.01000000, 0.09000000]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.square(x) else: check_variable_and_dtype( @@ -1147,7 +1147,7 @@ def tan(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [-0.42279324, -0.20271003, 0.10033467, 0.30933627]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.tan(x) else: check_variable_and_dtype( diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index f87e669cf198e..c9d4024904178 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -794,7 +794,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): shape = paddle.utils.convert_shape_to_list(shape) return _C_ops.uniform( shape, @@ -804,6 +804,29 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): seed, _current_expected_place(), ) + elif in_pir_mode(): + check_type( + shape, 'shape', (list, tuple, paddle.pir.OpResult), 'uniform/rand' + ) + check_dtype(dtype, 'dtype', supported_dtypes, 'uniform/rand') + check_type( + min, 'min', (float, int, paddle.pir.OpResult), 'uniform/rand' + ) + check_type( + max, 'max', (float, int, paddle.pir.OpResult), 'uniform/rand' + ) + if paddle.utils._contain_var(shape): + shape = paddle.utils.get_int_tensor_list( + shape, _current_expected_place() + ) + return _C_ops.uniform( + shape, + dtype, + float(min), + float(max), + seed, + _current_expected_place(), + ) else: check_type(shape, 'shape', (list, tuple, Variable), 'uniform/rand') check_dtype(dtype, 'dtype', supported_dtypes, 'uniform/rand') @@ -957,9 +980,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): if high is None: if low <= 0: raise ValueError( - "If high is None, low must be greater than 0, but received low = {}.".format( - low - ) + f"If high is None, low must be greater than 0, but received low = {low}." ) high = low low = 0 @@ -1145,9 +1166,7 @@ def randint_like(x, low=0, high=None, dtype=None, name=None): if high is None: if low <= 0: raise ValueError( - "If high is None, low must be greater than 0, but received low = {}.".format( - low - ) + f"If high is None, low must be greater than 0, but received low = {low}." ) high = low low = 0 diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index c33bd0cd4f415..51f09119ef2e4 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -59,42 +59,45 @@ def argsort(x, axis=-1, descending=False, name=None): .. code-block:: python - import paddle - - x = paddle.to_tensor([[[5,8,9,5], - [0,0,1,7], - [6,9,2,4]], - [[5,2,4,2], - [4,7,7,9], - [1,7,0,6]]], - dtype='float32') - out1 = paddle.argsort(x, axis=-1) - out2 = paddle.argsort(x, axis=0) - out3 = paddle.argsort(x, axis=1) - - print(out1) - #[[[0 3 1 2] - # [0 1 2 3] - # [2 3 0 1]] - # [[1 3 2 0] - # [0 1 2 3] - # [2 0 3 1]]] - - print(out2) - #[[[0 1 1 1] - # [0 0 0 0] - # [1 1 1 0]] - # [[1 0 0 0] - # [1 1 1 1] - # [0 0 0 1]]] - - print(out3) - #[[[1 1 1 2] - # [0 0 2 0] - # [2 2 0 1]] - # [[2 0 2 0] - # [1 1 0 2] - # [0 2 1 1]]] + >>> import paddle + + >>> x = paddle.to_tensor([[[5,8,9,5], + ... [0,0,1,7], + ... [6,9,2,4]], + ... [[5,2,4,2], + ... [4,7,7,9], + ... [1,7,0,6]]], + ... dtype='float32') + >>> out1 = paddle.argsort(x, axis=-1) + >>> out2 = paddle.argsort(x, axis=0) + >>> out3 = paddle.argsort(x, axis=1) + + >>> print(out1) + Tensor(shape=[2, 3, 4], dtype=int64, place=Place(cpu), stop_gradient=True, + [[[0, 3, 1, 2], + [0, 1, 2, 3], + [2, 3, 0, 1]], + [[1, 3, 2, 0], + [0, 1, 2, 3], + [2, 0, 3, 1]]]) + + >>> print(out2) + Tensor(shape=[2, 3, 4], dtype=int64, place=Place(cpu), stop_gradient=True, + [[[0, 1, 1, 1], + [0, 0, 0, 0], + [1, 1, 1, 0]], + [[1, 0, 0, 0], + [1, 1, 1, 1], + [0, 0, 0, 1]]]) + + >>> print(out3) + Tensor(shape=[2, 3, 4], dtype=int64, place=Place(cpu), stop_gradient=True, + [[[1, 1, 1, 2], + [0, 0, 2, 0], + [2, 2, 0, 1]], + [[2, 0, 2, 0], + [1, 1, 0, 2], + [0, 2, 1, 1]]]) """ if in_dynamic_mode(): _, ids = _C_ops.argsort(x, axis, descending) @@ -154,22 +157,23 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): Examples: .. code-block:: python - import paddle - - x = paddle.to_tensor([[5,8,9,5], - [0,0,1,7], - [6,9,2,4]]) - out1 = paddle.argmax(x) - print(out1) # 2 - out2 = paddle.argmax(x, axis=0) - print(out2) - # [2, 2, 0, 1] - out3 = paddle.argmax(x, axis=-1) - print(out3) - # [2, 3, 1] - out4 = paddle.argmax(x, axis=0, keepdim=True) - print(out4) - # [[2, 2, 0, 1]] + >>> import paddle + + >>> x = paddle.to_tensor([[5,8,9,5], + ... [0,0,1,7], + ... [6,9,2,4]]) + >>> out1 = paddle.argmax(x) + >>> print(out1.numpy()) + 2 + >>> out2 = paddle.argmax(x, axis=0) + >>> print(out2.numpy()) + [2 2 0 1] + >>> out3 = paddle.argmax(x, axis=-1) + >>> print(out3.numpy()) + [2 3 1] + >>> out4 = paddle.argmax(x, axis=0, keepdim=True) + >>> print(out4.numpy()) + [[2 2 0 1]] """ if axis is not None and not isinstance( axis, (int, Variable, paddle.pir.OpResult) @@ -246,22 +250,23 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): Examples: .. code-block:: python - import paddle - - x = paddle.to_tensor([[5,8,9,5], - [0,0,1,7], - [6,9,2,4]]) - out1 = paddle.argmin(x) - print(out1) # 4 - out2 = paddle.argmin(x, axis=0) - print(out2) - # [1, 1, 1, 2] - out3 = paddle.argmin(x, axis=-1) - print(out3) - # [0, 0, 2] - out4 = paddle.argmin(x, axis=0, keepdim=True) - print(out4) - # [[1, 1, 1, 2]] + >>> import paddle + + >>> x = paddle.to_tensor([[5,8,9,5], + ... [0,0,1,7], + ... [6,9,2,4]]) + >>> out1 = paddle.argmin(x) + >>> print(out1.numpy()) + 4 + >>> out2 = paddle.argmin(x, axis=0) + >>> print(out2.numpy()) + [1 1 1 2] + >>> out3 = paddle.argmin(x, axis=-1) + >>> print(out3.numpy()) + [0 0 2] + >>> out4 = paddle.argmin(x, axis=0, keepdim=True) + >>> print(out4.numpy()) + [[1 1 1 2]] """ if axis is not None and not isinstance( axis, (int, Variable, paddle.pir.OpResult) @@ -335,20 +340,22 @@ def index_select(x, index, axis=0, name=None): Examples: .. code-block:: python - import paddle - - x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 11.0, 12.0]]) - index = paddle.to_tensor([0, 1, 1], dtype='int32') - out_z1 = paddle.index_select(x=x, index=index) - #[[1. 2. 3. 4.] - # [5. 6. 7. 8.] - # [5. 6. 7. 8.]] - out_z2 = paddle.index_select(x=x, index=index, axis=1) - #[[ 1. 2. 2.] - # [ 5. 6. 6.] - # [ 9. 10. 10.]] + >>> import paddle + + >>> x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0], + ... [5.0, 6.0, 7.0, 8.0], + ... [9.0, 10.0, 11.0, 12.0]]) + >>> index = paddle.to_tensor([0, 1, 1], dtype='int32') + >>> out_z1 = paddle.index_select(x=x, index=index) + >>> print(out_z1.numpy()) + [[1. 2. 3. 4.] + [5. 6. 7. 8.] + [5. 6. 7. 8.]] + >>> out_z2 = paddle.index_select(x=x, index=index, axis=1) + >>> print(out_z2.numpy()) + [[ 1. 2. 2.] + [ 5. 6. 6.] + [ 9. 10. 10.]] """ if in_dynamic_mode(): @@ -409,42 +416,50 @@ def nonzero(x, as_tuple=False): .. code-block:: python - import paddle - - x1 = paddle.to_tensor([[1.0, 0.0, 0.0], - [0.0, 2.0, 0.0], - [0.0, 0.0, 3.0]]) - x2 = paddle.to_tensor([0.0, 1.0, 0.0, 3.0]) - out_z1 = paddle.nonzero(x1) - print(out_z1) - #[[0 0] - # [1 1] - # [2 2]] - out_z1_tuple = paddle.nonzero(x1, as_tuple=True) - for out in out_z1_tuple: - print(out) - #[[0] - # [1] - # [2]] - #[[0] - # [1] - # [2]] - out_z2 = paddle.nonzero(x2) - print(out_z2) - #[[1] - # [3]] - out_z2_tuple = paddle.nonzero(x2, as_tuple=True) - for out in out_z2_tuple: - print(out) - #[[1] - # [3]] + >>> import paddle + + >>> x1 = paddle.to_tensor([[1.0, 0.0, 0.0], + ... [0.0, 2.0, 0.0], + ... [0.0, 0.0, 3.0]]) + >>> x2 = paddle.to_tensor([0.0, 1.0, 0.0, 3.0]) + >>> out_z1 = paddle.nonzero(x1) + >>> print(out_z1) + Tensor(shape=[3, 2], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 0], + [1, 1], + [2, 2]]) + + >>> out_z1_tuple = paddle.nonzero(x1, as_tuple=True) + >>> for out in out_z1_tuple: + ... print(out) + Tensor(shape=[3, 1], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0], + [1], + [2]]) + Tensor(shape=[3, 1], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0], + [1], + [2]]) + + >>> out_z2 = paddle.nonzero(x2) + >>> print(out_z2) + Tensor(shape=[2, 1], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1], + [3]]) + + >>> out_z2_tuple = paddle.nonzero(x2, as_tuple=True) + >>> for out in out_z2_tuple: + ... print(out) + Tensor(shape=[2, 1], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1], + [3]]) """ list_out = [] shape = x.shape rank = len(shape) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): outs = _C_ops.nonzero(x) else: check_variable_and_dtype( @@ -507,39 +522,39 @@ def sort(x, axis=-1, descending=False, name=None): .. code-block:: python - import paddle - - x = paddle.to_tensor([[[5,8,9,5], - [0,0,1,7], - [6,9,2,4]], - [[5,2,4,2], - [4,7,7,9], - [1,7,0,6]]], - dtype='float32') - out1 = paddle.sort(x=x, axis=-1) - out2 = paddle.sort(x=x, axis=0) - out3 = paddle.sort(x=x, axis=1) - print(out1) - #[[[5. 5. 8. 9.] - # [0. 0. 1. 7.] - # [2. 4. 6. 9.]] - # [[2. 2. 4. 5.] - # [4. 7. 7. 9.] - # [0. 1. 6. 7.]]] - print(out2) - #[[[5. 2. 4. 2.] - # [0. 0. 1. 7.] - # [1. 7. 0. 4.]] - # [[5. 8. 9. 5.] - # [4. 7. 7. 9.] - # [6. 9. 2. 6.]]] - print(out3) - #[[[0. 0. 1. 4.] - # [5. 8. 2. 5.] - # [6. 9. 9. 7.]] - # [[1. 2. 0. 2.] - # [4. 7. 4. 6.] - # [5. 7. 7. 9.]]] + >>> import paddle + + >>> x = paddle.to_tensor([[[5,8,9,5], + ... [0,0,1,7], + ... [6,9,2,4]], + ... [[5,2,4,2], + ... [4,7,7,9], + ... [1,7,0,6]]], + ... dtype='float32') + >>> out1 = paddle.sort(x=x, axis=-1) + >>> out2 = paddle.sort(x=x, axis=0) + >>> out3 = paddle.sort(x=x, axis=1) + >>> print(out1.numpy()) + [[[5. 5. 8. 9.] + [0. 0. 1. 7.] + [2. 4. 6. 9.]] + [[2. 2. 4. 5.] + [4. 7. 7. 9.] + [0. 1. 6. 7.]]] + >>> print(out2.numpy()) + [[[5. 2. 4. 2.] + [0. 0. 1. 7.] + [1. 7. 0. 4.]] + [[5. 8. 9. 5.] + [4. 7. 7. 9.] + [6. 9. 2. 6.]]] + >>> print(out3.numpy()) + [[[0. 0. 1. 4.] + [5. 8. 2. 5.] + [6. 9. 9. 7.]] + [[1. 2. 0. 2.] + [4. 7. 4. 6.] + [5. 7. 7. 9.]]] """ if in_dynamic_or_pir_mode(): outs, _ = _C_ops.argsort(x, axis, descending) @@ -580,16 +595,16 @@ def mode(x, axis=-1, keepdim=False, name=None): .. code-block:: python - import paddle + >>> import paddle - tensor = paddle.to_tensor([[[1,2,2],[2,3,3]],[[0,5,5],[9,9,0]]], dtype=paddle.float32) - res = paddle.mode(tensor, 2) - print(res) - # (Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [[2., 3.], - # [5., 9.]]), Tensor(shape=[2, 2], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[2, 2], - # [2, 1]])) + >>> tensor = paddle.to_tensor([[[1,2,2],[2,3,3]],[[0,5,5],[9,9,0]]], dtype=paddle.float32) + >>> res = paddle.mode(tensor, 2) + >>> print(res) + (Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[2., 3.], + [5., 9.]]), Tensor(shape=[2, 2], dtype=int64, place=Place(cpu), stop_gradient=True, + [[2, 2], + [2, 1]])) """ if in_dynamic_mode(): @@ -642,20 +657,21 @@ def where(condition, x=None, y=None, name=None): .. code-block:: python - import paddle + >>> import paddle - x = paddle.to_tensor([0.9383, 0.1983, 3.2, 1.2]) - y = paddle.to_tensor([1.0, 1.0, 1.0, 1.0]) + >>> x = paddle.to_tensor([0.9383, 0.1983, 3.2, 1.2]) + >>> y = paddle.to_tensor([1.0, 1.0, 1.0, 1.0]) - out = paddle.where(x>1, x, y) - print(out) - #out: [1.0, 1.0, 3.2, 1.2] + >>> out = paddle.where(x>1, x, y) + >>> print(out) + Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, + [1. , 1. , 3.20000005, 1.20000005]) - out = paddle.where(x>1) - print(out) - #out: (Tensor(shape=[2, 1], dtype=int64, place=CPUPlace, stop_gradient=True, - # [[2], - # [3]]),) + >>> out = paddle.where(x>1) + >>> print(out) + (Tensor(shape=[2, 1], dtype=int64, place=Place(cpu), stop_gradient=True, + [[2], + [3]]),) """ if np.isscalar(x): x = paddle.full([1], x, np.array([x]).dtype.name) @@ -796,41 +812,41 @@ def index_sample(x, index): .. code-block:: python - import paddle - - x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 11.0, 12.0]], dtype='float32') - index = paddle.to_tensor([[0, 1, 2], - [1, 2, 3], - [0, 0, 0]], dtype='int32') - target = paddle.to_tensor([[100, 200, 300, 400], - [500, 600, 700, 800], - [900, 1000, 1100, 1200]], dtype='int32') - out_z1 = paddle.index_sample(x, index) - print(out_z1) - #[[1. 2. 3.] - # [6. 7. 8.] - # [9. 9. 9.]] - - # Use the index of the maximum value by topk op - # get the value of the element of the corresponding index in other tensors - top_value, top_index = paddle.topk(x, k=2) - out_z2 = paddle.index_sample(target, top_index) - print(top_value) - #[[ 4. 3.] - # [ 8. 7.] - # [12. 11.]] - - print(top_index) - #[[3 2] - # [3 2] - # [3 2]] - - print(out_z2) - #[[ 400 300] - # [ 800 700] - # [1200 1100]] + >>> import paddle + + >>> x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0], + ... [5.0, 6.0, 7.0, 8.0], + ... [9.0, 10.0, 11.0, 12.0]], dtype='float32') + >>> index = paddle.to_tensor([[0, 1, 2], + ... [1, 2, 3], + ... [0, 0, 0]], dtype='int32') + >>> target = paddle.to_tensor([[100, 200, 300, 400], + ... [500, 600, 700, 800], + ... [900, 1000, 1100, 1200]], dtype='int32') + >>> out_z1 = paddle.index_sample(x, index) + >>> print(out_z1.numpy()) + [[1. 2. 3.] + [6. 7. 8.] + [9. 9. 9.]] + + >>> # Use the index of the maximum value by topk op + >>> # get the value of the element of the corresponding index in other tensors + >>> top_value, top_index = paddle.topk(x, k=2) + >>> out_z2 = paddle.index_sample(target, top_index) + >>> print(top_value.numpy()) + [[ 4. 3.] + [ 8. 7.] + [12. 11.]] + + >>> print(top_index.numpy()) + [[3 2] + [3 2] + [3 2]] + + >>> print(out_z2.numpy()) + [[ 400 300] + [ 800 700] + [1200 1100]] """ if in_dynamic_mode(): @@ -885,16 +901,17 @@ def masked_select(x, mask, name=None): .. code-block:: python - import paddle - - x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 11.0, 12.0]]) - mask = paddle.to_tensor([[True, False, False, False], - [True, True, False, False], - [True, False, False, False]]) - out = paddle.masked_select(x, mask) - #[1.0 5.0 6.0 9.0] + >>> import paddle + + >>> x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0], + ... [5.0, 6.0, 7.0, 8.0], + ... [9.0, 10.0, 11.0, 12.0]]) + >>> mask = paddle.to_tensor([[True, False, False, False], + ... [True, True, False, False], + ... [True, False, False, False]]) + >>> out = paddle.masked_select(x, mask) + >>> print(out.numpy()) + [1. 5. 6. 9.] """ if in_dynamic_mode(): @@ -945,30 +962,50 @@ def topk(x, k, axis=None, largest=True, sorted=True, name=None): .. code-block:: python - import paddle - - data_1 = paddle.to_tensor([1, 4, 5, 7]) - value_1, indices_1 = paddle.topk(data_1, k=1) - print(value_1) # [7] - print(indices_1) # [3] - - data_2 = paddle.to_tensor([[1, 4, 5, 7], [2, 6, 2, 5]]) - value_2, indices_2 = paddle.topk(data_2, k=1) - print(value_2) # [[7], [6]] - print(indices_2) # [[3], [1]] - - value_3, indices_3 = paddle.topk(data_2, k=1, axis=-1) - print(value_3) # [[7], [6]] - print(indices_3) # [[3], [1]] - - value_4, indices_4 = paddle.topk(data_2, k=1, axis=0) - print(value_4) # [[2, 6, 5, 7]] - print(indices_4) # [[1, 1, 0, 0]] + >>> import paddle + + >>> data_1 = paddle.to_tensor([1, 4, 5, 7]) + >>> value_1, indices_1 = paddle.topk(data_1, k=1) + >>> print(value_1) + Tensor(shape=[1], dtype=int64, place=Place(cpu), stop_gradient=True, + [7]) + >>> print(indices_1) + Tensor(shape=[1], dtype=int64, place=Place(cpu), stop_gradient=True, + [3]) + + >>> data_2 = paddle.to_tensor([[1, 4, 5, 7], [2, 6, 2, 5]]) + >>> value_2, indices_2 = paddle.topk(data_2, k=1) + >>> print(value_2) + Tensor(shape=[2, 1], dtype=int64, place=Place(cpu), stop_gradient=True, + [[7], + [6]]) + >>> print(indices_2) + Tensor(shape=[2, 1], dtype=int64, place=Place(cpu), stop_gradient=True, + [[3], + [1]]) + + >>> value_3, indices_3 = paddle.topk(data_2, k=1, axis=-1) + >>> print(value_3) + Tensor(shape=[2, 1], dtype=int64, place=Place(cpu), stop_gradient=True, + [[7], + [6]]) + >>> print(indices_3) + Tensor(shape=[2, 1], dtype=int64, place=Place(cpu), stop_gradient=True, + [[3], + [1]]) + + >>> value_4, indices_4 = paddle.topk(data_2, k=1, axis=0) + >>> print(value_4) + Tensor(shape=[1, 4], dtype=int64, place=Place(cpu), stop_gradient=True, + [[2, 6, 5, 7]]) + >>> print(indices_4) + Tensor(shape=[1, 4], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1, 1, 0, 0]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if axis is None: axis = -1 out, indices = _C_ops.topk(x, k, axis, largest, sorted) @@ -1018,30 +1055,30 @@ def bucketize(x, sorted_sequence, out_int32=False, right=False, name=None): .. code-block:: python - import paddle - - sorted_sequence = paddle.to_tensor([2, 4, 8, 16], dtype='int32') - x = paddle.to_tensor([[0, 8, 4, 16], [-1, 2, 8, 4]], dtype='int32') - out1 = paddle.bucketize(x, sorted_sequence) - print(out1) - # Tensor(shape=[2, 4], dtype=int64, place=CPUPlace, stop_gradient=True, - # [[0, 2, 1, 3], - # [0, 0, 2, 1]]) - out2 = paddle.bucketize(x, sorted_sequence, right=True) - print(out2) - # Tensor(shape=[2, 4], dtype=int64, place=CPUPlace, stop_gradient=True, - # [[0, 3, 2, 4], - # [0, 1, 3, 2]]) - out3 = x.bucketize(sorted_sequence) - print(out3) - # Tensor(shape=[2, 4], dtype=int64, place=CPUPlace, stop_gradient=True, - # [[0, 2, 1, 3], - # [0, 0, 2, 1]]) - out4 = x.bucketize(sorted_sequence, right=True) - print(out4) - # Tensor(shape=[2, 4], dtype=int64, place=CPUPlace, stop_gradient=True, - # [[0, 3, 2, 4], - # [0, 1, 3, 2]]) + >>> import paddle + + >>> sorted_sequence = paddle.to_tensor([2, 4, 8, 16], dtype='int32') + >>> x = paddle.to_tensor([[0, 8, 4, 16], [-1, 2, 8, 4]], dtype='int32') + >>> out1 = paddle.bucketize(x, sorted_sequence) + >>> print(out1) + Tensor(shape=[2, 4], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 2, 1, 3], + [0, 0, 2, 1]]) + >>> out2 = paddle.bucketize(x, sorted_sequence, right=True) + >>> print(out2) + Tensor(shape=[2, 4], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 3, 2, 4], + [0, 1, 3, 2]]) + >>> out3 = x.bucketize(sorted_sequence) + >>> print(out3) + Tensor(shape=[2, 4], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 2, 1, 3], + [0, 0, 2, 1]]) + >>> out4 = x.bucketize(sorted_sequence, right=True) + >>> print(out4) + Tensor(shape=[2, 4], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 3, 2, 4], + [0, 1, 3, 2]]) """ check_variable_and_dtype( @@ -1078,27 +1115,27 @@ def searchsorted( .. code-block:: python - import paddle - - sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9, 11], - [2, 4, 6, 8, 10, 12]], dtype='int32') - values = paddle.to_tensor([[3, 6, 9, 10], [3, 6, 9, 10]], dtype='int32') - out1 = paddle.searchsorted(sorted_sequence, values) - print(out1) - # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[1, 3, 4, 5], - # [1, 2, 4, 4]]) - out2 = paddle.searchsorted(sorted_sequence, values, right=True) - print(out2) - # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[2, 3, 5, 5], - # [1, 3, 4, 5]]) - sorted_sequence_1d = paddle.to_tensor([1, 3, 5, 7, 9, 11, 13]) - out3 = paddle.searchsorted(sorted_sequence_1d, values) - print(out3) - # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[1, 3, 4, 5], - # [1, 3, 4, 5]]) + >>> import paddle + + >>> sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9, 11], + ... [2, 4, 6, 8, 10, 12]], dtype='int32') + >>> values = paddle.to_tensor([[3, 6, 9, 10], [3, 6, 9, 10]], dtype='int32') + >>> out1 = paddle.searchsorted(sorted_sequence, values) + >>> print(out1) + Tensor(shape=[2, 4], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1, 3, 4, 5], + [1, 2, 4, 4]]) + >>> out2 = paddle.searchsorted(sorted_sequence, values, right=True) + >>> print(out2) + Tensor(shape=[2, 4], dtype=int64, place=Place(cpu), stop_gradient=True, + [[2, 3, 5, 5], + [1, 3, 4, 5]]) + >>> sorted_sequence_1d = paddle.to_tensor([1, 3, 5, 7, 9, 11, 13]) + >>> out3 = paddle.searchsorted(sorted_sequence_1d, values) + >>> print(out3) + Tensor(shape=[2, 4], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1, 3, 4, 5], + [1, 3, 4, 5]]) """ if in_dynamic_mode(): @@ -1150,23 +1187,28 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None): .. code-block:: python - import paddle - - x = paddle.randn((2,3,2)) - # Tensor(shape=[2, 3, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [[[ 0.22954939, -0.01296274], - # [ 1.17135799, -0.34493217], - # [-0.19550551, -0.17573971]], - # - # [[ 0.15104349, -0.93965352], - # [ 0.14745511, 0.98209465], - # [ 0.10732264, -0.55859774]]]) - y = paddle.kthvalue(x, 2, 1) - # (Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [[ 0.22954939, -0.17573971], - # [ 0.14745511, -0.55859774]]), Tensor(shape=[2, 2], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[0, 2], - # [1, 2]])) + >>> import paddle + + >>> x = paddle.randn((2,3,2)) + >>> print(x) + >>> # doctest: +SKIP('Different environments yield different output.') + Tensor(shape=[2, 3, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[ 0.11855337, -0.30557564], + [-0.09968963, 0.41220093], + [ 1.24004936, 1.50014710]], + [[ 0.08612321, -0.92485696], + [-0.09276631, 1.15149164], + [-1.46587241, 1.22873247]]]) + >>> # doctest: -SKIP + >>> y = paddle.kthvalue(x, 2, 1) + >>> print(y) + >>> # doctest: +SKIP('Different environments yield different output.') + (Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[ 0.11855337, 0.41220093], + [-0.09276631, 1.15149164]]), Tensor(shape=[2, 2], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 1], + [1, 1]])) + >>> # doctest: -SKIP """ if in_dynamic_mode(): if axis is not None: diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index a08dcd6ddb259..d7bcc48c8fa45 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -159,8 +159,8 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None): out = paddle.sum(paddle.pow((x - u), 2), axis, keepdim=keepdim, name=name) dtype = x.dtype - n = paddle.cast(paddle.numel(x), paddle.int64) / paddle.cast( - paddle.numel(out), paddle.int64 + n = paddle.cast(paddle.numel(x), "int64") / paddle.cast( + paddle.numel(out), "int64" ) n = n.astype(dtype) if unbiased: @@ -221,7 +221,7 @@ def std(x, axis=None, unbiased=True, keepdim=False, name=None): [1. 2.081666] """ - if not in_dynamic_mode(): + if not in_dynamic_or_pir_mode(): check_variable_and_dtype( x, 'x', ['float16', 'float32', 'float64'], 'std' ) diff --git a/python/paddle/tensor/to_string.py b/python/paddle/tensor/to_string.py index 6e173545a2767..97b8268fb6fe5 100644 --- a/python/paddle/tensor/to_string.py +++ b/python/paddle/tensor/to_string.py @@ -126,15 +126,11 @@ def _format_item(np_var, max_width=0, signed=False): or np_var.dtype == np.float16 ): if DEFAULT_PRINT_OPTIONS.sci_mode: - item_str = f'{{:.{DEFAULT_PRINT_OPTIONS.precision}e}}'.format( - np_var - ) + item_str = f'{np_var:.{DEFAULT_PRINT_OPTIONS.precision}e}' elif np.ceil(np_var) == np_var: item_str = f'{np_var:.0f}.' else: - item_str = f'{{:.{DEFAULT_PRINT_OPTIONS.precision}f}}'.format( - np_var - ) + item_str = f'{np_var:.{DEFAULT_PRINT_OPTIONS.precision}f}' else: item_str = f'{np_var}' diff --git a/python/paddle/utils/__init__.py b/python/paddle/utils/__init__.py index 1c58242e877ce..18697fdc25bfe 100644 --- a/python/paddle/utils/__init__.py +++ b/python/paddle/utils/__init__.py @@ -37,7 +37,7 @@ from .layers_utils import padding_to_same_structure # noqa: F401 from .layers_utils import assert_same_structure # noqa: F401 from .layers_utils import get_shape_tensor_inputs # noqa: F401 -from .layers_utils import get_pir_shape_tensor # noqa: F401 +from .layers_utils import get_int_tensor_list # noqa: F401 from .layers_utils import convert_shape_to_list # noqa: F401 from .layers_utils import check_shape # noqa: F401 from .layers_utils import try_set_static_shape_tensor # noqa: F401 diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index cb50f73d8d9b5..471b17c6d38fd 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -943,9 +943,7 @@ def parse_op_info(op_name): """ if op_name not in OpProtoHolder.instance().op_proto_map: raise ValueError( - "Please load {} shared library file firstly by `paddle.utils.cpp_extension.load_op_meta_info_and_register_op(...)`".format( - op_name - ) + f"Please load {op_name} shared library file firstly by `paddle.utils.cpp_extension.load_op_meta_info_and_register_op(...)`" ) op_proto = OpProtoHolder.instance().get_op_proto(op_name) diff --git a/python/paddle/utils/environments.py b/python/paddle/utils/environments.py new file mode 100644 index 0000000000000..84e7c293eafc6 --- /dev/null +++ b/python/paddle/utils/environments.py @@ -0,0 +1,104 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from typing import Generic, TypeVar + +T = TypeVar("T") + + +class EnvironmentVariable(Generic[T]): + name: str + default: T + + def __init__(self, name: str, default: T): + self.name = name + self.default = default + + def get(self) -> T: + raise NotImplementedError() + + def set(self, value: T) -> None: + raise NotImplementedError() + + def delete(self) -> None: + del os.environ[self.name] + + +class StringEnvironmentVariable(EnvironmentVariable[str]): + def __init__(self, name: str, default: str): + super().__init__(name, default) + assert isinstance(default, str), "default must be a string" + + def get(self) -> str: + return os.getenv(self.name, self.default) + + def set(self, value: str) -> None: + assert isinstance(value, str), "value must be a string" + os.environ[self.name] = value + + +class BooleanEnvironmentVariable(EnvironmentVariable[bool]): + BOOLEAN_IS_SET = ("y", "yes", "t", "true", "on", "1") + + def __init__(self, name: str, default: bool): + super().__init__(name, default) + assert isinstance(default, bool), "default must be a boolean" + + def get(self) -> bool: + default = str(self.default).lower() + env_str = os.getenv(self.name, default).lower() + return env_str in BooleanEnvironmentVariable.BOOLEAN_IS_SET + + def set(self, value: bool) -> None: + assert isinstance(value, bool), "value must be a boolean" + os.environ[self.name] = str(value).lower() + + +class IntegerEnvironmentVariable(EnvironmentVariable[int]): + def __init__(self, name: str, default: int): + super().__init__(name, default) + assert isinstance(default, int) and not isinstance( + default, bool + ), "default must be an integer" + + def get(self) -> int: + try: + return int(os.getenv(self.name, str(self.default))) + except ValueError: + return self.default + + def set(self, value: int) -> None: + assert isinstance(value, int) and not isinstance( + value, bool + ), "value must be an integer" + os.environ[self.name] = str(value) + + +class EnvironmentVariableGuard(Generic[T]): + variable: EnvironmentVariable[T] + original_value: T + + def __init__(self, variable: EnvironmentVariable[T], value: T): + self.variable = variable + self.original_value = variable.get() + self.variable.set(value) + + def __enter__(self) -> EnvironmentVariableGuard: + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.variable.set(self.original_value) diff --git a/python/paddle/utils/install_check.py b/python/paddle/utils/install_check.py index 4974eddbfa26c..b444b71834233 100644 --- a/python/paddle/utils/install_check.py +++ b/python/paddle/utils/install_check.py @@ -287,13 +287,11 @@ def run_check(): ) except Exception as e: logging.warning( - "PaddlePaddle meets some problem with {} {}s. This may be caused by:" + f"PaddlePaddle meets some problem with {device_count} {device_str}s. This may be caused by:" "\n 1. There is not enough GPUs visible on your system" "\n 2. Some GPUs are occupied by other process now" "\n 3. NVIDIA-NCCL2 is not installed correctly on your system. Please follow instruction on https://github.com/NVIDIA/nccl-tests " - "\n to test your NCCL, or reinstall it following https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html".format( - device_count, device_str - ) + "\n to test your NCCL, or reinstall it following https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html" ) logging.warning(f"\n Original Error is: {e}") diff --git a/python/paddle/utils/layers_utils.py b/python/paddle/utils/layers_utils.py index 37f32640b9b52..8a02e40c6aca9 100644 --- a/python/paddle/utils/layers_utils.py +++ b/python/paddle/utils/layers_utils.py @@ -18,14 +18,16 @@ from uuid import uuid4 from weakref import WeakKeyDictionary +import numpy as np + import paddle +from paddle.pir.core import convert_np_dtype_to_dtype_ from ..base.data_feeder import check_dtype, convert_dtype from ..base.framework import ( Block, Variable, _current_expected_place, - core, in_dygraph_mode, ) @@ -74,7 +76,9 @@ def convert_to_list(value, n, name, dtype=int): + str(value) ) for single_value in value_list: - assert not isinstance(single_value, Variable), ( + assert not isinstance( + single_value, (Variable, paddle.pir.OpResult) + ), ( "Required numerical type with '%s', but received Tensor." % dtype ) @@ -384,18 +388,28 @@ def _contain_var(list_or_tuple): return False -def get_pir_shape_tensor(list_shape, place=_current_expected_place()): - shape_tensor_list = [] - for dim in list_shape: - if isinstance(dim, paddle.pir.OpResult): - dim.stop_gradient = True - if convert_dtype(dim.dtype) != 'int32': - dim = paddle.cast(x=dim, dtype='int32') - shape_tensor_list.append(dim) +def get_int_tensor_list(ele_list, place=None, default_dtype='int64'): + if place is None: + place = _current_expected_place() + + int_tensor_list = [] + for ele in ele_list: + if isinstance(ele, paddle.pir.OpResult): + ele.stop_gradient = True + if convert_dtype(ele.dtype) != default_dtype: + ele = paddle.cast(x=ele, dtype=default_dtype) + if ele.shape == []: + ele = paddle.reshape(ele, [-1]) + int_tensor_list.append(ele) else: - temp_out = paddle.full([1], dim, core.DataType.INT32, place) - shape_tensor_list.append(temp_out) - return shape_tensor_list + temp_out = paddle.full( + [1], + ele, + convert_np_dtype_to_dtype_(np.dtype(default_dtype)), + place, + ) + int_tensor_list.append(temp_out) + return int_tensor_list def get_shape_tensor_inputs(inputs, attrs, shape, op_type): diff --git a/python/paddle/vision/models/alexnet.py b/python/paddle/vision/models/alexnet.py index 26282bf88f7a1..48af3b6b30d9b 100644 --- a/python/paddle/vision/models/alexnet.py +++ b/python/paddle/vision/models/alexnet.py @@ -175,9 +175,7 @@ def _alexnet(arch, pretrained, **kwargs): if pretrained: assert ( arch in model_urls - ), "{} model do not have a pretrained model now, you should set pretrained=False".format( - arch - ) + ), f"{arch} model do not have a pretrained model now, you should set pretrained=False" weight_path = get_weights_path_from_url( model_urls[arch][0], model_urls[arch][1] ) diff --git a/python/paddle/vision/models/densenet.py b/python/paddle/vision/models/densenet.py index 5da45f5040c79..b181389a387c5 100644 --- a/python/paddle/vision/models/densenet.py +++ b/python/paddle/vision/models/densenet.py @@ -339,9 +339,7 @@ def _densenet(arch, layers, pretrained, **kwargs): if pretrained: assert ( arch in model_urls - ), "{} model do not have a pretrained model now, you should set pretrained=False".format( - arch - ) + ), f"{arch} model do not have a pretrained model now, you should set pretrained=False" weight_path = get_weights_path_from_url( model_urls[arch][0], model_urls[arch][1] ) diff --git a/python/paddle/vision/models/googlenet.py b/python/paddle/vision/models/googlenet.py index e5945bbcb0589..e9dbbbc240a99 100644 --- a/python/paddle/vision/models/googlenet.py +++ b/python/paddle/vision/models/googlenet.py @@ -265,9 +265,7 @@ def googlenet(pretrained=False, **kwargs): if pretrained: assert ( arch in model_urls - ), "{} model do not have a pretrained model now, you should set pretrained=False".format( - arch - ) + ), f"{arch} model do not have a pretrained model now, you should set pretrained=False" weight_path = get_weights_path_from_url( model_urls[arch][0], model_urls[arch][1] ) diff --git a/python/paddle/vision/models/inceptionv3.py b/python/paddle/vision/models/inceptionv3.py index 88eb4b1d90ff6..5281e875d48cd 100644 --- a/python/paddle/vision/models/inceptionv3.py +++ b/python/paddle/vision/models/inceptionv3.py @@ -620,9 +620,7 @@ def inception_v3(pretrained=False, **kwargs): if pretrained: assert ( arch in model_urls - ), "{} model do not have a pretrained model now, you should set pretrained=False".format( - arch - ) + ), f"{arch} model do not have a pretrained model now, you should set pretrained=False" weight_path = get_weights_path_from_url( model_urls[arch][0], model_urls[arch][1] ) diff --git a/python/paddle/vision/models/mobilenetv1.py b/python/paddle/vision/models/mobilenetv1.py index f431d6a779551..907ef4b0a09ff 100644 --- a/python/paddle/vision/models/mobilenetv1.py +++ b/python/paddle/vision/models/mobilenetv1.py @@ -249,9 +249,7 @@ def _mobilenet(arch, pretrained=False, **kwargs): if pretrained: assert ( arch in model_urls - ), "{} model do not have a pretrained model now, you should set pretrained=False".format( - arch - ) + ), f"{arch} model do not have a pretrained model now, you should set pretrained=False" weight_path = get_weights_path_from_url( model_urls[arch][0], model_urls[arch][1] ) diff --git a/python/paddle/vision/models/mobilenetv2.py b/python/paddle/vision/models/mobilenetv2.py index 235714c433f3e..60914b48f008f 100644 --- a/python/paddle/vision/models/mobilenetv2.py +++ b/python/paddle/vision/models/mobilenetv2.py @@ -188,9 +188,7 @@ def _mobilenet(arch, pretrained=False, **kwargs): if pretrained: assert ( arch in model_urls - ), "{} model do not have a pretrained model now, you should set pretrained=False".format( - arch - ) + ), f"{arch} model do not have a pretrained model now, you should set pretrained=False" weight_path = get_weights_path_from_url( model_urls[arch][0], model_urls[arch][1] ) diff --git a/python/paddle/vision/models/mobilenetv3.py b/python/paddle/vision/models/mobilenetv3.py index eea53f3b43c85..c8a4184385d11 100644 --- a/python/paddle/vision/models/mobilenetv3.py +++ b/python/paddle/vision/models/mobilenetv3.py @@ -409,9 +409,7 @@ def _mobilenet_v3(arch, pretrained=False, scale=1.0, **kwargs): arch = f"{arch}_x{scale}" assert ( arch in model_urls - ), "{} model do not have a pretrained model now, you should set pretrained=False".format( - arch - ) + ), f"{arch} model do not have a pretrained model now, you should set pretrained=False" weight_path = get_weights_path_from_url( model_urls[arch][0], model_urls[arch][1] ) diff --git a/python/paddle/vision/models/resnet.py b/python/paddle/vision/models/resnet.py index 390f9c2a12d30..c801814b3e86c 100644 --- a/python/paddle/vision/models/resnet.py +++ b/python/paddle/vision/models/resnet.py @@ -352,9 +352,7 @@ def _resnet(arch, Block, depth, pretrained, **kwargs): if pretrained: assert ( arch in model_urls - ), "{} model do not have a pretrained model now, you should set pretrained=False".format( - arch - ) + ), f"{arch} model do not have a pretrained model now, you should set pretrained=False" weight_path = get_weights_path_from_url( model_urls[arch][0], model_urls[arch][1] ) diff --git a/python/paddle/vision/models/shufflenetv2.py b/python/paddle/vision/models/shufflenetv2.py index ab78764cc274a..e68f0c67439ef 100644 --- a/python/paddle/vision/models/shufflenetv2.py +++ b/python/paddle/vision/models/shufflenetv2.py @@ -318,9 +318,7 @@ def _shufflenet_v2(arch, pretrained=False, **kwargs): if pretrained: assert ( arch in model_urls - ), "{} model do not have a pretrained model now, you should set pretrained=False".format( - arch - ) + ), f"{arch} model do not have a pretrained model now, you should set pretrained=False" weight_path = get_weights_path_from_url( model_urls[arch][0], model_urls[arch][1] ) diff --git a/python/paddle/vision/models/squeezenet.py b/python/paddle/vision/models/squeezenet.py index 7d120d5df5449..0c9d90fe7e02d 100644 --- a/python/paddle/vision/models/squeezenet.py +++ b/python/paddle/vision/models/squeezenet.py @@ -203,9 +203,7 @@ def _squeezenet(arch, version, pretrained, **kwargs): if pretrained: assert ( arch in model_urls - ), "{} model do not have a pretrained model now, you should set pretrained=False".format( - arch - ) + ), f"{arch} model do not have a pretrained model now, you should set pretrained=False" weight_path = get_weights_path_from_url( model_urls[arch][0], model_urls[arch][1] ) diff --git a/python/paddle/vision/models/vgg.py b/python/paddle/vision/models/vgg.py index ceec4b8bf7728..8f0d8f475ee25 100644 --- a/python/paddle/vision/models/vgg.py +++ b/python/paddle/vision/models/vgg.py @@ -183,9 +183,7 @@ def _vgg(arch, cfg, batch_norm, pretrained, **kwargs): if pretrained: assert ( arch in model_urls - ), "{} model do not have a pretrained model now, you should set pretrained=False".format( - arch - ) + ), f"{arch} model do not have a pretrained model now, you should set pretrained=False" weight_path = get_weights_path_from_url( model_urls[arch][0], model_urls[arch][1] ) diff --git a/python/setup.py.in b/python/setup.py.in index 10cbd7d54a86d..d904877af43ce 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -429,7 +429,7 @@ packages=['paddle', 'paddle.framework', 'paddle.jit', 'paddle.jit.dy2static', - 'paddle.jit.newir_dy2static', + 'paddle.jit.pir_dy2static', 'paddle.jit.sot', 'paddle.jit.sot.opcode_translator', 'paddle.jit.sot.opcode_translator.executor', diff --git a/setup.py b/setup.py index e12d676cb8a5f..7cd1d0247e54c 100644 --- a/setup.py +++ b/setup.py @@ -41,8 +41,8 @@ ) else: if os.getenv("PY_VERSION") is None: - print("export PY_VERSION = %s" % platform.python_version()) python_version = platform.python_version() + print(f"export PY_VERSION = {python_version}") os.environ["PY_VERSION"] = python_version else: if os.getenv("PY_VERSION") != str(sys.version_info.major) + '.' + str( @@ -137,10 +137,8 @@ def get_header_install_dir(header): install_dir = re.sub( env_dict.get("PADDLE_SOURCE_DIR") + '/', '', header ) - print('install_dir: ', install_dir) if 'fluid/jit' in install_dir: install_dir = re.sub('fluid/jit', 'jit', install_dir) - print('fluid/jit install_dir: ', install_dir) else: # third_party install_dir = re.sub( @@ -206,12 +204,10 @@ class InstallCommand(InstallCommandBase): def finalize_options(self): ret = InstallCommandBase.finalize_options(self) self.install_lib = self.install_platlib - print("install_lib:", self.install_platlib) self.install_headers = os.path.join( self.install_platlib, 'paddle', 'include' ) - print("install_headers:", self.install_headers) return ret @@ -776,10 +772,8 @@ def cmake_run(build_path): or option_key == 'PYTHON_LIBRARIES' ): key = option_key + ":FILEPATH" - print(key) elif option_key == 'PYTHON_INCLUDE_DIR': key = option_key + ':PATH' - print(key) elif option_key == 'GENERATOR': key = 'CMAKE_' + option_key else: @@ -788,14 +782,12 @@ def cmake_run(build_path): paddle_build_options[key] = option_value options_process(args, paddle_build_options) - print("args:", args) with cd(build_path): cmake_args = [] cmake_args.append(CMAKE) cmake_args += args cmake_args.append('-DWITH_SETUP_INSTALL=ON') cmake_args.append(TOP_DIR) - print("cmake_args:", cmake_args) subprocess.check_call(cmake_args) @@ -804,7 +796,6 @@ def build_run(args, build_path, envrion_var): build_args = [] build_args.append(CMAKE) build_args += args - print(" ".join(build_args)) try: subprocess.check_call(build_args, cwd=build_path, env=envrion_var) except (CalledProcessError, KeyboardInterrupt) as e: @@ -1424,7 +1415,7 @@ def get_setup_parameters(): 'paddle.framework', 'paddle.jit', 'paddle.jit.dy2static', - 'paddle.jit.newir_dy2static', + 'paddle.jit.pir_dy2static', 'paddle.jit.sot', 'paddle.jit.sot.opcode_translator', 'paddle.jit.sot.opcode_translator.executor', diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 612cddc6b1b0d..381b29804e284 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -99,6 +99,10 @@ if(WITH_TESTING) if(CINN_ONLY) return() endif() + # The following unittests only run in PR-CI-CINN + if(WITH_CINN AND NOT CINN_ONMLY) + add_subdirectory(ir/pir/cinn) + endif() add_subdirectory(amp) add_subdirectory(asp) @@ -138,6 +142,7 @@ if(WITH_TESTING) add_subdirectory(rnn) add_subdirectory(rpc) add_subdirectory(sequence) + add_subdirectory(sot) add_subdirectory(standalone_executor) add_subdirectory(tokenizer) # add_subdirectory(white_list) diff --git a/test/amp/test_amp_api.py b/test/amp/test_amp_api.py index 3f9f13d3b420b..9f0d31e86310e 100644 --- a/test/amp/test_amp_api.py +++ b/test/amp/test_amp_api.py @@ -289,7 +289,7 @@ def test_op_called_as_expected(self): func = SimpleModelIncludeSetValue() func = paddle.amp.decorate(func, level='O2') - func = paddle.jit.to_static(func) + func = paddle.jit.to_static(func, full_graph=True) input = paddle.randn((2, 3)) with paddle.amp.auto_cast(level='O2'): diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 8700ab2e07074..a1862a17a581b 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -10,6 +10,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) test_auto_parallel_relaunch) set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) + py_test_modules(test_mp_allreduce_matmul_grad_overlapping MODULES + test_mp_allreduce_matmul_grad_overlapping) + set_tests_properties(test_mp_allreduce_matmul_grad_overlapping + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner) set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) @@ -96,7 +100,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) py_test_modules(test_reshard_r_to_s MODULES test_reshard_r_to_s) set_tests_properties(test_reshard_r_to_s - PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 160) py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p) set_tests_properties(test_reshard_r_to_p PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) @@ -114,18 +118,26 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_semi_auto_parallel_basic MODULES test_semi_auto_parallel_basic) set_tests_properties(test_semi_auto_parallel_basic - PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 200) + py_test_modules(test_semi_auto_parallel_pylayer MODULES + test_semi_auto_parallel_pylayer) + set_tests_properties(test_semi_auto_parallel_pylayer + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) py_test_modules(test_semi_auto_parallel_single_strategy MODULES test_semi_auto_parallel_single_strategy) set_tests_properties(test_semi_auto_parallel_single_strategy - PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 300) py_test_modules(test_semi_auto_parallel_hybrid_strategy MODULES test_semi_auto_parallel_hybrid_strategy) set_tests_properties(test_semi_auto_parallel_hybrid_strategy PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) - py_test_modules(test_gpt_with_newir MODULES test_gpt_with_newir) - set_tests_properties(test_gpt_with_newir + py_test_modules(test_semi_auto_parallel_dygraph_inplace MODULES + test_semi_auto_parallel_dygraph_inplace) + set_tests_properties(test_semi_auto_parallel_dygraph_inplace PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) + py_test_modules(test_gpt_with_pir MODULES test_gpt_with_pir) + set_tests_properties(test_gpt_with_pir PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" + TIMEOUT 100) # End of unittests WITH multi cards and timeout # NOTE(zyl): unittests WITH multi cards and WITHOUT timeout @@ -156,6 +168,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) set_tests_properties(test_fuse_adamw_pass PROPERTIES TIMEOUT 20) py_test_modules(test_rule_based_tuner_o2 MODULES test_rule_based_tuner_o2) set_tests_properties(test_rule_based_tuner_o2 PROPERTIES TIMEOUT 50) + py_test_modules(test_semi_auto_parallel_functional_in_single_card MODULES + test_semi_auto_parallel_functional_in_single_card) # End of unittests WITH single card and timeout # NOTE(zyl): unittests WITH single card and WITHOUT timeout diff --git a/test/auto_parallel/gpt_with_newir.py b/test/auto_parallel/gpt_with_pir.py similarity index 55% rename from test/auto_parallel/gpt_with_newir.py rename to test/auto_parallel/gpt_with_pir.py index 1be3202a23777..459539d23273a 100644 --- a/test/auto_parallel/gpt_with_newir.py +++ b/test/auto_parallel/gpt_with_pir.py @@ -18,6 +18,7 @@ import numpy as np from get_gpt_model import FakeDataset, generate_model +from test_sparse_addmm_op import get_cuda_version import paddle from paddle.distributed import ParallelEnv @@ -26,7 +27,7 @@ paddle.enable_static() -def apply_pass(use_sharding=False): +def apply_pass(use_sharding=False, pipeline_mode=None, fuse_passes_list=None): strategy = auto.Strategy() strategy.auto_mode = "semi" strategy.reinit = True @@ -51,6 +52,17 @@ def apply_pass(use_sharding=False): sharding.degree = 2 sharding.stage = 2 + if pipeline_mode: + pipeline = strategy.pipeline + pipeline.enable = True + pipeline.schedule_mode = pipeline_mode + pipeline.accumulate_steps = 2 + + if fuse_passes_list: + fused_passes = strategy.fused_passes + fused_passes.enable = True + fused_passes.fused_passes_list = fuse_passes_list + return strategy @@ -60,7 +72,7 @@ def reset_prog(): paddle.utils.unique_name.switch() -class TestNewIR(unittest.TestCase): +class TestPir(unittest.TestCase): def setUp(self): self.batch_size = 2 self.batch_num = 5 @@ -81,10 +93,19 @@ def init(self, engine, name): place = paddle.CUDAPlace(ParallelEnv().dev_id) engine._executor = paddle.static.Executor(place) - def get_engine(self, mode, name, use_sharding=False): + def get_engine( + self, + mode, + name, + use_sharding=False, + pipeline_mode=None, + fuse_passes_list=None, + ): reset_prog() - strategy = apply_pass(use_sharding) + paddle.set_default_dtype('float32') + + strategy = apply_pass(use_sharding, pipeline_mode, fuse_passes_list) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) model, loss = generate_model(mode, dropout_prob=0.1) @@ -102,12 +123,12 @@ def check_results(self, ref_losses, check_losses): ), ) - def enable_new_ir(self, flag): - paddle.set_flags({'FLAGS_enable_new_ir_in_executor': flag}) # for c++ - os.environ['FLAGS_enable_new_ir_in_executor'] = str(flag) # for python + def enable_pir(self, flag): + paddle.set_flags({'FLAGS_enable_pir_in_executor': flag}) # for c++ + os.environ['FLAGS_enable_pir_in_executor'] = str(flag) # for python def test_dp(self): - self.enable_new_ir(False) + self.enable_pir(False) engine_dp_prog = self.get_engine( "dp", name="dp_prog", use_sharding=True ) @@ -115,8 +136,8 @@ def test_dp(self): self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) - self.enable_new_ir(True) - engine_dp_ir = self.get_engine("dp", name="dp_newir", use_sharding=True) + self.enable_pir(True) + engine_dp_ir = self.get_engine("dp", name="dp_pir", use_sharding=True) out_dp_ir = engine_dp_ir.fit( self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) @@ -125,15 +146,55 @@ def test_dp(self): out_dp_prog.history["loss"][0], out_dp_ir.history["loss"][0] ) + def test_dp_with_fused_linear(self): + if not get_cuda_version() >= 11060: + return + + self.enable_pir(False) + engine_dp_prog = self.get_engine( + "dp", + name="dp_prog_fuse_linear", + fuse_passes_list=['fuse_gemm_epilogue'], + ) + out_dp_prog = engine_dp_prog.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + + self.enable_pir(True) + engine_dp_ir = self.get_engine( + "dp", + name="dp_pir_fuse_linear", + use_sharding=True, + fuse_passes_list=['fused_gemm_epilogue_pass'], + ) + out_dp_ir = engine_dp_ir.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + # TODO(zhiqiu): fix accuracy problem and use array_equal to check it + np.testing.assert_allclose( + out_dp_prog.history["loss"][0], + out_dp_ir.history["loss"][0], + rtol=1e-5, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, + out_dp_prog.history["loss"][0], + out_dp_ir.history["loss"][0], + out_dp_prog.history["loss"][0] - out_dp_ir.history["loss"][0], + ), + ) + # self.check_results( + # out_dp_prog.history["loss"][0], out_dp_ir.history["loss"][0] + # ) + def test_mp(self): - self.enable_new_ir(False) + self.enable_pir(False) engine_mp_prog = self.get_engine("mp", name="mp_prog") out_mp_prog = engine_mp_prog.fit( self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) - self.enable_new_ir(True) - engine_mp_ir = self.get_engine("mp", name="mp_newir") + self.enable_pir(True) + engine_mp_ir = self.get_engine("mp", name="mp_pir") out_mp_ir = engine_mp_ir.fit( self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) @@ -144,15 +205,15 @@ def test_mp(self): def test_pp(self): # navie pipeline parallel without schedule - self.enable_new_ir(False) + self.enable_pir(False) engine_pp_prog = self.get_engine("pp", name="pp_prog0") out_pp_prog = engine_pp_prog.fit( self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) - self.enable_new_ir(True) + self.enable_pir(True) # send_v2/recv_v2 dynamic_shape is True - engine_pp_ir = self.get_engine("pp", name="pp_newir") + engine_pp_ir = self.get_engine("pp", name="pp_pir") out_pp_ir = engine_pp_ir.fit( self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) @@ -182,6 +243,54 @@ def test_pp(self): out_pp_prog1["loss"], out_pp_ir.history["loss"][0] ) + def test_pp_1f1b(self): + self.enable_pir(False) + engine_1f1b_prog = self.get_engine( + "pp", name="1f1b_prog", use_sharding=False, pipeline_mode="1F1B" + ) + out_1f1b_prog = engine_1f1b_prog.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + + self.enable_pir(True) + engine_1f1b_ir = self.get_engine( + "pp", name="1f1b_pir", use_sharding=False, pipeline_mode="1F1B" + ) + out_1f1b_ir = engine_1f1b_ir.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + + if paddle.distributed.get_rank() == 1: + self.check_results( + out_1f1b_prog.history["loss"][0], + out_1f1b_ir.history["loss"][0], + ) + + def test_pp_fthenb(self): + self.enable_pir(False) + engine_fthenb_prog = self.get_engine( + "pp", name="fthenb_prog", use_sharding=False, pipeline_mode="FThenB" + ) + out_fthenb_prog = engine_fthenb_prog.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + + self.enable_pir(True) + engine_fthenb_ir = self.get_engine( + "pp", + name="fthenb_pir", + use_sharding=False, + pipeline_mode="FThenB", + ) + out_fthenb_ir = engine_fthenb_ir.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + if paddle.distributed.get_rank() == 1: + self.check_results( + out_fthenb_prog.history["loss"][0], + out_fthenb_ir.history["loss"][0], + ) + if __name__ == "__main__": unittest.main() diff --git a/test/auto_parallel/mp_allreduce_matmul_grad_overlapping_unittest.py b/test/auto_parallel/mp_allreduce_matmul_grad_overlapping_unittest.py new file mode 100644 index 0000000000000..2945dd1b31151 --- /dev/null +++ b/test/auto_parallel/mp_allreduce_matmul_grad_overlapping_unittest.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +from get_gpt_model import FakeDataset, generate_model + +import paddle +from paddle.distributed.fleet import auto + +paddle.enable_static() + + +def reset_prog(): + paddle.base.framework.switch_main_program(paddle.static.Program()) + paddle.base.framework.switch_startup_program(paddle.static.Program()) + + +class TestMPAllreduceMatmulGradOverlapping(unittest.TestCase): + def setUp(self): + self.rtol = 1e-5 + self.atol = 1e-8 + self.batch_size = 1 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2023) + np.random.seed(2023) + random.seed(2023) + place = paddle.base.CUDAPlace(paddle.distributed.ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_mp_engine(self, allreduce_matmul_grad_overlapping): + reset_prog() + + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + strategy.mp_optimization.allreduce_matmul_grad_overlapping = ( + allreduce_matmul_grad_overlapping + ) + + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("mp") + + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def run_mp(self, allreduce_matmul_grad_overlapping): + mp_engine = self.get_mp_engine(allreduce_matmul_grad_overlapping) + history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) + return np.array(history.history["loss"]) + + def check_results(self, ref_losses, check_losses, rtol=None, atol=None): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=rtol or self.rtol, + atol=atol or self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses + ), + ) + + def test_mp_allreduce_matmul_grad_overlapping(self): + losses_with_allreduce_matmul_grad_overlapping = self.run_mp(True) + losses_without_allreduce_matmul_grad_overlapping = self.run_mp(False) + + np.testing.assert_equal( + losses_with_allreduce_matmul_grad_overlapping, + losses_without_allreduce_matmul_grad_overlapping, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/reshard_api.py b/test/auto_parallel/reshard_api.py index c77cb9b773cac..5ad046080fa8f 100644 --- a/test/auto_parallel/reshard_api.py +++ b/test/auto_parallel/reshard_api.py @@ -18,6 +18,7 @@ import paddle import paddle.distributed as dist +from paddle import nn class TestReshardAPI: @@ -33,6 +34,8 @@ def run_test_cases(self): if self._backend == "cpu": paddle.set_device("cpu") self.test_case_p_to_r() + self.test_case_r_to_s() + self.test_case_forward_and_backward() def test_case_p_to_r(self): a = paddle.ones(self._shape) @@ -82,6 +85,68 @@ def test_case_r_to_s(self): assert np.equal(output_tensor.shape, input_tensor.shape).all() assert np.equal(output_tensor._local_shape, out_shape).all() + def test_case_forward_and_backward(self): + if self._backend == "cpu": + return + + np.random.seed(1901) + input_numpy = np.random.random(self._shape).astype("float32") + label_numpy = np.random.random(self._shape).astype('float32') + + in_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs[self._shard] = "x" + + in_dist_attr = dist.DistAttr( + mesh=dist.ProcessMesh([0, 1], dim_names=["x"]), + sharding_specs=in_shard_specs, + ) + + out_dist_attr = dist.DistAttr( + mesh=dist.ProcessMesh([0, 1], dim_names=["x"]), + sharding_specs=out_shard_specs, + ) + + local_input = paddle.to_tensor(input_numpy) + dist_input = dist.shard_tensor( + paddle.to_tensor(input_numpy), dist_attr=in_dist_attr + ) + + local_input.stop_gradient = False + dist_input.stop_gradient = False + + local_output = local_input + paddle.ones(self._shape) + dist_output = dist_input + dist.shard_tensor( + paddle.ones(self._shape), dist_attr=in_dist_attr + ) + dist_output.stop_gradient = False + + dist_output = dist.reshard(dist_output, dist_attr=out_dist_attr) + + local_label = paddle.to_tensor(label_numpy) + dist_label = dist.shard_tensor( + paddle.to_tensor(label_numpy), dist_attr=out_dist_attr + ) + + local_loss_fn = nn.MSELoss() + dist_loss_fn = nn.MSELoss() + + local_loss = local_loss_fn(local_output, local_label) + dist_loss = dist_loss_fn(dist_output, dist_label) + + np.testing.assert_allclose( + local_loss.numpy(), dist_loss.numpy(), rtol=1e-5, atol=1e-5 + ) + + local_loss.backward() + dist_loss.backward() + np.testing.assert_allclose( + local_input.grad.numpy(), + dist_input.grad.numpy(), + rtol=1e-5, + atol=1e-5, + ) + if __name__ == '__main__': TestReshardAPI().run_test_cases() diff --git a/test/auto_parallel/reshard_r_to_s_cross_mesh.py b/test/auto_parallel/reshard_r_to_s_cross_mesh.py new file mode 100644 index 0000000000000..68db1bcd7ef0c --- /dev/null +++ b/test/auto_parallel/reshard_r_to_s_cross_mesh.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.base import core + + +class TestReshardRToSCrossMesh: + def __init__(self): + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + self._shard = eval(os.getenv("shard")) + self._backend = os.getenv("backend") + self._in_mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._out_mesh = dist.ProcessMesh([1, 0], dim_names=["x"]) + + def run_test_case(self): + # cpu does not support send/recv + if self._backend == "cpu": + return + elif self._backend == "gpu": + place = paddle.CUDAPlace(dist.get_rank()) + + dev_ctx = core.DeviceContext.create(place) + + paddle.seed(self._seeds) + value = paddle.uniform(self._shape, self._dtype) + + in_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs[self._shard] = "x" + + dist_attr = dist.DistAttr( + mesh=self._in_mesh, sharding_specs=in_shard_specs + ) + out_dist_attr = dist.DistAttr( + mesh=self._out_mesh, sharding_specs=out_shard_specs + ) + + input_tensor = dist.shard_tensor(value, dist_attr=dist_attr) + + reshard_func = core.RToSReshardFunctionCrossMesh() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + out_shape = list(self._shape) + + if out_shape[self._shard] % 2 == 0: + out_shape[self._shard] = out_shape[self._shard] // 2 + split_shape = self._in_mesh.shape[0] + else: + split_shape = [ + out_shape[self._shard] // 2 + 1, + out_shape[self._shard] // 2, + ] + out_shape[self._shard] = ( + split_shape[0] if dist.get_rank() == 1 else split_shape[1] + ) + + out_expected_local_tensor_list = paddle.split( + value, num_or_sections=split_shape, axis=self._shard + ) + + np.testing.assert_equal( + out._local_value().numpy(), + out_expected_local_tensor_list[0].numpy() + if dist.get_rank() == 1 + else out_expected_local_tensor_list[1].numpy(), + ) + + assert np.equal(out.shape, input_tensor.shape).all() + assert np.equal(out._local_shape, out_shape).all() + + +if __name__ == '__main__': + TestReshardRToSCrossMesh().run_test_case() diff --git a/test/auto_parallel/reshard_s_to_r_cross_mesh.py b/test/auto_parallel/reshard_s_to_r_cross_mesh.py new file mode 100644 index 0000000000000..e1ea23f7a95d6 --- /dev/null +++ b/test/auto_parallel/reshard_s_to_r_cross_mesh.py @@ -0,0 +1,85 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.base import core + + +class TestReshardSToRCrossMesh: + def __init__(self): + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + self._shard = eval(os.getenv("shard")) + self._backend = os.getenv("backend") + + self._in_mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._out_mesh = dist.ProcessMesh([1, 0], dim_names=["x"]) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + place = paddle.CPUPlace() + elif self._backend == "gpu": + place = paddle.CUDAPlace(dist.get_rank()) + + dev_ctx = core.DeviceContext.create(place) + a = paddle.randn(self._shape) + + in_shard_specs = [None for i in range(len(self._shape))] + in_shard_specs[self._shard] = "x" + + out_shard_specs = [None for i in range(len(self._shape))] + dist_attr = dist.DistAttr( + mesh=self._in_mesh, sharding_specs=in_shard_specs + ) + out_dist_attr = dist.DistAttr( + mesh=self._out_mesh, sharding_specs=out_shard_specs + ) + + input_tensor = dist.shard_tensor(a, dist_attr=dist_attr) + + reshard_func = core.SToRReshardFunctionCrossMesh() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + + out_shape = list(self._shape) + if out_shape[self._shard] % 2 == 0: + split_shape = self._in_mesh.shape[0] + else: + split_shape = [ + out_shape[self._shard] // 2 + 1, + out_shape[self._shard] // 2, + ] + + in_expected_local_tensor_list = paddle.split( + out._local_value(), num_or_sections=split_shape, axis=self._shard + ) + + np.testing.assert_equal( + input_tensor._local_value().numpy(), + in_expected_local_tensor_list[dist.get_rank()].numpy(), + ) + + assert np.equal(out.shape, out_shape).all() + + +if __name__ == '__main__': + TestReshardSToRCrossMesh().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_dygraph_inplace.py b/test/auto_parallel/semi_auto_parallel_dygraph_inplace.py new file mode 100644 index 0000000000000..d94677e3b61f1 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_dygraph_inplace.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class TestInplaceForSemiAutoParallel(unittest.TestCase): + def run_test_case(self): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + x_np = np.random.random(size=[64, 32]).astype(np.float32) + y_np = np.random.random(size=[32, 48]).astype(np.float32) + x = paddle.to_tensor(x_np) + y = paddle.to_tensor(y_np) + x.stop_gradient = False + y.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', None]) + y_dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None, None]) + + dist_x = dist.shard_tensor(x_np, dist_attr=x_dist_attr) + dist_y = dist.shard_tensor(y_np, dist_attr=y_dist_attr) + dist_x.stop_gradient = False + dist_y.stop_gradient = False + dist_x = dist_x.add(dist_x) + dist_y = dist_y.add(dist_y) + dist_out = paddle.matmul( + dist_x, dist_y, transpose_x=False, transpose_y=False + ) + dist_x.add_(dist_x) + dist_y.add_(dist_y) + + with self.assertRaisesRegex( + RuntimeError, + "received tensor_version:1 != wrapper_version_snapshot:0", + ): + dist_out.backward() + + +if __name__ == '__main__': + TestInplaceForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_for_add_n.py b/test/auto_parallel/semi_auto_parallel_for_add_n.py new file mode 100644 index 0000000000000..9d7786eeaaf08 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_add_n.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class TestAddNApiForSemiAutoParallel: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + + def test_body( + self, x_shape, y_shape, x_specs, y_specs, trans_x=False, trans_y=False + ): + paddle.seed(self._seed) + np.random.seed(self._seed) + + x_np = np.random.random(size=x_shape).astype(self._dtype) + y_np = np.random.random(size=y_shape).astype(self._dtype) + x = paddle.to_tensor(x_np) + y = paddle.to_tensor(y_np) + x.stop_gradient = False + y.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + y_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=y_specs) + + dist_x = dist.shard_tensor(x_np, dist_attr=x_dist_attr) + dist_y = dist.shard_tensor(y_np, dist_attr=y_dist_attr) + dist_x.stop_gradient = False + dist_y.stop_gradient = False + + out = paddle.add_n([x, y]) + dist_out = paddle.add_n([dist_x, dist_y]) + self.check_tensor_eq(out, dist_out) + + out.backward() + dist_out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + self.check_tensor_eq(y.grad, dist_y.grad) + + return dist_out, dist_x.grad, dist_y.grad + + def test_add_n(self): + self.test_body( + x_shape=[64, 32], + y_shape=[64, 32], + x_specs=[None, None], + y_specs=[None, None], + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_add_n() + + +if __name__ == '__main__': + TestAddNApiForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_for_bitwise.py b/test/auto_parallel/semi_auto_parallel_for_bitwise.py new file mode 100644 index 0000000000000..1cbc6654b53b5 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_bitwise.py @@ -0,0 +1,161 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class TestBitwiseApiForSemiAutoParallel: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._check_grad = False + self._rtol = 1e-6 + self._atol = 0.0 + paddle.seed(self._seed) + np.random.seed(self._seed) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose( + np1, np2, rtol=self._rtol, atol=self._atol, verbose=True + ) + + def test_unary_body(self, x_shape, out_shape, x_specs, unary_func): + x = paddle.randint(0, 100, x_shape, self._dtype) + x.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + + dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr) + dist_x.stop_gradient = False + + dist_out = unary_func(dist_x) + out = unary_func(x) + self.check_tensor_eq(out, dist_out) + if self._check_grad: + dist_out.backward() + out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + + def test_binary_body( + self, x_shape, y_shape, out_shape, x_specs, y_specs, binary_func + ): + x = paddle.randint(0, 100, x_shape, self._dtype) + y = paddle.randint(0, 100, y_shape, self._dtype) + x.stop_gradient = False + y.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + y_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=y_specs) + + dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr) + dist_y = dist.shard_tensor(y, dist_attr=y_dist_attr) + dist_x.stop_gradient = False + dist_y.stop_gradient = False + + dist_out = binary_func(dist_x, dist_y) + out = binary_func(x, y) + self.check_tensor_eq(out, dist_out) + + if self._check_grad: + dist_out.backward() + out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + self.check_tensor_eq(y.grad, dist_y.grad) + + def test_bitwise_and_x_shard(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, None], + binary_func=paddle.bitwise_and, + ) + + def test_bitwise_and_x_shard_broadcast(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[2, 16, 32], + out_shape=[2, 16, 32], + x_specs=['x', None], + y_specs=[None, None, None], + binary_func=paddle.bitwise_and, + ) + + def test_bitwise_and_x_y_shard(self): + if self._backend == "cpu": + return + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, 'x'], + binary_func=paddle.bitwise_and, + ) + + def test_bitwise_and_x_y_shard_broadcast(self): + self.test_binary_body( + x_shape=[4, 16, 32], + y_shape=[16, 32], + out_shape=[4, 16, 32], + x_specs=['x', None, None], + y_specs=[None, None], + binary_func=paddle.bitwise_and, + ) + + def test_bitwise_not_x_shard(self): + self.test_unary_body( + x_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + unary_func=paddle.bitwise_not, + ) + + def test_bitwise_not_x_shard_broadcast(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[2, 16, 32], + out_shape=[2, 16, 32], + x_specs=['x', None], + y_specs=[None, None, None], + binary_func=paddle.bitwise_not, + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_bitwise_and_x_shard() + self.test_bitwise_and_x_shard_broadcast() + self.test_bitwise_and_x_y_shard() + self.test_bitwise_and_x_y_shard_broadcast() + self.test_bitwise_not_x_shard() + + +if __name__ == '__main__': + TestBitwiseApiForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_for_compare.py b/test/auto_parallel/semi_auto_parallel_for_compare.py new file mode 100644 index 0000000000000..a174a9c9180e1 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_compare.py @@ -0,0 +1,172 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class TestCompareApiForSemiAutoParallel: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._check_grad = False + self._rtol = 1e-6 + self._atol = 0.0 + paddle.seed(self._seed) + np.random.seed(self._seed) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose( + np1, np2, rtol=self._rtol, atol=self._atol, verbose=True + ) + + def test_binary_body( + self, x_shape, y_shape, out_shape, x_specs, y_specs, binary_func + ): + x = paddle.randn(x_shape, self._dtype) + y = paddle.randn(y_shape, self._dtype) + x.stop_gradient = False + y.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + y_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=y_specs) + + dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr) + dist_y = dist.shard_tensor(y, dist_attr=y_dist_attr) + dist_x.stop_gradient = False + dist_y.stop_gradient = False + + dist_out = binary_func(dist_x, dist_y) + out = binary_func(x, y) + self.check_tensor_eq(out, dist_out) + + if self._check_grad: + dist_out.backward() + out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + self.check_tensor_eq(y.grad, dist_y.grad) + + def test_equal_x_shard(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, None], + binary_func=paddle.equal, + ) + + def test_equal_x_shard_broadcast(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[2, 16, 32], + out_shape=[2, 16, 32], + x_specs=['x', None], + y_specs=[None, None, None], + binary_func=paddle.equal, + ) + + def test_equal_x_y_shard(self): + if self._backend == "cpu": + return + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, 'x'], + binary_func=paddle.equal, + ) + + def test_equal_x_y_shard_broadcast(self): + self.test_binary_body( + x_shape=[4, 16, 32], + y_shape=[16, 32], + out_shape=[4, 16, 32], + x_specs=['x', None, None], + y_specs=[None, None], + binary_func=paddle.equal, + ) + + def test_not_equal_x_shard(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, None], + binary_func=paddle.not_equal, + ) + + def test_not_equal_x_shard_broadcast(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[2, 16, 32], + out_shape=[2, 16, 32], + x_specs=['x', None], + y_specs=[None, None, None], + binary_func=paddle.not_equal, + ) + + def test_not_equal_x_y_shard(self): + if self._backend == "cpu": + return + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, 'x'], + binary_func=paddle.not_equal, + ) + + def test_not_equal_x_y_shard_broadcast(self): + self.test_binary_body( + x_shape=[4, 16, 32], + y_shape=[16, 32], + out_shape=[4, 16, 32], + x_specs=['x', None, None], + y_specs=[None, None], + binary_func=paddle.not_equal, + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_equal_x_shard() + self.test_equal_x_shard_broadcast() + self.test_equal_x_y_shard() + self.test_equal_x_y_shard_broadcast() + + self.test_not_equal_x_shard() + self.test_not_equal_x_shard_broadcast() + self.test_not_equal_x_y_shard() + self.test_not_equal_x_y_shard_broadcast() + + +if __name__ == '__main__': + TestCompareApiForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_for_concat.py b/test/auto_parallel/semi_auto_parallel_for_concat.py new file mode 100644 index 0000000000000..24605825d5f15 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_concat.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from semi_auto_parallel_util import SemiAutoParallelTestBase + +import paddle +import paddle.distributed as dist + + +class TestSplitAndConcatSemiAutoParallel(SemiAutoParallelTestBase): + def __init__(self): + super().__init__() + + def test_concat_forward(self): + shapes = [[16, 4, 4], [64, 4, 4]] + specs = [[None, None, 'x'], [None, None, 'x']] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=paddle.concat, + with_backward=False, + axis=0, + ) + + def test_concat_forward_reshard(self): + shapes = [[16, 4, 4], [64, 4, 4]] + specs = [['x', None, None], [None, None, 'x']] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=paddle.concat, + with_backward=False, + axis=0, + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_concat_forward() + # all to all is not supported yet for cpu + if self._backend == "gpu": + self.test_concat_forward_reshard() + + +if __name__ == '__main__': + TestSplitAndConcatSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_for_custom_relu.py b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py new file mode 100644 index 0000000000000..07496ec07e506 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from site import getsitepackages + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.utils.cpp_extension import get_build_directory, load +from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS, run_cmd + +# Note(Aurelius84): We use `add_test` in Cmake to config how to run unittest in CI. +# `PYTHONPATH` will be set as `build/python/paddle` that will make no way to find +# paddle include directory. Because the following path is generated after installing +# PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI. +paddle_includes = [] +for site_packages_path in getsitepackages(): + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include') + ) + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include', 'third_party') + ) + +# Test for extra compile args +extra_cc_args = ['-w', '-g'] if not IS_WINDOWS else ['/w'] +extra_nvcc_args = ['-O3'] + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = f'{get_build_directory()}\\dist_custom_relu\\dist_custom_relu.pyd' +if os.name == 'nt' and os.path.isfile(file): + cmd = f'del {file}' + run_cmd(cmd, True) + +if os.name == 'nt': + test_include = "..\\python\\paddle\\base\\tests\\auto_parallel" +else: + test_include = "../python/paddle/base/tests/auto_parallel" +paddle_includes.append(test_include) + +custom_ops = load( + name='dist_custom_relu_jit', + sources=[ + '../custom_op/custom_relu_op.cc', + '../custom_op/custom_relu_op_dup.cc', + '../custom_op/custom_relu_op.cu', + ], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_cc_args, # test for cc flags + extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags + verbose=True, +) + + +class TestCustomReluForSemiAutoParallel: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + + def test_body(self, x_shape, x_specs): + paddle.seed(self._seed) + np.random.seed(self._seed) + + x_np = np.random.random(size=x_shape).astype(self._dtype) + x = paddle.to_tensor(x_np) + x.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + + dist_x = dist.shard_tensor(x_np, dist_attr=x_dist_attr) + dist_x.stop_gradient = False + + y = paddle.add(x, x) + dist_y = paddle.add(dist_x, dist_x) + out = custom_ops.custom_relu(y) + dist_out = custom_ops.custom_relu(dist_y) + out.stop_gradient = False + dist_out.stop_gradient = False + + self.check_tensor_eq(out, dist_out) + + out.backward() + dist_out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + + def test_custom_relu(self): + self.test_body( + x_shape=[64, 32], + x_specs=['x', None], + ) + + def run_test_case(self): + paddle.set_device("gpu:" + str(dist.get_rank())) + self.test_custom_relu() + + +if __name__ == '__main__': + TestCustomReluForSemiAutoParallel().test_custom_relu() diff --git a/test/auto_parallel/semi_auto_parallel_for_elementwise.py b/test/auto_parallel/semi_auto_parallel_for_elementwise.py index b7e3e30b89e56..0e737db45ecaf 100644 --- a/test/auto_parallel/semi_auto_parallel_for_elementwise.py +++ b/test/auto_parallel/semi_auto_parallel_for_elementwise.py @@ -18,6 +18,7 @@ import paddle import paddle.distributed as dist +import paddle.nn.functional as F class TestElementwiseApiForSemiAutoParallel: @@ -26,18 +27,38 @@ def __init__(self): self._backend = os.getenv("backend") self._seed = eval(os.getenv("seed")) self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._rtol = 1e-6 + self._atol = 0.0 + paddle.seed(self._seed) + np.random.seed(self._seed) def check_tensor_eq(self, a, b): np1 = a.numpy() np2 = b.numpy() - np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + np.testing.assert_allclose( + np1, np2, rtol=self._rtol, atol=self._atol, verbose=True + ) + + def test_unary_body(self, x_shape, out_shape, x_specs, unary_func): + x = paddle.randn(x_shape, self._dtype) + x.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + + dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr) + dist_x.stop_gradient = False + + dist_out = unary_func(dist_x) + out = unary_func(x) + self.check_tensor_eq(out, dist_out) + + dist_out.backward() + out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) def test_binary_body( self, x_shape, y_shape, out_shape, x_specs, y_specs, binary_func ): - paddle.seed(self._seed) - np.random.seed(self._seed) - x = paddle.randn(x_shape, self._dtype) y = paddle.randn(y_shape, self._dtype) x.stop_gradient = False @@ -82,9 +103,9 @@ def test_sub_x_shard(self): def test_add_x_shard_broadcast(self): self.test_binary_body( - x_shape=[16, 32], - y_shape=[2, 16, 32], - out_shape=[2, 16, 32], + x_shape=[8, 16], + y_shape=[2, 8, 16], + out_shape=[2, 8, 16], x_specs=['x', None], y_specs=[None, None, None], binary_func=paddle.add, @@ -129,6 +150,315 @@ def test_sub_x_y_shard_broadcast(self): binary_func=paddle.subtract, ) + def test_square_x_shard(self): + self.test_unary_body( + x_shape=[4, 16], + out_shape=[4, 16], + x_specs=['x', None], + unary_func=paddle.square, + ) + + def test_relu_x_shard(self): + self.test_unary_body( + x_shape=[4, 16], + out_shape=[4, 16], + x_specs=['x', None], + unary_func=F.relu, + ) + + def test_maximum_x_shard(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, None], + binary_func=paddle.maximum, + ) + + def test_maximum_x_shard_broadcast(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[2, 16, 32], + out_shape=[2, 16, 32], + x_specs=['x', None], + y_specs=[None, None, None], + binary_func=paddle.maximum, + ) + + def test_maximum_x_y_shard(self): + if self._backend == "cpu": + return + + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, 'x'], + binary_func=paddle.maximum, + ) + + def test_maximum_x_y_shard_broadcast(self): + self.test_binary_body( + x_shape=[4, 16, 32], + y_shape=[16, 32], + out_shape=[4, 16, 32], + x_specs=['x', None, None], + y_specs=[None, None], + binary_func=paddle.maximum, + ) + + def test_multiply_x_shard(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, None], + binary_func=paddle.multiply, + ) + + def test_multiply_x_shard_broadcast(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[2, 16, 32], + out_shape=[2, 16, 32], + x_specs=['x', None], + y_specs=[None, None, None], + binary_func=paddle.multiply, + ) + + def test_multiply_x_y_shard(self): + if self._backend == "cpu": + return + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, 'x'], + binary_func=paddle.multiply, + ) + + def test_multiply_x_y_shard_broadcast(self): + self.test_binary_body( + x_shape=[4, 6, 8], + y_shape=[6, 8], + out_shape=[4, 6, 8], + x_specs=['x', None, None], + y_specs=[None, None], + binary_func=paddle.multiply, + ) + + def test_divide_x_shard(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, None], + binary_func=paddle.divide, + ) + + def test_divide_x_shard_broadcast(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[2, 16, 32], + out_shape=[2, 16, 32], + x_specs=['x', None], + y_specs=[None, None, None], + binary_func=paddle.divide, + ) + + def test_divide_x_y_shard(self): + if self._backend == "cpu": + return + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, 'x'], + binary_func=paddle.divide, + ) + + def test_divide_x_y_shard_broadcast(self): + self.test_binary_body( + x_shape=[2, 4, 6], + y_shape=[4, 6], + out_shape=[2, 4, 6], + x_specs=['x', None, None], + y_specs=[None, None], + binary_func=paddle.divide, + ) + + def test_bitwise_and_x_shard(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, None], + binary_func=paddle.bitwise_and, + ) + + def test_bitwise_and_x_shard_broadcast(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[2, 16, 32], + out_shape=[2, 16, 32], + x_specs=['x', None], + y_specs=[None, None, None], + binary_func=paddle.bitwise_and, + ) + + def test_bitwise_and_x_y_shard(self): + if self._backend == "cpu": + return + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, 'x'], + binary_func=paddle.bitwise_and, + ) + + def test_bitwise_and_x_y_shard_broadcast(self): + self.test_binary_body( + x_shape=[4, 16, 32], + y_shape=[16, 32], + out_shape=[4, 16, 32], + x_specs=['x', None, None], + y_specs=[None, None], + binary_func=paddle.bitwise_and, + ) + + def test_elementwise_pow_x_shard(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, None], + binary_func=paddle.pow, + ) + + def test_elementwise_pow_x_shard_broadcast(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[2, 16, 32], + out_shape=[2, 16, 32], + x_specs=['x', None], + y_specs=[None, None, None], + binary_func=paddle.pow, + ) + + def test_elementwise_pow_x_y_shard(self): + if self._backend == "cpu": + return + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, 'x'], + binary_func=paddle.pow, + ) + + def test_elementwise_pow_x_y_shard_broadcast(self): + self.test_binary_body( + x_shape=[4, 6, 8], + y_shape=[6, 8], + out_shape=[4, 6, 8], + x_specs=['x', None, None], + y_specs=[None, None], + binary_func=paddle.pow, + ) + + def test_equal_x_shard(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, None], + binary_func=paddle.equal, + ) + + def test_equal_x_shard_broadcast(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[2, 16, 32], + out_shape=[2, 16, 32], + x_specs=['x', None], + y_specs=[None, None, None], + binary_func=paddle.equal, + ) + + def test_equal_x_y_shard(self): + if self._backend == "cpu": + return + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, 'x'], + binary_func=paddle.equal, + ) + + def test_equal_x_y_shard_broadcast(self): + self.test_binary_body( + x_shape=[2, 6, 4], + y_shape=[6, 4], + out_shape=[2, 6, 4], + x_specs=['x', None, None], + y_specs=[None, None], + binary_func=paddle.equal, + ) + + def test_exp_x_shard(self): + self.test_unary_body( + x_shape=[4, 16], + out_shape=[4, 16], + x_specs=['x', None], + unary_func=paddle.exp, + ) + + def test_rsqrt_x_shard(self): + self.test_unary_body( + x_shape=[4, 16], + out_shape=[4, 16], + x_specs=['x', None], + unary_func=paddle.rsqrt, + ) + + def test_silu_x_shard(self): + self.test_unary_body( + x_shape=[4, 16], + out_shape=[4, 16], + x_specs=['x', None], + unary_func=paddle.nn.functional.silu, + ) + + def test_sin_x_shard(self): + self.test_unary_body( + x_shape=[4, 16], + out_shape=[4, 16], + x_specs=['x', None], + unary_func=paddle.sin, + ) + + def test_cos_x_shard(self): + self.test_unary_body( + x_shape=[4, 16], + out_shape=[4, 16], + x_specs=['x', None], + unary_func=paddle.cos, + ) + def run_test_case(self): if self._backend == "cpu": paddle.set_device("cpu") @@ -141,6 +471,31 @@ def run_test_case(self): self.test_add_x_shard_broadcast() self.test_add_x_y_shard() self.test_add_x_y_shard_broadcast() + self.test_sub_x_shard() + self.test_sub_x_y_shard_broadcast() + self.test_square_x_shard() + self.test_relu_x_shard() + self.test_maximum_x_shard() + self.test_maximum_x_shard_broadcast() + self.test_maximum_x_y_shard() + self.test_maximum_x_y_shard_broadcast() + self.test_multiply_x_shard() + self.test_multiply_x_shard_broadcast() + self.test_multiply_x_y_shard() + self.test_multiply_x_y_shard_broadcast() + self.test_divide_x_shard() + self.test_divide_x_shard_broadcast() + self.test_divide_x_y_shard() + self.test_divide_x_y_shard_broadcast() + self.test_elementwise_pow_x_shard() + self.test_elementwise_pow_x_shard_broadcast() + self.test_elementwise_pow_x_y_shard() + self.test_elementwise_pow_x_y_shard_broadcast() + self.test_exp_x_shard() + self.test_rsqrt_x_shard() + self.test_silu_x_shard() + self.test_sin_x_shard() + self.test_cos_x_shard() if __name__ == '__main__': diff --git a/test/auto_parallel/semi_auto_parallel_for_matmul.py b/test/auto_parallel/semi_auto_parallel_for_matmul.py index 279062f483058..470100e9c3bc8 100644 --- a/test/auto_parallel/semi_auto_parallel_for_matmul.py +++ b/test/auto_parallel/semi_auto_parallel_for_matmul.py @@ -30,7 +30,7 @@ def __init__(self): def check_tensor_eq(self, a, b): np1 = a.numpy() np2 = b.numpy() - np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + np.testing.assert_allclose(np1, np2, rtol=1e-04, verbose=True) def test_body( self, x_shape, y_shape, x_specs, y_specs, trans_x=False, trans_y=False diff --git a/test/auto_parallel/semi_auto_parallel_for_reduction.py b/test/auto_parallel/semi_auto_parallel_for_reduction.py new file mode 100644 index 0000000000000..4b2e7d4bb026b --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_reduction.py @@ -0,0 +1,111 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class TestReductionApiForSemiAutoParallel: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + + def test_body(self, x_shape, out_shape, x_specs, axis, keepdim, op_func): + paddle.seed(self._seed) + np.random.seed(self._seed) + + x = paddle.randn(x_shape, self._dtype) + x.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + + dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr) + dist_x.stop_gradient = False + + dist_out = op_func(dist_x, axis=axis, keepdim=keepdim) + out = op_func(x, axis=axis, keepdim=keepdim) + self.check_tensor_eq(out, dist_out) + np.testing.assert_equal(dist_out.shape, out_shape, verbose=True) + + dist_out.backward() + out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + + def test_sum_x_shard(self): + self.test_body( + x_shape=[4, 8, 6], + out_shape=[4, 6], + x_specs=['x', None, None], + axis=1, + keepdim=False, + op_func=paddle.sum, + ) + + def test_sum_x_shard_on_axis(self): + self.test_body( + x_shape=[4, 8, 6], + out_shape=[4], + x_specs=[None, 'x', None], + axis=[1, 2], + keepdim=False, + op_func=paddle.sum, + ) + + def test_sum_x_shard_on_axis_keepdim(self): + self.test_body( + x_shape=[4, 8, 6], + out_shape=[4, 1, 6], + x_specs=[None, 'x', None], + axis=1, + keepdim=True, + op_func=paddle.sum, + ) + + def test_mean_x_shard(self): + self.test_body( + x_shape=[4, 8, 6], + out_shape=[8, 6], + x_specs=['x', None, None], + axis=-3, + keepdim=False, + op_func=paddle.mean, + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_sum_x_shard() + self.test_sum_x_shard_on_axis() + self.test_sum_x_shard_on_axis_keepdim() + self.test_mean_x_shard() + + +if __name__ == '__main__': + TestReductionApiForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py b/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py index f08a4073a3faa..1c52687409336 100644 --- a/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py +++ b/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py @@ -18,7 +18,6 @@ import paddle import paddle.distributed as dist -import paddle.nn.functional as F class TestReplicatedSPmdApiForSemiAutoParallel: @@ -49,29 +48,6 @@ def create_local_and_dist_tensor_pair(self, np_array, sharding_specs): return local_t, dist_t - # input: phi::Tensor - # output: phi::Tensor - def test_relu(self): - x = np.random.random(size=[4, 4]).astype(self._dtype) - local_in, dist_in = self.create_local_and_dist_tensor_pair( - x, ['x', None] - ) - local_out = F.relu(local_in) - dist_out = F.relu(dist_in) - np.testing.assert_equal( - dist_out.dist_attr.dims_mapping, [-1, -1], verbose=True - ) - self.check_tensor_eq(local_out, dist_out) - - # test backward - local_out.backward() - dist_out.backward() - np.testing.assert_equal(dist_in.grad._local_shape, [2, 4], verbose=True) - np.testing.assert_equal( - dist_in.grad.dist_attr.dims_mapping, [0, -1], verbose=True - ) - self.check_tensor_eq(local_in.grad, dist_in.grad) - # input: phi::Tensor # output: std::vector def test_unbind(self): @@ -91,6 +67,109 @@ def test_unbind(self): dist_out.backward() self.check_tensor_eq(local_in.grad, dist_in.grad) + # input: paddle::optional + # output: phi::Tensor + def test_expand_as(self): + x1 = np.random.random(size=[2, 8]).astype("float32") + x2 = np.random.random(size=[2, 2, 8]).astype("float32") + local_in1, dist_in1 = self.create_local_and_dist_tensor_pair( + x1, ['x', None] + ) + local_in2, dist_in2 = self.create_local_and_dist_tensor_pair( + x2, [None, None, None] + ) + local_out = paddle.expand_as(local_in1, local_in2) + dist_out = paddle.expand_as(dist_in1, dist_in2) + self.check_tensor_eq(local_out, dist_out) + + local_out.backward() + dist_out.backward() + self.check_tensor_eq(local_in1.grad, dist_in1.grad) + + # input: phi::Tensor + # output: inplace paddle::optional + def test_adamax(self): + dtype = np.float32 + mp_dtype = np.float32 + shape = [120, 320] + + beta1 = 0.78 + beta2 = 0.899 + epsilon = 1e-5 + param = np.random.random(shape).astype(dtype) + grad = np.random.random(shape).astype(dtype) + moment = np.random.random(shape).astype(dtype) + inf_norm = np.random.random(shape).astype(dtype) + master_param = param.astype(mp_dtype) + + lr = np.array([0.002]).astype("float32") + beta1_pow = np.array([beta1**10]).astype("float32") + + local_param, dist_param = self.create_local_and_dist_tensor_pair( + param, ['x', None] + ) + local_grad, dist_grad = self.create_local_and_dist_tensor_pair( + grad, ['x', None] + ) + local_lr, dist_lr = self.create_local_and_dist_tensor_pair(lr, [None]) + ( + local_beta1_pow, + dist_beta1_pow, + ) = self.create_local_and_dist_tensor_pair(beta1_pow, [None]) + local_moment, dist_moment = self.create_local_and_dist_tensor_pair( + moment, ['x', None] + ) + local_inf_norm, dist_inf_norm = self.create_local_and_dist_tensor_pair( + inf_norm, ['x', None] + ) + ( + local_master_param, + dist_master_param, + ) = self.create_local_and_dist_tensor_pair(master_param, [None, None]) + + ( + local_param_out, + local_moment_out, + local_inf_norm_out, + local_master_param_out, + ) = paddle._C_ops.adamax_( + local_param, + local_grad, + local_lr, + local_moment, + local_inf_norm, + local_beta1_pow, + local_master_param, + beta1, + beta2, + epsilon, + True, + ) + + ( + dist_param_out, + dist_moment_out, + dist_inf_norm_out, + dist_master_param_out, + ) = paddle._C_ops.adamax_( + dist_param, + dist_grad, + dist_lr, + dist_moment, + dist_inf_norm, + dist_beta1_pow, + dist_master_param, + beta1, + beta2, + epsilon, + True, + ) + + self.check_tensor_eq(local_param_out, dist_param_out) + self.check_tensor_eq(local_moment_out, dist_moment_out) + self.check_tensor_eq(local_inf_norm_out, dist_inf_norm_out) + self.check_tensor_eq(local_master_param_out, dist_master_param_out) + # mutiple operators def test_mse_loss(self): x = np.random.random(size=[4, 4]).astype(self._dtype) @@ -102,9 +181,9 @@ def test_mse_loss(self): y, [None] ) - mes_loss = paddle.nn.loss.MSELoss() - local_out = mes_loss(local_in, local_label) - dist_out = mes_loss(dist_in, dist_label) + mse_loss = paddle.nn.loss.MSELoss() + local_out = mse_loss(local_in, local_label) + dist_out = mse_loss(dist_in, dist_label) self.check_tensor_eq(local_out, dist_out) # test backward @@ -124,9 +203,10 @@ def run_test_case(self): else: raise ValueError("Only support cpu or gpu backend.") - self.test_relu() - self.test_mse_loss() self.test_unbind() + self.test_expand_as() + self.test_adamax() + self.test_mse_loss() if __name__ == '__main__': diff --git a/test/auto_parallel/semi_auto_parallel_pylayer.py b/test/auto_parallel/semi_auto_parallel_pylayer.py new file mode 100644 index 0000000000000..5a8f9683c6476 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_pylayer.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.autograd.py_layer import PyLayer + + +class TestNet(PyLayer): + @staticmethod + def forward(ctx, x1, x2, x3): + y1 = paddle.matmul(x1, x2, transpose_x=False, transpose_y=False) + y2 = paddle.matmul(x2, x3, transpose_x=False, transpose_y=False) + return y1, y2 + + @staticmethod + def backward(ctx, dy1, dy2): + return dy1, dy2, dy2 + + +class TestPyLayerForSemiAutoParallel(unittest.TestCase): + def run_test_case(self): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + x1_np = np.random.random(size=[64, 32]).astype(np.float32) + x2_np = np.random.random(size=[32, 48]).astype(np.float32) + x3_np = np.random.random(size=[48, 64]).astype(np.float32) + x1 = paddle.to_tensor(x1_np) + x2 = paddle.to_tensor(x2_np) + x3 = paddle.to_tensor(x3_np) + x1.stop_gradient = False + x2.stop_gradient = False + x3.stop_gradient = False + + x1_dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None, None]) + x2_dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None, None]) + x3_dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None, None]) + + dist_x1 = dist.shard_tensor(x1_np, dist_attr=x1_dist_attr) + dist_x2 = dist.shard_tensor(x2_np, dist_attr=x2_dist_attr) + dist_x3 = dist.shard_tensor(x3_np, dist_attr=x3_dist_attr) + dist_x1.stop_gradient = False + dist_x2.stop_gradient = False + dist_x3.stop_gradient = False + + y1, y2 = TestNet.apply(x1, x2, x3) + loss = y1.sum() + + dist_y1, dist_y2 = TestNet.apply(dist_x1, dist_x2, dist_x3) + dist_loss = dist_y1.sum() + + np.testing.assert_allclose( + loss.numpy(), dist_loss.numpy(), rtol=1e-04, verbose=True + ) + + loss.backward() + dist_loss.backward() + + np.testing.assert_allclose( + x1.grad.numpy(), dist_x1.grad.numpy(), rtol=1e-04, verbose=True + ) + np.testing.assert_allclose( + x2.grad.numpy(), dist_x2.grad.numpy(), rtol=1e-04, verbose=True + ) + np.testing.assert_allclose( + x3.grad.numpy(), dist_x3.grad.numpy(), rtol=1e-04, verbose=True + ) + + +if __name__ == '__main__': + TestPyLayerForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net.py b/test/auto_parallel/semi_auto_parallel_simple_net.py index 62fec8c906336..a0c78b061dffe 100644 --- a/test/auto_parallel/semi_auto_parallel_simple_net.py +++ b/test/auto_parallel/semi_auto_parallel_simple_net.py @@ -19,6 +19,7 @@ import paddle import paddle.distributed as dist from paddle import nn +from paddle.distributed.fleet.utils import recompute BATCH_SIZE = 16 BATCH_NUM = 4 @@ -26,92 +27,44 @@ CLASS_NUM = 10 -# TODO(chenweihang): update to MLP Layer later -class DemoNet(nn.Layer): - def __init__(self, np_w0, np_w1): - super().__init__() - self.w0 = self.create_parameter( - shape=[IMAGE_SIZE, IMAGE_SIZE], - attr=paddle.framework.ParamAttr( - name="demo_weight_1", - initializer=paddle.nn.initializer.Assign(np_w0), - ), - ) - self.w1 = self.create_parameter( - shape=[IMAGE_SIZE, CLASS_NUM], - attr=paddle.framework.ParamAttr( - name="nemo_weight_2", - initializer=paddle.nn.initializer.Assign(np_w1), - ), - ) +def create_numpy_like_random(name): + return paddle.ParamAttr( + name=name, initializer=paddle.nn.initializer.Uniform(0, 1) + ) - def forward(self, x): - y = paddle.matmul(x, self.w0) - z = paddle.matmul(y, self.w1) - return z - -class DPDemoNet(nn.Layer): - def __init__(self, np_w0, np_w1, mesh): +class DemoNet(nn.Layer): + def __init__( + self, + param_prefix="", + is_recompute=False, + is_pp=False, + pp_reshard_dist_attr=None, + ): super().__init__() - self.mesh = mesh - self.w0 = self.create_parameter( - shape=[IMAGE_SIZE, IMAGE_SIZE], - attr=paddle.framework.ParamAttr( - name="dp_demo_weight_1", - initializer=paddle.nn.initializer.Assign(np_w0), - ), - ) - self.w1 = self.create_parameter( - shape=[IMAGE_SIZE, CLASS_NUM], - attr=paddle.framework.ParamAttr( - name="dp_nemo_weight_2", - initializer=paddle.nn.initializer.Assign(np_w1), - ), - ) + weight_attr_0 = create_numpy_like_random(param_prefix + "_0") + weight_attr_1 = create_numpy_like_random(param_prefix + "_1") + + self.is_pp = is_pp + self.is_recompute = is_recompute + self.pp_reshard_dist_attr = pp_reshard_dist_attr + self.linear_0 = nn.Linear(IMAGE_SIZE, IMAGE_SIZE, weight_attr_0) + self.linear_1 = nn.Linear(IMAGE_SIZE, CLASS_NUM, weight_attr_1) + self.relu = nn.ReLU() + + def _inner_forward_fn(self, x): + out = self.linear_0(x) + out = self.relu(out) + if self.is_pp: + out = dist.reshard(out, self.pp_reshard_dist_attr) + out = self.linear_1(out) + return out def forward(self, x): - y = paddle.matmul( - dist.shard_tensor( - x, - dist_attr=dist.DistAttr( - mesh=self.mesh, sharding_specs=['x', None] - ), - ), - self.w0, - ) - z = paddle.matmul(y, self.w1) - return z - - -class MPDemoNet(nn.Layer): - def __init__(self, np_w0, np_w1, mesh): - super().__init__() - self.w0 = dist.shard_tensor( - self.create_parameter( - shape=[IMAGE_SIZE, IMAGE_SIZE], - attr=paddle.framework.ParamAttr( - name="mp_demo_weight_1", - initializer=paddle.nn.initializer.Assign(np_w0), - ), - ), - dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, 'x']), - ) - self.w1 = dist.shard_tensor( - self.create_parameter( - shape=[IMAGE_SIZE, CLASS_NUM], - attr=paddle.framework.ParamAttr( - name="mp_nemo_weight_2", - initializer=paddle.nn.initializer.Assign(np_w1), - ), - ), - dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=['x', None]), - ) - - def forward(self, x): - y = paddle.matmul(x, self.w0) - z = paddle.matmul(y, self.w1) - return z + if self.is_recompute: + return recompute(self._inner_forward_fn, x) + else: + return self._inner_forward_fn(x) class TestSimpleNetForSemiAutoParallel: @@ -120,6 +73,11 @@ def __init__(self): self._backend = os.getenv("backend") self._seed = eval(os.getenv("seed")) self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._pp_mesh0 = dist.ProcessMesh([0], dim_names=["x"]) + self._pp_mesh1 = dist.ProcessMesh([1], dim_names=["x"]) + self.pp_reshard_dist_attr = dist.DistAttr( + mesh=self._pp_mesh1, sharding_specs=[None, None] + ) paddle.set_device(self._backend) @@ -127,6 +85,45 @@ def __init__(self): self.init_single_card_net_result() + def shard_fn(self, layer_name, layer, process_mesh): + if layer_name == 'linear_0': + dist_attr = dist.DistAttr( + mesh=process_mesh, sharding_specs=[None, 'x'] + ) + layer.weight = dist.shard_tensor(layer.weight, dist_attr=dist_attr) + elif layer_name == 'linear_1': + dist_attr = dist.DistAttr( + mesh=process_mesh, sharding_specs=['x', None] + ) + layer.weight = dist.shard_tensor(layer.weight, dist_attr=dist_attr) + + def pp_shard_fn(self, layer_name, layer, process_mesh): + if layer_name == 'linear_0': + # shard_layer doens't support cross-mesh now. + # input process_mesh of pp_shard_fn is useless, + # it's defined just for unified format. + weight_dist_attr = dist.DistAttr( + mesh=self._pp_mesh0, sharding_specs=[None, None] + ) + bias_dist_attr = dist.DistAttr( + mesh=self._pp_mesh0, sharding_specs=[None] + ) + layer.weight = dist.shard_tensor( + layer.weight, dist_attr=weight_dist_attr + ) + layer.bias = dist.shard_tensor(layer.bias, dist_attr=bias_dist_attr) + elif layer_name == 'linear_1': + weight_dist_attr = dist.DistAttr( + mesh=self._pp_mesh1, sharding_specs=[None, None] + ) + bias_dist_attr = dist.DistAttr( + mesh=self._pp_mesh1, sharding_specs=[None] + ) + layer.weight = dist.shard_tensor( + layer.weight, dist_attr=weight_dist_attr + ) + layer.bias = dist.shard_tensor(layer.bias, dist_attr=bias_dist_attr) + def init_input_data(self): paddle.seed(self._seed) np.random.seed(self._seed) @@ -135,26 +132,40 @@ def init_input_data(self): 'float32' ) self.label = np.random.random([BATCH_SIZE, CLASS_NUM]).astype('float32') - self.w0 = np.random.random([IMAGE_SIZE, IMAGE_SIZE]).astype('float32') - self.w1 = np.random.random([IMAGE_SIZE, CLASS_NUM]).astype('float32') - # TODO(chenweihang): optimizer cannot run auto-parallel now - def run_dynamic(self, layer): + def run_dynamic(self, layer, shard_input=False, is_pp=False): + paddle.seed(self._seed) + np.random.seed(self._seed) + # create loss loss_fn = nn.MSELoss() + # run forward and backward image = paddle.to_tensor(self.image) - out = layer(image) + input_mesh = self._pp_mesh0 if is_pp else self._mesh + if shard_input: + image = dist.shard_tensor( + image, + dist_attr=dist.DistAttr( + mesh=input_mesh, sharding_specs=['x', None] + ), + ) + out = layer(image) label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) loss.backward() - return loss, layer.w0.grad, layer.w1.grad + opt = paddle.optimizer.SGD( + learning_rate=0.1, parameters=layer.parameters() + ) + opt.step() + return loss, layer.parameters() def init_single_card_net_result(self): - self.base_loss, self.base_w0_grad, self.base_w1_grad = self.run_dynamic( - DemoNet(self.w0, self.w1) + self.base_loss, self.base_parameters = self.run_dynamic( + DemoNet("demo_weight") ) def check_tensor_eq(self, a, b): @@ -163,24 +174,69 @@ def check_tensor_eq(self, a, b): np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) def test_dp_demo_net(self): - self.dp_loss, self.dp_w0_grad, self.dp_w1_grad = self.run_dynamic( - DPDemoNet(self.w0, self.w1, self._mesh) + self.dp_loss, self.dp_parameters = self.run_dynamic( + DemoNet("dp_demo_weight"), + shard_input=True, ) self.check_tensor_eq(self.dp_loss, self.base_loss) - self.check_tensor_eq(self.dp_w0_grad, self.base_w0_grad) - self.check_tensor_eq(self.dp_w1_grad, self.base_w1_grad) + for param, param_base in zip(self.dp_parameters, self.base_parameters): + self.check_tensor_eq(param, param_base) + self.check_tensor_eq(param.grad, param_base.grad) def test_mp_demo_net(self): - self.mp_loss, self.mp_w0_grad, self.mp_w1_grad = self.run_dynamic( - MPDemoNet(self.w0, self.w1, self._mesh) + mp_layer = dist.shard_layer( + DemoNet("mp_demo_weight"), self._mesh, self.shard_fn ) + + self.mp_loss, self.mp_parameters = self.run_dynamic(mp_layer) self.check_tensor_eq(self.mp_loss, self.base_loss) - self.check_tensor_eq(self.mp_w0_grad, self.base_w0_grad) - self.check_tensor_eq(self.mp_w1_grad, self.base_w1_grad) + + for param, param_base in zip(self.mp_parameters, self.base_parameters): + self.check_tensor_eq(param, param_base) + self.check_tensor_eq(param.grad, param_base.grad) + + def test_pp_demo_net(self): + # Send/Recv operators doens't support CPU now. + if self._backend != "gpu": + return + + pp_layer = dist.shard_layer( + DemoNet( + "pp_demo_weight", + is_pp=True, + pp_reshard_dist_attr=self.pp_reshard_dist_attr, + ), + self._pp_mesh0, + self.pp_shard_fn, + ) + + self.pp_loss, self.pp_parameters = self.run_dynamic( + pp_layer, is_pp=True + ) + + rank = dist.get_rank() + # TODO(GhostScreaming): DistTensor.numpy() doesn't support + # cross-mesh now, ReshardXToReplicated function in eager_method + # needs to be fixed later. + if rank == 0: + # linear_0 weight and bias + self.check_tensor_eq(self.pp_parameters[0], self.base_parameters[0]) + self.check_tensor_eq(self.pp_parameters[1], self.base_parameters[1]) + else: + self.check_tensor_eq(self.pp_loss, self.base_loss) + # linear_1 weight and bias + self.check_tensor_eq(self.pp_parameters[2], self.base_parameters[2]) + self.check_tensor_eq(self.pp_parameters[3], self.base_parameters[3]) + + # TODO(GhostScreaming): Enable it later. + # for param, param_base in zip(self.mp_parameters, self.base_parameters): + # self.check_tensor_eq(param, param_base) + # self.check_tensor_eq(param.grad, param_base.grad) def run_test_case(self): self.test_dp_demo_net() self.test_mp_demo_net() + self.test_pp_demo_net() if __name__ == '__main__': diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_amp.py b/test/auto_parallel/semi_auto_parallel_simple_net_amp.py new file mode 100644 index 0000000000000..087bbcc16efb4 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_amp.py @@ -0,0 +1,137 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +from semi_auto_parallel_simple_net import ( + DemoNet, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +from paddle import nn + + +class TestSimpleNetWithAmpForSemiAutoParallel(TestSimpleNetForSemiAutoParallel): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + self.init_single_card_net_result() + + def run_dynamic_amp(self, layer, level='O1', shard_input=False): + paddle.seed(self._seed) + np.random.seed(self._seed) + + if level == 'O2': + layer = paddle.amp.decorate(models=layer, level='O2') + # create loss + loss_fn = nn.MSELoss() + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + # run forward and backward + image = paddle.to_tensor(self.image) + if shard_input: + image = dist.shard_tensor( + image, + dist_attr=dist.DistAttr( + mesh=self._mesh, sharding_specs=['x', None] + ), + ) + + with paddle.amp.auto_cast(level=level): + out = layer(image) + label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) + + scaled = scaler.scale(loss) + scaled.backward() + return loss, layer.parameters() + + def init_single_card_net_result(self): + ( + self.base_loss_o1, + self.base_parameters_o1, + ) = self.run_dynamic_amp(DemoNet('demo_weight_O1'), 'O1') + ( + self.base_loss_o2, + self.base_parameters_o2, + ) = self.run_dynamic_amp(DemoNet('demo_weight_O2'), 'O2') + + def test_dp_demo_net(self): + ( + self.dp_loss_o1, + self.dp_parameters_o1, + ) = self.run_dynamic_amp( + DemoNet('dp_demo_weight_O1'), 'O1', shard_input=True + ) + self.check_tensor_eq(self.dp_loss_o1, self.base_loss_o1) + for param, param_base in zip( + self.dp_parameters_o1, self.base_parameters_o1 + ): + # self.check_tensor_eq(param, param_base) + self.check_tensor_eq(param.grad, param_base.grad) + + ( + self.dp_loss_o2, + self.dp_parameters_o2, + ) = self.run_dynamic_amp(DemoNet('dp_demo_weight_O2'), 'O2') + self.check_tensor_eq(self.dp_loss_o2, self.base_loss_o2) + for param, param_base in zip( + self.dp_parameters_o2, self.base_parameters_o2 + ): + self.check_tensor_eq(param, param_base) + self.check_tensor_eq(param.grad, param_base.grad) + + def test_mp_demo_net(self): + mp_layer_o1 = dist.shard_layer( + DemoNet("mp_demo_weight_O1"), self._mesh, self.shard_fn + ) + ( + self.mp_loss_o1, + self.mp_parameters_o1, + ) = self.run_dynamic_amp(mp_layer_o1, 'O1') + self.check_tensor_eq(self.mp_loss_o1, self.base_loss_o1) + for param, param_base in zip( + self.mp_parameters_o1, self.base_parameters_o1 + ): + self.check_tensor_eq(param, param_base) + self.check_tensor_eq(param.grad, param_base.grad) + + mp_layer_o2 = dist.shard_layer( + DemoNet("mp_demo_weight_O2"), self._mesh, self.shard_fn + ) + ( + self.mp_loss_o2, + self.mp_parameters_o2, + ) = self.run_dynamic_amp(mp_layer_o2, 'O2') + self.check_tensor_eq(self.mp_loss_o2, self.base_loss_o2) + for param, param_base in zip( + self.mp_parameters_o2, self.base_parameters_o2 + ): + self.check_tensor_eq(param, param_base) + self.check_tensor_eq(param.grad, param_base.grad) + + def run_test_case(self): + self.test_dp_demo_net() + self.test_mp_demo_net() + + +if __name__ == '__main__': + TestSimpleNetWithAmpForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_clear_gradient.py b/test/auto_parallel/semi_auto_parallel_simple_net_clear_gradient.py new file mode 100644 index 0000000000000..17a852a779c34 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_clear_gradient.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from semi_auto_parallel_simple_net import ( + DemoNet, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +from paddle import nn + + +class TestSimpleNetWithClearGradientForSemiAutoParallel( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + + def run_dynamic_clear_gradient(self, layer, shard_input=False): + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(self.image) + if shard_input: + image = dist.shard_tensor( + image, + dist_attr=dist.DistAttr( + mesh=self._mesh, sharding_specs=['x', None] + ), + ) + out = layer(image) + + label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) + + loss.backward() + + for param in layer.parameters(): + param.clear_gradient() + param.clear_gradient(False) + + def test_demo_net(self): + mp_layer = dist.shard_layer( + DemoNet("clear_gradient_demo"), + self._mesh, + self.shard_fn, + ) + self.run_dynamic_clear_gradient(mp_layer) + + def run_test_case(self): + self.test_demo_net() + + +if __name__ == '__main__': + TestSimpleNetWithClearGradientForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py new file mode 100644 index 0000000000000..ef8ff6e004c45 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py @@ -0,0 +1,161 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from site import getsitepackages + +from semi_auto_parallel_simple_net import TestSimpleNetForSemiAutoParallel + +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import nn +from paddle.utils.cpp_extension import get_build_directory, load +from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS, run_cmd + +# Note(Aurelius84): We use `add_test` in Cmake to config how to run unittest in CI. +# `PYTHONPATH` will be set as `build/python/paddle` that will make no way to find +# paddle include directory. Because the following path is generated after installing +# PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI. +paddle_includes = [] +for site_packages_path in getsitepackages(): + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include') + ) + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include', 'third_party') + ) + +# Test for extra compile args +extra_cc_args = ['-w', '-g'] if not IS_WINDOWS else ['/w'] +extra_nvcc_args = ['-O3'] + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = f'{get_build_directory()}\\dist_custom_relu\\dist_custom_relu.pyd' +if os.name == 'nt' and os.path.isfile(file): + cmd = f'del {file}' + run_cmd(cmd, True) + +if os.name == 'nt': + test_include = "..\\python\\paddle\\base\\tests\\auto_parallel" +else: + test_include = "../python/paddle/base/tests/auto_parallel" +paddle_includes.append(test_include) + +custom_ops = load( + name='dist_custom_relu_jit', + sources=[ + '../custom_op/custom_relu_op.cc', + '../custom_op/custom_relu_op_dup.cc', + '../custom_op/custom_relu_op.cu', + ], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_cc_args, # test for cc flags + extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags + verbose=True, +) + +BATCH_SIZE = 16 +BATCH_NUM = 4 +IMAGE_SIZE = 784 +CLASS_NUM = 10 + + +class PPDemoNet(nn.Layer): + def __init__(self, mesh0, mesh1, param_suffix=""): + super().__init__() + self.replicate_dist_attr0 = dist.DistAttr( + mesh=mesh0, sharding_specs=[None, None] + ) + self.replicate_dist_attr1 = dist.DistAttr( + mesh=mesh1, sharding_specs=[None, None] + ) + self.w0 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, IMAGE_SIZE], + attr=paddle.framework.ParamAttr( + name="pp_demo_weight_0" + param_suffix, + initializer=paddle.nn.initializer.Uniform(0, 1), + ), + ), + dist_attr=self.replicate_dist_attr0, + ) + self.w1 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, CLASS_NUM], + attr=paddle.framework.ParamAttr( + name="pp_nemo_weight_1" + param_suffix, + initializer=paddle.nn.initializer.Uniform(0, 1), + ), + ), + dist_attr=self.replicate_dist_attr1, + ) + + def forward(self, x): + out = F.linear(x, self.w0) + out = custom_ops.custom_relu(out) + # out = F.relu(out) + out = dist.reshard(out, dist_attr=self.replicate_dist_attr1) + out = F.linear(out, self.w1) + return out + + +class TestSimpleNetWithCustomReluForSemiAutoParallel( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._pp_mesh0 = dist.ProcessMesh([0], dim_names=["x"]) + self._pp_mesh1 = dist.ProcessMesh([1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + + def run_dynamic_custom_relu(self, layer, shard_input=False): + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(self.image) + if shard_input: + image = dist.shard_tensor( + image, + dist_attr=dist.DistAttr( + mesh=self._mesh, sharding_specs=['x', None] + ), + ) + out = layer(image) + + label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) + + loss.backward() + + def test_demo_net(self): + mp_layer = dist.shard_layer( + PPDemoNet(self._pp_mesh0, self._pp_mesh1), + self._mesh, + self.shard_fn, + ) + self.run_dynamic_custom_relu(mp_layer) + + def run_test_case(self): + self.test_demo_net() + + +if __name__ == "__main__": + TestSimpleNetWithCustomReluForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_fill_zero_for_emtpy_grad.py b/test/auto_parallel/semi_auto_parallel_simple_net_fill_zero_for_emtpy_grad.py new file mode 100644 index 0000000000000..f32fab0d69997 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_fill_zero_for_emtpy_grad.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from semi_auto_parallel_simple_net import ( + DemoNet, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +from paddle import nn + + +class TestSimpleNetWithEmtpyGradForSemiAutoParallel( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + + def run_dynamic_empty_grad(self, layer, shard_input=False): + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(self.image) + if shard_input: + image = dist.shard_tensor( + image, + dist_attr=dist.DistAttr( + mesh=self._mesh, sharding_specs=['x', None] + ), + ) + out = layer(image) + out = paddle.split(out, 2)[0] + + label = paddle.to_tensor(self.label) + label = paddle.split(label, 2)[0] + loss = loss_fn(out, label) + + loss.backward() + + grads = paddle.base.core.eager.get_grads_types( + [layer.parameters()[0], layer.parameters()[1]] + ) + layer.parameters()[0]._reset_grad_inplace_version() + tmp = layer.parameters()[1]._grad_value() + + def test_demo_net(self): + mp_layer = dist.shard_layer( + DemoNet("empty_grad_demo"), + self._mesh, + self.shard_fn, + ) + self.run_dynamic_empty_grad(mp_layer) + + def run_test_case(self): + self.test_demo_net() + + +if __name__ == '__main__': + TestSimpleNetWithEmtpyGradForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_grad_api.py b/test/auto_parallel/semi_auto_parallel_simple_net_grad_api.py new file mode 100644 index 0000000000000..2a531a75e8df2 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_grad_api.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from semi_auto_parallel_simple_net import ( + DemoNet, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +from paddle import nn + + +class TestSimpleNetWithGradApiForSemiAutoParallel( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + + def run_dynamic_grad_api(self, layer, shard_input=False): + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(self.image) + if shard_input: + image = dist.shard_tensor( + image, + dist_attr=dist.DistAttr( + mesh=self._mesh, sharding_specs=['x', None] + ), + ) + out = layer(image) + + label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) + + loss.backward() + + grads = paddle.base.core.eager.get_grads_types( + [layer.parameters()[0], layer.parameters()[1]] + ) + layer.parameters()[0]._reset_grad_inplace_version() + tmp = layer.parameters()[1]._grad_value() + + def test_demo_net(self): + mp_layer = dist.shard_layer( + DemoNet("grad_api_demo"), + self._mesh, + self.shard_fn, + ) + self.run_dynamic_grad_api(mp_layer) + + def run_test_case(self): + self.test_demo_net() + + +if __name__ == '__main__': + TestSimpleNetWithGradApiForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_gradient_hook.py b/test/auto_parallel/semi_auto_parallel_simple_net_gradient_hook.py new file mode 100644 index 0000000000000..4c0e8284b5135 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_gradient_hook.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +from semi_auto_parallel_simple_net import ( + DemoNet, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +from paddle import nn + +hook_triggered = False + + +def backward_hook(): + def trigger_hook(grad): + global hook_triggered + hook_triggered = True + assert grad.is_dist() + return paddle.scale(grad, 1.0) + + return trigger_hook + + +class TestSimpleNetWithGradientHookForSemiAutoParallel( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + + def run_dynamic(self, layer): + loss_fn = nn.MSELoss() + image = paddle.to_tensor(self.image) + + out = layer(image) + label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) + loss.backward() + + def test_register_grad_hook(self): + paddle.seed(self._seed) + np.random.seed(self._seed) + + model = dist.shard_layer( + DemoNet("mp_demo_register_grad_hook"), self._mesh, self.shard_fn + ) + model.parameters()[0]._register_grad_hook(backward_hook()) + + self.run_dynamic(model) + global hook_triggered + assert hook_triggered + hook_triggered = False + + def test_register_hook(self): + paddle.seed(self._seed) + np.random.seed(self._seed) + + model = dist.shard_layer( + DemoNet("mp_demo_register_hook"), self._mesh, self.shard_fn + ) + model.parameters()[0].register_hook(backward_hook()) + + self.run_dynamic(model) + global hook_triggered + assert hook_triggered + hook_triggered = False + + def run_test_case(self): + self.test_register_grad_hook() + self.test_register_hook() + + +if __name__ == '__main__': + TestSimpleNetWithGradientHookForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_gradient_merge.py b/test/auto_parallel/semi_auto_parallel_simple_net_gradient_merge.py new file mode 100644 index 0000000000000..83403e0d0ecd8 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_gradient_merge.py @@ -0,0 +1,105 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +from semi_auto_parallel_simple_net import ( + DemoNet, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +from paddle import nn + + +class TestSimpleNetWithGradientMergeForSemiAutoParallel( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + self.init_single_card_net_result() + + def run_dynamic_gradient_merge(self, layer, shard_input=False): + paddle.seed(self._seed) + np.random.seed(self._seed) + + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(self.image) + if shard_input: + image = dist.shard_tensor( + image, + dist_attr=dist.DistAttr( + mesh=self._mesh, sharding_specs=['x', None] + ), + ) + + for i in range(2): + out = layer(image) + label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) + loss.backward() + + return loss, layer.parameters() + + def init_single_card_net_result(self): + ( + self.base_loss, + self.base_parameters, + ) = self.run_dynamic_gradient_merge(DemoNet("gradient_merge_demo")) + + def test_dp_demo_net(self): + ( + self.dp_loss, + self.dp_parameters, + ) = self.run_dynamic_gradient_merge( + DemoNet("gradient_merge_dp_demo"), + shard_input=True, + ) + self.check_tensor_eq(self.dp_loss, self.base_loss) + self.check_tensor_eq(self.dp_loss, self.base_loss) + for param, param_base in zip(self.dp_parameters, self.base_parameters): + self.check_tensor_eq(param, param_base) + self.check_tensor_eq(param.grad, param_base.grad) + + def test_mp_demo_net(self): + mp_layer = dist.shard_layer( + DemoNet("gradient_merge_mp_demo"), self._mesh, self.shard_fn + ) + ( + self.mp_loss, + self.mp_parameters, + ) = self.run_dynamic_gradient_merge(mp_layer) + + self.check_tensor_eq(self.mp_loss, self.base_loss) + for param, param_base in zip(self.mp_parameters, self.base_parameters): + self.check_tensor_eq(param, param_base) + self.check_tensor_eq(param.grad, param_base.grad) + + def run_test_case(self): + self.test_dp_demo_net() + self.test_mp_demo_net() + + +if __name__ == '__main__': + TestSimpleNetWithGradientMergeForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_hybrid.py b/test/auto_parallel/semi_auto_parallel_simple_net_hybrid.py index 90532a647812a..6626411108e86 100644 --- a/test/auto_parallel/semi_auto_parallel_simple_net_hybrid.py +++ b/test/auto_parallel/semi_auto_parallel_simple_net_hybrid.py @@ -15,53 +15,12 @@ import os from semi_auto_parallel_simple_net import ( - CLASS_NUM, - IMAGE_SIZE, + DemoNet, TestSimpleNetForSemiAutoParallel, ) import paddle import paddle.distributed as dist -from paddle import nn - - -class DPAndMPDemoNet(nn.Layer): - def __init__(self, np_w0, np_w1, mesh): - super().__init__() - self.mesh = mesh - self.w0 = dist.shard_tensor( - self.create_parameter( - shape=[IMAGE_SIZE, IMAGE_SIZE], - attr=paddle.framework.ParamAttr( - name="dmp_demo_weight_1", - initializer=paddle.nn.initializer.Assign(np_w0), - ), - ), - dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, 'y']), - ) - self.w1 = dist.shard_tensor( - self.create_parameter( - shape=[IMAGE_SIZE, CLASS_NUM], - attr=paddle.framework.ParamAttr( - name="dmp_nemo_weight_2", - initializer=paddle.nn.initializer.Assign(np_w1), - ), - ), - dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=['y', None]), - ) - - def forward(self, x): - y = paddle.matmul( - dist.shard_tensor( - x, - dist_attr=dist.DistAttr( - mesh=self.mesh, sharding_specs=['x', None] - ), - ), - self.w0, - ) - z = paddle.matmul(y, self.w1) - return z class TestSimpleNetHybridStrategyForSemiAutoParallel( @@ -72,6 +31,15 @@ def __init__(self): self._backend = os.getenv("backend") self._seed = eval(os.getenv("seed")) self._mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) + self._pp_mesh0 = dist.ProcessMesh( + [[0, 1], [2, 3]], dim_names=["x", "y"] + ) + self._pp_mesh1 = dist.ProcessMesh( + [[4, 5], [6, 7]], dim_names=["x", "y"] + ) + self.pp_reshard_dist_attr = dist.DistAttr( + mesh=self._pp_mesh1, sharding_specs=["x", "y"] + ) paddle.set_device(self._backend) @@ -79,17 +47,95 @@ def __init__(self): self.init_single_card_net_result() def test_dp_mp_demo_net(self): + model = dist.shard_layer( + DemoNet("dp_mp_hybrid_strategy"), self._mesh, self.shard_fn + ) + ( self.dp_mp_loss, - self.dp_mp_w0_grad, - self.dp_mp_w1_grad, - ) = self.run_dynamic(DPAndMPDemoNet(self.w0, self.w1, self._mesh)) + self.dp_mp_parameters, + ) = self.run_dynamic(model, shard_input=True) + self.check_tensor_eq(self.dp_mp_loss, self.base_loss) - self.check_tensor_eq(self.dp_mp_w0_grad, self.base_w0_grad) - self.check_tensor_eq(self.dp_mp_w1_grad, self.base_w1_grad) + for param, param_base in zip( + self.dp_mp_parameters, self.base_parameters + ): + self.check_tensor_eq(param, param_base) + self.check_tensor_eq(param.grad, param_base.grad) + + def dp_mp_pp_shard_fn(self, layer_name, layer, process_mesh): + if layer_name == 'linear_0': + # shard_layer doens't support cross-mesh now. + # input process_mesh of pp_shard_fn is useless, + # it's defined just for unified format. + weight_dist_attr = dist.DistAttr( + mesh=self._pp_mesh0, sharding_specs=[None, 'y'] + ) + bias_dist_attr = dist.DistAttr( + mesh=self._pp_mesh0, sharding_specs=[None] + ) + layer.weight = dist.shard_tensor( + layer.weight, dist_attr=weight_dist_attr + ) + layer.bias = dist.shard_tensor(layer.bias, dist_attr=bias_dist_attr) + elif layer_name == 'linear_1': + weight_dist_attr = dist.DistAttr( + mesh=self._pp_mesh1, sharding_specs=['y', None] + ) + bias_dist_attr = dist.DistAttr( + mesh=self._pp_mesh1, sharding_specs=[None] + ) + layer.weight = dist.shard_tensor( + layer.weight, dist_attr=weight_dist_attr + ) + layer.bias = dist.shard_tensor(layer.bias, dist_attr=bias_dist_attr) + + def dp_mp_pp_demo_net(self): + model = dist.shard_layer( + DemoNet( + "dp_mp_pp_hybrid_strategy", + is_pp=True, + pp_reshard_dist_attr=self.pp_reshard_dist_attr, + ), + self._pp_mesh0, + self.dp_mp_pp_shard_fn, + ) + + ( + self.dp_mp_pp_loss, + self.dp_mp_pp_parameters, + ) = self.run_dynamic(model, shard_input=True, is_pp=True) + + rank = dist.get_rank() + # TODO(GhostScreaming): DistTensor.numpy() doesn't support + # cross-mesh now, ReshardXToReplicated function in eager_method + # needs to be fixed later. + if rank in [0, 1, 2, 3]: + # linear_0 weight and bias + self.check_tensor_eq( + self.dp_mp_pp_parameters[0], self.base_parameters[0] + ) + self.check_tensor_eq( + self.dp_mp_pp_parameters[1], self.base_parameters[1] + ) + else: + self.check_tensor_eq(self.dp_mp_pp_loss, self.base_loss) + # linear_1 weight and bias + self.check_tensor_eq( + self.dp_mp_pp_parameters[2], self.base_parameters[2] + ) + self.check_tensor_eq( + self.dp_mp_pp_parameters[3], self.base_parameters[3] + ) def run_test_case(self): self.test_dp_mp_demo_net() + # TODO(GhostScreaming): Paddle-CI-Coverage doesn't support 8-cards + # testcase now. Enable it later. It can be tested with + # modify test_semi_auto_parallel_hybrid_strategy.py `setUp` function, + # just set num_of_devices=8, nnode =1 and _changeable_envs = {"backend": ["gpu"]} + # to test it. + # self.dp_mp_pp_demo_net() if __name__ == '__main__': diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_recompute.py b/test/auto_parallel/semi_auto_parallel_simple_net_recompute.py new file mode 100644 index 0000000000000..78a9ec2d136f3 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_recompute.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +from semi_auto_parallel_simple_net import ( + DemoNet, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +from paddle import nn + + +class TestSimpleNetWithRecomputeForSemiAutoParallel( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + self.init_single_card_net_result() + + def run_dynamic_recompute(self, layer, shard_input=False): + paddle.seed(self._seed) + np.random.seed(self._seed) + + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(self.image) + if shard_input: + image = dist.shard_tensor( + image, + dist_attr=dist.DistAttr( + mesh=self._mesh, sharding_specs=['x', None] + ), + ) + image.stop_gradient = False + out = layer(image) + + label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) + + loss.backward() + return loss, layer.parameters() + + def init_single_card_net_result(self): + ( + self.base_loss, + self.base_parameters, + ) = self.run_dynamic_recompute( + DemoNet("recompute_demo", is_recompute=True) + ) + + def test_dp_demo_net(self): + ( + self.dp_loss, + self.dp_parameters, + ) = self.run_dynamic_recompute( + DemoNet("recompute_dp_demo", is_recompute=True), + shard_input=True, + ) + self.check_tensor_eq(self.dp_loss, self.base_loss) + self.check_tensor_eq(self.dp_loss, self.base_loss) + for param, param_base in zip(self.dp_parameters, self.base_parameters): + self.check_tensor_eq(param, param_base) + self.check_tensor_eq(param.grad, param_base.grad) + + def test_mp_demo_net(self): + mp_layer = dist.shard_layer( + DemoNet("recompute_mp_demo", is_recompute=True), + self._mesh, + self.shard_fn, + ) + ( + self.mp_loss, + self.mp_parameters, + ) = self.run_dynamic_recompute(mp_layer) + + self.check_tensor_eq(self.mp_loss, self.base_loss) + for param, param_base in zip(self.mp_parameters, self.base_parameters): + self.check_tensor_eq(param, param_base) + self.check_tensor_eq(param.grad, param_base.grad) + + def run_test_case(self): + self.test_dp_demo_net() + self.test_mp_demo_net() + + +if __name__ == '__main__': + TestSimpleNetWithRecomputeForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_zero_grads.py b/test/auto_parallel/semi_auto_parallel_simple_net_zero_grads.py new file mode 100644 index 0000000000000..4bdcd540618f2 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_zero_grads.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from semi_auto_parallel_simple_net import ( + DemoNet, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +from paddle import nn + + +class TestSimpleNetWithZeroGradsForSemiAutoParallel( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + + def run_dynamic_zero_grads(self, layer, shard_input=False): + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(self.image) + if shard_input: + image = dist.shard_tensor( + image, + dist_attr=dist.DistAttr( + mesh=self._mesh, sharding_specs=['x', None] + ), + ) + out = layer(image) + + label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) + + loss.backward() + + for param in layer.parameters(): + param._zero_grads() + + def test_demo_net(self): + mp_layer = dist.shard_layer( + DemoNet("zero_grads_demo"), + self._mesh, + self.shard_fn, + ) + self.run_dynamic_zero_grads(mp_layer) + + def run_test_case(self): + self.test_demo_net() + + +if __name__ == "__main__": + TestSimpleNetWithZeroGradsForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_util.py b/test/auto_parallel/semi_auto_parallel_util.py new file mode 100644 index 0000000000000..cfb905e8382a2 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_util.py @@ -0,0 +1,133 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class SemiAutoParallelTestBase: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + + def flatten(self, inputs, terminal_cond): + """ + inputs may be single tensor、tuple + """ + + if terminal_cond(inputs): + return [inputs], "i" + + assert isinstance(inputs, (tuple, list)) + flattened = [] + structure = [] + for i in range(len(inputs)): + tmp, tmp_structure = self.flatten(inputs[i], terminal_cond) + flattened.extend(tmp) + structure.append(tmp_structure) + + if isinstance(inputs, tuple): + structure = tuple(structure) + return flattened, structure + + def unflatten(self, inputs, structure, offset=0): + """ + inputs may be single tensor + """ + assert isinstance(inputs, list) + assert offset < len(inputs) + if structure == "i": + offset = offset + 1 + # return a list + return inputs[offset - 1], offset + assert isinstance(structure, (tuple, list)) + unflattened = [] + for i in range(len(structure)): + tmp, offset = self.unflatten(inputs, structure[i], offset) + unflattened.append(tmp) + if isinstance(structure, tuple): + unflattened = tuple(unflattened) + return unflattened, offset + + def runfunc_and_check( + self, inputs_shape, inputs_specs, op_func, with_backward, **kwargs + ): + paddle.seed(self._seed) + np.random.seed(self._seed) + + flat_inputs = [] + flat_dist_inputs = [] + + def terminal_cond(x): + return isinstance(x, list) and all( + not isinstance(e, (list, tuple)) for e in x + ) + + flat_inputs_specs, inputs_structure = self.flatten( + inputs_specs, terminal_cond + ) + flat_inputs_shape, _ = self.flatten(inputs_shape, terminal_cond) + assert len(flat_inputs_specs) == len(flat_inputs_shape) + + for shape, spec in zip(flat_inputs_shape, flat_inputs_specs): + input_np = np.random.random(size=shape).astype(self._dtype) + input = paddle.to_tensor(input_np) + input.stop_gradient = False + input_dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=spec + ) + dist_input = dist.shard_tensor(input, dist_attr=input_dist_attr) + dist_input.stop_gradient = False + flat_inputs.append(input) + flat_dist_inputs.append(dist_input) + inputs, _ = self.unflatten(flat_inputs, inputs_structure) + dist_inputs, _ = self.unflatten(flat_dist_inputs, inputs_structure) + + def wrap_tuple(e): + return e if isinstance(e, tuple) else (e,) + + op_inputs = wrap_tuple(inputs) + op_dist_input = wrap_tuple(dist_inputs) + + out = op_func(*op_inputs, **kwargs) + dist_out = op_func(*op_dist_input, **kwargs) + + if with_backward: + + def terminal_cond2(x): + return not isinstance(x, (list, tuple)) + + flat_out, _ = self.flatten(out, terminal_cond2) + flat_dist_out, _ = self.flatten(dist_out, terminal_cond2) + assert len(flat_out) == len(flat_dist_out) + for output, dist_output in zip(flat_out, flat_dist_out): + self.check_tensor_eq(out, dist_out) + output.backward() + dist_output.backward() + + for x, dist_x in zip(flat_inputs, flat_dist_inputs): + self.check_tensor_eq(x.grad, dist_x.grad) + + return dist_inputs, dist_out diff --git a/test/auto_parallel/spmd_rules/CMakeLists.txt b/test/auto_parallel/spmd_rules/CMakeLists.txt index cf034e33678aa..97c2d1dc0205e 100644 --- a/test/auto_parallel/spmd_rules/CMakeLists.txt +++ b/test/auto_parallel/spmd_rules/CMakeLists.txt @@ -18,6 +18,10 @@ if(WITH_DISTRIBUTE) py_test_modules(test_default_data_parallel_rule MODULES test_default_data_parallel_rule) py_test_modules(test_layer_norm_rule MODULES test_layer_norm_rule) + py_test_modules(test_slice_rule MODULES test_slice_rule) + py_test_modules(test_flatten_rule MODULES test_flatten_rule) + py_test_modules(test_unsqueeze_rule MODULES test_unsqueeze_rule) + py_test_modules(test_concat_rule MODULES test_concat_rule) # End of unittests WITH single card WITHOUT timeout endif() diff --git a/test/auto_parallel/spmd_rules/test_concat_rule.py b/test/auto_parallel/spmd_rules/test_concat_rule.py new file mode 100644 index 0000000000000..b1e1c11a0622e --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_concat_rule.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto +from paddle.framework import core + + +class TestConcatSPMDRule(unittest.TestCase): + """ + Unit tests for split spmd rule. + """ + + def setUp(self): + self.process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]]) + self.shapes = [[16, 16, 16], [4, 16, 16], [2, 16, 16]] + self.dim_mappings = [[-1, 0, 1], [-1, 1, 0], [-1, -1, 0]] + + def build_inputs(self): + inputs = [] + for shape, dim_mapping in zip(self.shapes, self.dim_mappings): + tensor_dist_attr = TensorDistAttr() + tensor_dist_attr.dims_mapping = dim_mapping + tensor_dist_attr.process_mesh = self.process_mesh + inputs.append(DistTensorSpec(shape, tensor_dist_attr)) + return inputs + + def test_infer_forward(self): + inputs = self.build_inputs() + rule = core.get_phi_spmd_rule("concat") + infered_dist_attrs = rule.infer_forward(inputs, 0) + infered_input_dist_attrs = infered_dist_attrs[0] + self.assertEqual(len(infered_input_dist_attrs), 1) + infered_output_dist_attrs = infered_dist_attrs[1] + self.assertEqual(len(infered_output_dist_attrs), 1) + for input_dist_attr in infered_input_dist_attrs[0]: + self.assertEqual(input_dist_attr.dims_mapping, [-1, 1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py b/test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py index 8d69da185246e..f8ceb1b88bf96 100644 --- a/test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py +++ b/test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py @@ -26,7 +26,7 @@ class TestDefaultDataParallelSPMDRule(unittest.TestCase): def setUp(self): # After replaced all spmd rules by phi impl, we can recover the # api name to `get_spmd_rule` - self.rule = core.get_phi_spmd_rule("unsqueeze") + self.rule = core.get_phi_spmd_rule("default_data_parallel") x_shape = [10, 10, 32, 48] y_shape = [32, 48] diff --git a/test/auto_parallel/spmd_rules/test_flatten_rule.py b/test/auto_parallel/spmd_rules/test_flatten_rule.py new file mode 100644 index 0000000000000..599b2ddf4bf95 --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_flatten_rule.py @@ -0,0 +1,398 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from collections import OrderedDict + +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto +from paddle.framework import core + + +class TestFlattenSPMDRule(unittest.TestCase): + def setUp(self): + self.rule = core.get_phi_spmd_rule("flatten") + + x_shape = [8, 16, 8, 24] + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.dims_mapping = [-1, -1, -1, -1] + x_tensor_dist_attr.process_mesh = process_mesh + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + self.attrs = OrderedDict() + + def test_flatten_infer_forward(self): + # shape: [8, 16, 8, 24] --> [8, 16 * 8, 24] + # dims_mapping: [0, -1, -1, 1] --> [0, -1, -1, 1] [ 0, -1, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = 2 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8, 24] + # dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [ -1, 0, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = 2 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8, 24] + # dims_mapping: [-1, -1, 1, 0] --> [-1, -1, -1, 0] [ -1, -1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 1, 0]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = 2 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0]) + + # shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24] + # dims_mapping: [-1, 0, 1, -1] --> [-1, -1, -1, -1] [ -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1]) + self.attrs['start_axis'] = 0 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) + + # shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24] + # dims_mapping: [0, -1, -1, 1] --> [0, -1, -1, -1] [ 0] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1, 1]) + self.attrs['start_axis'] = 0 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0]) + + # shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24] + # dims_mapping: [1, 0, -1, -1] --> [1, -1, -1, -1] [ 1] + self.x_dist_tensor_spec.set_dims_mapping([1, 0, -1, -1]) + self.attrs['start_axis'] = 0 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24] + # dims_mapping: [-1, -1, 0, 1] --> [-1, -1, -1, -1] [-1, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24] + # dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, -1] [-1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24] + # dims_mapping: [0, 1, -1, -1] --> [0, 1, -1, -1] [0, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1]) + + def test_flatten_infer_backward(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + output_tensor_dist_attr = TensorDistAttr() + output_tensor_dist_attr.dims_mapping = [-1, -1, -1] + output_tensor_dist_attr.process_mesh = process_mesh + self.output_dist_tensor_spec = DistTensorSpec( + [8, 16 * 8, 24], output_tensor_dist_attr + ) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8, 24] (input --> output) + # dims_mapping: [0, -1, 1] --> [0, -1, -1, 1], [0, -1, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16 * 8, 24] + self.output_dist_tensor_spec.set_dims_mapping([0, -1, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = 2 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8, 24] (input --> output) + # dims_mapping: [0, 1, -1] --> [0, 1, -1, -1], [0, 1, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16 * 8, 24] + self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = 2 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8, 24] (input --> output) + # dims_mapping: [-1, 0, 1] --> [-1, 0, -1, 1], [-1, 0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16 * 8, 24] + self.output_dist_tensor_spec.set_dims_mapping([-1, 0, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = 2 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24] (input --> output) + # dims_mapping: [-1] --> [-1, -1, -1, -1], [-1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8 * 16 * 8 * 24] + self.output_dist_tensor_spec.set_dims_mapping([-1]) + self.attrs['start_axis'] = 0 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) + + # shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24] (input --> output) + # dims_mapping: [0] --> [0, -1, -1, -1], [0] (output --> input, output) + self.output_dist_tensor_spec.shape = [8 * 16 * 8 * 24] + self.output_dist_tensor_spec.set_dims_mapping([0]) + self.attrs['start_axis'] = 0 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0]) + + # shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24] (input --> output) + # dims_mapping: [1] --> [1, -1, -1, -1], [1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8 * 16 * 8 * 24] + self.output_dist_tensor_spec.set_dims_mapping([1]) + self.attrs['start_axis'] = 0 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24] (input --> output) + # dims_mapping: [-1, -1] --> [-1, -1, -1, -1], [-1, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16 * 8 * 24] + self.output_dist_tensor_spec.set_dims_mapping([-1, -1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24] (input --> output) + # dims_mapping: [0, -1] --> [0, -1, -1, -1], [0, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16 * 8 * 24] + self.output_dist_tensor_spec.set_dims_mapping([0, -1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24] (input --> output) + # dims_mapping: [0, 1] --> [0, 1, -1, -1], [0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16 * 8 * 24] + self.output_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/spmd_rules/test_reshape_rule.py b/test/auto_parallel/spmd_rules/test_reshape_rule.py index a370580682d8c..8268c7e768276 100644 --- a/test/auto_parallel/spmd_rules/test_reshape_rule.py +++ b/test/auto_parallel/spmd_rules/test_reshape_rule.py @@ -243,10 +243,58 @@ def test_reshape_infer_forward(self): infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1] ) + # shape: [-1, -1, 3072] --> [0, 0, -1, 192] + # dims_mapping: [0, 1, -1] --> [0, 1, -1], [0, 1, -1, -1] + self.x_dist_tensor_spec.shape = [-1, -1, 3072] + self.attrs["shape"] = [0, 0, -1, 192] + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['shape'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + + # shape: [-1, -1, 3072] --> [0, 0, -1, 192] + # dims_mapping: [0, -1, 1] --> [0, -1, -1], [0, -1, -1, -1] + self.x_dist_tensor_spec.shape = [-1, -1, 3072] + self.attrs["shape"] = [0, 0, -1, 192] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, 1]) + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['shape'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, -1] + ) + + # shape: [-1, -1, 3072] --> [0, 0, -1, 192] + # dims_mapping: [1, -1, 0] --> [1, -1, 0], [1, -1, 0, -1] + self.x_dist_tensor_spec.shape = [-1, -1, 3072] + self.attrs["shape"] = [0, 0, -1, 192] + self.x_dist_tensor_spec.set_dims_mapping([1, -1, 0]) + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['shape'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, 0, -1] + ) + # shape: [6, 12, 48, 24] --> [3, 24, 6, -1, -1] # raise error self.attrs["shape"] = [3, 24, 6, -1, -1] - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): self.rule.infer_forward( self.x_dist_tensor_spec, self.attrs['shape'] ) @@ -454,6 +502,63 @@ def test_reshape_infer_backward(self): infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1, 0] ) + # shape: [8, 1024, 3072] --> [0, 0, -1, 192] (input --> output) + # dims_mapping: [0, 1, -1, -1] --> [0, 1, -1], [0, 1, -1, -1] (output --> input, output) + self.x_dist_tensor_spec.shape = [8, 1024, 3072] + self.output_dist_tensor_spec.shape = [0, 0, -1, 192] + self.attrs["shape"] = [0, 0, -1, 192] + self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + + # shape: [-1, -1, 3072] --> [0, 0, -1, 192] (input --> output) + # dims_mapping: [0, 1, -1, -1] --> [0, 1, -1], [0, 1, -1, -1] (output --> input, output) + self.x_dist_tensor_spec.shape = [-1, -1, 3072] + self.output_dist_tensor_spec.shape = [0, 0, -1, 192] + self.attrs["shape"] = [0, 0, -1, 192] + self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + + # shape: [-1, -1, 3072] --> [0, 0, -1, 192] (input --> output) + # dims_mapping: [0, -1, 1, -1] --> [0, -1, 1], [0, -1, 1, -1] (output --> input, output) + self.x_dist_tensor_spec.shape = [-1, -1, 3072] + self.output_dist_tensor_spec.shape = [0, 0, -1, 192] + self.attrs["shape"] = [0, 0, -1, 192] + self.output_dist_tensor_spec.set_dims_mapping([0, -1, 1, -1]) + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, 1, -1] + ) + if __name__ == "__main__": unittest.main() diff --git a/test/auto_parallel/spmd_rules/test_slice_rule.py b/test/auto_parallel/spmd_rules/test_slice_rule.py new file mode 100644 index 0000000000000..e5bad8ff9b87e --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_slice_rule.py @@ -0,0 +1,303 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from collections import OrderedDict + +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto +from paddle.framework import core + + +class TestSliceSPMDRule(unittest.TestCase): + def setUp(self): + self.rule = core.get_phi_spmd_rule("slice") + + x_shape = [8, 8, 16, 16] + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.dims_mapping = [-1, -1, -1, -1] + x_tensor_dist_attr.process_mesh = process_mesh + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + self.attrs = OrderedDict() + self.attrs['infer_flags'] = [0] + self.attrs['decrease_axis'] = [0] + + def test_slice_infer_forward(self): + # axes: [-1] + # dims_mapping: [-1, 0, 1, -1] --> [-1, 0, 1, -1] [-1, 0, 1, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1]) + self.attrs['axes'] = [-1] + self.attrs['starts'] = [4] + self.attrs['ends'] = [8] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, 1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1, -1] + ) + + # axes: [-1] + # dims_mapping: [-1, -1, 1, 0] --> [-1, -1, 1, -1] [-1, -1, 1, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 1, 0]) + self.attrs['axes'] = [-1] + self.attrs['starts'] = [4] + self.attrs['ends'] = [-1] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, 1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 1, -1] + ) + + # axes: [1, 2] + # dims_mapping: [0, -1, -1, 1] --> [0, -1, -1, 1] [0, -1, -1, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1, 1]) + self.attrs['axes'] = [1, 2] + self.attrs['starts'] = [4, 4] + self.attrs['ends'] = [-1, 32] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + + # axes: [1, 2] + # dims_mapping: [-1, 1, 0, -1] --> [-1, -1, -1, -1] [-1, -1, -1, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 1, 0, -1]) + self.attrs['axes'] = [1, 2] + self.attrs['starts'] = [4, 4] + self.attrs['ends'] = [-1, 32] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + + # axes: [0, 1, 2, 3] + # dims_mapping: [0, 1, -1, -1] --> [-1, -1, -1, -1] [-1, -1, -1, -1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + self.attrs['axes'] = [0, 1, 2, 3] + self.attrs['starts'] = [0, 0, 4, 4] + self.attrs['ends'] = [4, 4, -1, 32] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + + def test_slice_infer_backward(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + output_tensor_dist_attr = TensorDistAttr() + output_tensor_dist_attr.dims_mapping = [-1, -1, -1, -1] + output_tensor_dist_attr.process_mesh = process_mesh + self.output_dist_tensor_spec = DistTensorSpec( + [8, 8, 16, 16], output_tensor_dist_attr + ) + + # axes: [-1] + # dims_mapping: [-1, -1, 0, 1] --> [-1, -1, 0, -1], [-1, -1, 0, -1] (output --> input, output) + self.output_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) + self.attrs['axes'] = [-1] + self.attrs['starts'] = [4] + self.attrs['ends'] = [8] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0, -1] + ) + + # axes: [-1] + # dims_mapping: [-1, 1, 0, -1] --> [-1, 1, 0, -1], [-1, 1, 0, -1] (output --> input, output) + self.output_dist_tensor_spec.set_dims_mapping([-1, 1, 0, -1]) + self.attrs['axes'] = [-1] + self.attrs['starts'] = [4] + self.attrs['ends'] = [-1] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, 0, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0, -1] + ) + + # axes: [1, 2] + # dims_mapping: [-1, 1, 0, -1] --> [-1, -1, -1, -1], [-1, -1, -1, -1] (output --> input, output) + self.output_dist_tensor_spec.set_dims_mapping([-1, 1, 0, -1]) + self.attrs['axes'] = [1, 2] + self.attrs['starts'] = [4, 4] + self.attrs['ends'] = [-1, 32] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + + # axes: [1, 2] + # dims_mapping: [0, -1, -1, 1] --> [0, -1, -1, 1], [0, -1, -1, 1] (output --> input, output) + self.output_dist_tensor_spec.set_dims_mapping([0, -1, -1, 1]) + self.attrs['axes'] = [1, 2] + self.attrs['starts'] = [4, 4] + self.attrs['ends'] = [-1, 32] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + + # axes: [0, 1, 2, 3] + # dims_mapping: [0, 1, -1, -1] --> [-1, -1, -1, -1] [-1, -1, -1, -1] (output --> input, output) + self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + self.attrs['axes'] = [0, 1, 2, 3] + self.attrs['starts'] = [0, 0, 4, 4] + self.attrs['ends'] = [4, 4, -1, 32] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py b/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py new file mode 100644 index 0000000000000..afb851279ca36 --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py @@ -0,0 +1,341 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from collections import OrderedDict + +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto +from paddle.framework import core + + +class TestUnsqueezeSPMDRule(unittest.TestCase): + def setUp(self): + self.rule = core.get_phi_spmd_rule("unsqueeze") + + x_shape = [8, 16] + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.dims_mapping = [-1, -1] + x_tensor_dist_attr.process_mesh = process_mesh + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + self.attrs = OrderedDict() + + def test_unsqueeze_infer_forward(self): + # shape: [8, 16] --> [1, 8, 16] + # dims_mapping: [0, 1] --> [0, 1] [-1, 0, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [8, 16] --> [8, 16, 1] + # dims_mapping: [0, 1] --> [0, 1] [0, 1, -1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [-1] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + + # shape: [8, 16] --> [8, 1, 1, 16] + # dims_mapping: [0, 1] --> [0, 1] [0, -1, -1, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + + # shape: [8, 16] --> [1, 1, 1, 8, 16] + # dims_mapping: [0, 1] --> [0, 1] [-1, -1, -1, 0, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [0, 1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 0, 1] + ) + + # shape: [8, 16] --> [1, 8, 16] + # dims_mapping: [1, 0] --> [1, 0] [-1, 1, 0] + self.x_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0]) + + # shape: [8, 16] --> [8, 16, 1] + # dims_mapping: [1, 0] --> [1, 0] [1, 0, -1] + self.x_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [-1] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0, -1]) + + # shape: [8, 16] --> [8, 1, 1, 16] + # dims_mapping: [1, 0] --> [1, 0] [1, -1, -1, 0] + self.x_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, 0] + ) + + # shape: [8, 16] --> [1, 1, 1, 8, 16] + # dims_mapping: [1, 0] --> [1, 0] [-1, -1, -1, 1, 0] + self.x_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [0, 1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 1, 0] + ) + + # shape: [1, 8, 16] --> [1, 1, 8, 16] + # dims_mapping: [0, 1, -1] --> [-1, 1, -1] [-1, -1, 1, -1] + self.x_dist_tensor_spec.shape = [1, 8, 16] + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 1, -1] + ) + + def test_unsqueeze_infer_backward(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + output_tensor_dist_attr = TensorDistAttr() + output_tensor_dist_attr.dims_mapping = [-1, -1] + output_tensor_dist_attr.process_mesh = process_mesh + self.output_dist_tensor_spec = DistTensorSpec( + [8, 16], output_tensor_dist_attr + ) + + # shape: [8, 16] --> [1, 8, 16] (input --> output) + # dims_mapping: [-1, 0, 1] --> [0, 1], [-1, 0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, 0, 1]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [8, 16] --> [8, 16, 1] (input --> output) + # dims_mapping: [0, 1, -1] --> [0, 1], [0, 1, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16, 1] + self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + self.attrs['axis'] = [-1] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + + # shape: [8, 16] --> [8, 1, 1, 16] (input --> output) + # dims_mapping: [0, -1, -1, 1] --> [0, 1], [0, -1, -1, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 1, 1, 16] + self.output_dist_tensor_spec.set_dims_mapping([0, -1, -1, 1]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + + # shape: [8, 16] --> [1, 1, 1, 8, 16] (input --> output) + # dims_mapping: [-1, -1, -1, 0, 1] --> [0, 1], [-1, -1, -1, 0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 1, 1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, -1, -1, 0, 1]) + self.attrs['axis'] = [0, 1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 0, 1] + ) + + # shape: [8, 16] --> [1, 8, 16] (input --> output) + # dims_mapping: [-1, 1, 0] --> [1, 0], [-1, 1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, 1, 0]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0]) + + # shape: [8, 16] --> [8, 16, 1] (input --> output) + # dims_mapping: [1, 0, -1] --> [1, 0], [1, 0, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16, 1] + self.output_dist_tensor_spec.set_dims_mapping([1, 0, -1]) + self.attrs['axis'] = [-1] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0, -1]) + + # shape: [8, 16] --> [8, 1, 1, 16] (input --> output) + # dims_mapping: [1, -1, -1, 0] --> [1, 0], [1, -1, -1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 1, 1, 16] + self.output_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, 0] + ) + + # shape: [8, 16] --> [1, 1, 1, 8, 16] (input --> output) + # dims_mapping: [-1, -1, -1, 1, 0] --> [1, 0], [-1, -1, -1, 1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 1, 1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, -1, -1, 1, 0]) + self.attrs['axis'] = [0, 1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 1, 0] + ) + + # shape: [1, 8, 16] --> [1, 1, 8, 16] (input --> output) + # dims_mapping: [-1, 0, 1, -1] --> [-1, 1, -1], [-1, -1, 1, -1] (output --> input, output) + self.x_dist_tensor_spec.shape = [1, 8, 16] + self.output_dist_tensor_spec.shape = [1, 1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 1, -1] + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_api_dist_branch.py b/test/auto_parallel/test_api_dist_branch.py index dbeec8ec36222..323010166db24 100644 --- a/test/auto_parallel/test_api_dist_branch.py +++ b/test/auto_parallel/test_api_dist_branch.py @@ -113,6 +113,7 @@ def test_concat_for_dist_tensor(self): self.check_tensor_eq(local_in2.grad, dist_in2.grad) self.check_tensor_eq(local_in3.grad, dist_in3.grad) + # TODO(GhostScreaming): Support paddle.concat backward later. # input: std::vector # output: std::vector def test_broadcast_tensors_for_dist_tensor(self): @@ -136,24 +137,6 @@ def test_broadcast_tensors_for_dist_tensor(self): self.check_tensor_eq(local_in1.grad, dist_in1.grad) self.check_tensor_eq(local_in2.grad, dist_in2.grad) - # input: paddle::optional - # output: phi::Tensor - def test_expand_as_for_dist_tensor(self): - x1 = np.random.random(size=[2, 8]).astype("float32") - x2 = np.random.random(size=[2, 2, 8]).astype("float32") - local_in1, dist_in1 = self.create_local_and_dist_tensor_pair(x1) - local_in2, dist_in2 = self.create_local_and_dist_tensor_pair(x2) - local_out = paddle.expand_as(local_in1, local_in2) - dist_out = paddle.expand_as(dist_in1, dist_in2) - self.check_tensor_eq(local_out, dist_out) - - # TODO(chenweihang): expand_as is a special case, the forward contains - # optional input, but backward not, open this case after dist support - # optional input - # local_out.backward() - # dist_out.backward() - # self.check_tensor_eq(local_in1.grad, dist_in1.grad) - # input: paddle::optional # output: phi::Tensor def test_bincount_api_for_dist_tensor(self): @@ -255,86 +238,6 @@ def test_check_finite_and_unscale_for_dist_tensor(self): self.check_tensor_eq(local_x, dist_x) self.check_tensor_eq(local_found_inf, dist_found_inf) - # input: phi::Tensor - # output: inplace paddle::optional - def test_adamax_for_dist_tensor(self): - dtype = np.float32 - mp_dtype = np.float32 - shape = [123, 321] - - beta1 = 0.78 - beta2 = 0.899 - epsilon = 1e-5 - param = np.random.random(shape).astype(dtype) - grad = np.random.random(shape).astype(dtype) - moment = np.random.random(shape).astype(dtype) - inf_norm = np.random.random(shape).astype(dtype) - master_param = param.astype(mp_dtype) - - lr = np.array([0.002]).astype("float32") - beta1_pow = np.array([beta1**10]).astype("float32") - - local_param, dist_param = self.create_local_and_dist_tensor_pair(param) - local_grad, dist_grad = self.create_local_and_dist_tensor_pair(grad) - local_lr, dist_lr = self.create_local_and_dist_tensor_pair(lr) - ( - local_beta1_pow, - dist_beta1_pow, - ) = self.create_local_and_dist_tensor_pair(beta1_pow) - local_moment, dist_moment = self.create_local_and_dist_tensor_pair( - moment - ) - local_inf_norm, dist_inf_norm = self.create_local_and_dist_tensor_pair( - inf_norm - ) - ( - local_master_param, - dist_master_param, - ) = self.create_local_and_dist_tensor_pair(master_param) - - ( - local_param_out, - local_moment_out, - local_inf_norm_out, - local_master_param_out, - ) = paddle._C_ops.adamax_( - local_param, - local_grad, - local_lr, - local_moment, - local_inf_norm, - local_beta1_pow, - local_master_param, - beta1, - beta2, - epsilon, - True, - ) - - ( - dist_param_out, - dist_moment_out, - dist_inf_norm_out, - dist_master_param_out, - ) = paddle._C_ops.adamax_( - dist_param, - dist_grad, - dist_lr, - dist_moment, - dist_inf_norm, - dist_beta1_pow, - dist_master_param, - beta1, - beta2, - epsilon, - True, - ) - - self.check_tensor_eq(local_param_out, dist_param_out) - self.check_tensor_eq(local_moment_out, dist_moment_out) - self.check_tensor_eq(local_inf_norm_out, dist_inf_norm_out) - self.check_tensor_eq(local_master_param_out, dist_master_param_out) - # multi kernel functions def test_adagrad_for_dist_tensor(self): dtype = np.float16 diff --git a/test/auto_parallel/test_gpt_with_newir.py b/test/auto_parallel/test_gpt_with_pir.py similarity index 91% rename from test/auto_parallel/test_gpt_with_newir.py rename to test/auto_parallel/test_gpt_with_pir.py index 2f736d8a3b297..fd59c7471eb39 100644 --- a/test/auto_parallel/test_gpt_with_newir.py +++ b/test/auto_parallel/test_gpt_with_pir.py @@ -19,10 +19,10 @@ import unittest -class TestGPTNewIR(unittest.TestCase): - def test_gpt_newir(self): +class TestGPTPir(unittest.TestCase): + def test_gpt_pir(self): file_dir = os.path.dirname(os.path.abspath(__file__)) - launch_model_path = os.path.join(file_dir, "gpt_with_newir.py") + launch_model_path = os.path.join(file_dir, "gpt_with_pir.py") if os.environ.get("WITH_COVERAGE", "OFF") == "ON": coverage_args = ["-m", "coverage", "run", "--branch", "-p"] diff --git a/test/auto_parallel/test_mp_allreduce_matmul_grad_overlapping.py b/test/auto_parallel/test_mp_allreduce_matmul_grad_overlapping.py new file mode 100644 index 0000000000000..168836b263f5c --- /dev/null +++ b/test/auto_parallel/test_mp_allreduce_matmul_grad_overlapping.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys +import tempfile +import unittest + + +class TestMPAllreduceMatmulGradOverlapping(unittest.TestCase): + def test_mp_allreduce_matmul_grad_overlapping(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join( + file_dir, "mp_allreduce_matmul_grad_overlapping_unittest.py" + ) + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = ( + [sys.executable, "-u"] + + coverage_args + + [ + "-m", + "paddle.distributed.launch", + "--devices", + "0,1", + "--log_dir", + tmp_dir.name, + launch_model_path, + ] + ) + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_reshard_r_to_s.py b/test/auto_parallel/test_reshard_r_to_s.py index 68699885094de..b951508f8c1c9 100644 --- a/test/auto_parallel/test_reshard_r_to_s.py +++ b/test/auto_parallel/test_reshard_r_to_s.py @@ -22,7 +22,7 @@ def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { "dtype": "float32", - "seeds": str(self._seeds), + "seeds": "2023", } self._changeable_envs = { "shape": ["(10, 20)", "(5, 7)"], @@ -40,6 +40,17 @@ def test_reshard_r_to_s(self): user_defined_envs=envs, ) + def test_reshard_r_to_s_cross_mesh(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + if envs["backend"] != "cpu": + self.run_test_case( + "reshard_r_to_s_cross_mesh.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/auto_parallel/test_reshard_s_to_r.py b/test/auto_parallel/test_reshard_s_to_r.py index fd67df648a9b0..ec61fbb2a3358 100644 --- a/test/auto_parallel/test_reshard_s_to_r.py +++ b/test/auto_parallel/test_reshard_s_to_r.py @@ -40,6 +40,18 @@ def test_reshard_s_to_r(self): user_defined_envs=envs, ) + def test_reshard_s_to_r_cross_mesh(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + + for envs in envs_list: + if envs["backend"] != "cpu": + self.run_test_case( + "reshard_s_to_r_cross_mesh.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index 8040b97d43ac9..2589566cb670e 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -46,6 +46,36 @@ def test_elementwise_api(self): user_defined_envs=envs, ) + def test_concat_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_concat.py", + user_defined_envs=envs, + ) + + def test_reduction_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_reduction.py", + user_defined_envs=envs, + ) + + def test_bitwise_api(self): + envs_list = test_base.gen_product_envs_list( + {"dtype": "int32", "seed": "2023"}, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_bitwise.py", + user_defined_envs=envs, + ) + def test_several_replicated_spmd_api(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs @@ -56,6 +86,26 @@ def test_several_replicated_spmd_api(self): user_defined_envs=envs, ) + def test_add_n_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_add_n.py", + user_defined_envs=envs, + ) + + def test_custom_relu_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_custom_relu.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/auto_parallel/test_semi_auto_parallel_dygraph_inplace.py b/test/auto_parallel/test_semi_auto_parallel_dygraph_inplace.py new file mode 100644 index 0000000000000..e649b678ea882 --- /dev/null +++ b/test/auto_parallel/test_semi_auto_parallel_dygraph_inplace.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestSemiAutoParallelInplace(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp( + num_of_devices=2, + timeout=120, + ) + self._default_envs = { + "dtype": "float32", + "seed": "2023", + } + self._changeable_envs = {"backend": ["cpu", "gpu"]} + + def test_simple_net_single_strategy(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_dygraph_inplace.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_semi_auto_parallel_functional_in_single_card.py b/test/auto_parallel/test_semi_auto_parallel_functional_in_single_card.py new file mode 100644 index 0000000000000..9f20eb77f41d8 --- /dev/null +++ b/test/auto_parallel/test_semi_auto_parallel_functional_in_single_card.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class TestSemiAutoParallelFunctionalInSingleCard(unittest.TestCase): + def test_tensor_use_gpudnn(self): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + dense_tensor = paddle.randn([10, 20]) + dist_tensor = dist.shard_tensor( + dense_tensor, + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, None]), + ) + dist_tensor._use_gpudnn(False) + + def test_tensor_data_ptr(self): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + dense_tensor = paddle.randn([10, 20]) + dist_tensor = dist.shard_tensor( + dense_tensor, + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, None]), + ) + prt = dist_tensor.data_ptr() + + def test_tensor_offset(self): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + dense_tensor = paddle.randn([10, 20]) + dist_tensor = dist.shard_tensor( + dense_tensor, + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, None]), + ) + offset = dist_tensor._offset() + + def test_tensor_copy_to(self): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + dense_tensor = paddle.randn([10, 20]) + dist_tensor = dist.shard_tensor( + dense_tensor, + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, None]), + ) + dist_tensor._copy_to(paddle.CUDAPlace(0), True) + + def test_tensor__share_buffer_to(self): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + dense_tensor = paddle.randn([10, 20]) + dist_tensor = dist.shard_tensor( + dense_tensor, + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, None]), + ) + dense_tensor2 = paddle.randn([10, 10]) + to = dist.shard_tensor( + dense_tensor2, + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, None]), + ) + dist_tensor._share_buffer_to(to) + + def test_tensor__is_shared_buffer_with(self): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + dense_tensor = paddle.randn([10, 20]) + dist_tensor = dist.shard_tensor( + dense_tensor, + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, None]), + ) + dense_tensor2 = paddle.randn([10, 10]) + to = dist.shard_tensor( + dense_tensor2, + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, None]), + ) + dist_tensor._share_buffer_to(to) + self.assertTrue(dist_tensor._is_shared_buffer_with(to)) + + def test_tensor_strides(self): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + dense_tensor = paddle.randn([10, 20]) + dense_tensor = dense_tensor.reshape([20, 10]) + dist_tensor = dist.shard_tensor( + dense_tensor, + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, None]), + ) + strides = dist_tensor.get_strides() + is_contiguous = dist_tensor.is_contiguous() + dist_tensor = dist_tensor.contiguous() + + def test_tensor_uva(self): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + place = paddle.CPUPlace() + np_value = np.random.random(size=[10, 30]).astype('float32') + dense_tensor = paddle.to_tensor(np_value, place=place) + dist_tensor = dist.shard_tensor( + dense_tensor, + place=place, + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, None]), + ) + dist_tensor._uva() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_semi_auto_parallel_pylayer.py b/test/auto_parallel/test_semi_auto_parallel_pylayer.py new file mode 100644 index 0000000000000..8b0e6b67701cf --- /dev/null +++ b/test/auto_parallel/test_semi_auto_parallel_pylayer.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestSemiAutoParallelPyLayer(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp( + num_of_devices=2, + timeout=120, + ) + self._default_envs = { + "dtype": "float32", + "seed": "2023", + } + self._changeable_envs = {"backend": ["cpu", "gpu"]} + + def test_pylayer(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_pylayer.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py index 89ef4ac6a1a10..8d704c01a5d83 100644 --- a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py +++ b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py @@ -19,7 +19,10 @@ class TestSemiAutoParallelSingleStrategy(test_base.CommunicationTestDistBase): def setUp(self): - super().setUp(num_of_devices=2, timeout=120) + super().setUp( + num_of_devices=2, + timeout=120, + ) self._default_envs = { "dtype": "float32", "seed": "2023", @@ -36,6 +39,100 @@ def test_simple_net_single_strategy(self): user_defined_envs=envs, ) + def test_simple_net_single_strategy_with_amp(self): + self._changeable_envs = {"backend": ["gpu"]} + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_amp.py", + user_defined_envs=envs, + ) + + def test_simple_net_single_strategy_with_gradient_merge(self): + self._changeable_envs = {"backend": ["gpu"]} + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_gradient_merge.py", + user_defined_envs=envs, + ) + + def test_simple_net_recompute(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_recompute.py", + user_defined_envs=envs, + ) + + def test_simple_net_single_strategy_with_gradient_hook(self): + self._changeable_envs = {"backend": ["gpu"]} + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_gradient_hook.py", + user_defined_envs=envs, + ) + + def test_simple_net_clear_gradient(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_clear_gradient.py", + user_defined_envs=envs, + ) + + def test_simple_net_several_grad_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_grad_api.py", + user_defined_envs=envs, + ) + + def test_simple_net_empty_grad(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_fill_zero_for_emtpy_grad.py", + user_defined_envs=envs, + ) + + def test_simple_net_zero_grads(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_zero_grads.py", + user_defined_envs=envs, + ) + + def test_simple_net_custom_relu(self): + self._changeable_envs = {"backend": ["gpu"]} + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_custom_relu.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/auto_parallel/test_shard_tensor_api.py b/test/auto_parallel/test_shard_tensor_api.py index 5e59a7c9480e4..fa1a19596d71b 100644 --- a/test/auto_parallel/test_shard_tensor_api.py +++ b/test/auto_parallel/test_shard_tensor_api.py @@ -133,7 +133,7 @@ def test_static_mode(self): class TestShardTensorStaticDy2Static(unittest.TestCase): def test_dy2static(self): - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def func(): mesh = dist.ProcessMesh( [[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"] diff --git a/test/cinn/ops/test_matmul_op.py b/test/cinn/ops/test_matmul_op.py index 5a1fe42e36ef5..8abb2467e1d66 100755 --- a/test/cinn/ops/test_matmul_op.py +++ b/test/cinn/ops/test_matmul_op.py @@ -136,8 +136,6 @@ def init_attrs(self): # }, { "dtype": "float16", - "max_relative_error": 1e-2, - "max_absolute_error": 1e-2, }, { "dtype": "float32", diff --git a/test/collective/fleet/CMakeLists.txt b/test/collective/fleet/CMakeLists.txt index 92e6ce22b1f1b..5a0e2c0d859ec 100644 --- a/test/collective/fleet/CMakeLists.txt +++ b/test/collective/fleet/CMakeLists.txt @@ -297,6 +297,20 @@ if(WITH_NCCL) set_tests_properties(test_parallel_dygraph_no_sync PROPERTIES TIMEOUT "300") endif() endif() +if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) + bash_test_modules( + test_dygraph_dataparallel_bf16 + START_BASH + ../../legacy_test/dist_test.sh + TIMEOUT + "200" + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=22024;NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" + ) + set_tests_properties(test_dygraph_dataparallel_bf16 PROPERTIES TIMEOUT "200") +endif() if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) bash_test_modules( test_dygraph_sharding_stage2 @@ -326,6 +340,21 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) set_tests_properties(test_dygraph_sharding_stage2_bf16 PROPERTIES TIMEOUT "200") endif() +if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) + bash_test_modules( + test_dygraph_sharding_stage1_bf16 + START_BASH + ../../legacy_test/dist_test.sh + TIMEOUT + "200" + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=22024;NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" + ) + set_tests_properties(test_dygraph_sharding_stage1_bf16 PROPERTIES TIMEOUT + "200") +endif() if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) bash_test_modules( test_dygraph_sharding_stage1_fp16 @@ -695,11 +724,6 @@ if((WITH_GPU OR WITH_XPU) AND (LINUX OR WIN32)) test_fleet_recompute_meta_optimizer ENVS "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") endif() -if(LOCAL_ALL_ARCH AND (LINUX OR WIN32)) - py_test_modules( - test_fleet_private_function MODULES test_fleet_private_function ENVS - "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") -endif() if((WITH_GPU OR WITH_XPU) AND LOCAL_ALL_PLAT) bash_test_modules( test_new_group diff --git a/test/collective/fleet/dygraph_dataparallel_bf16.py b/test/collective/fleet/dygraph_dataparallel_bf16.py new file mode 100644 index 0000000000000..efc7b6f993d98 --- /dev/null +++ b/test/collective/fleet/dygraph_dataparallel_bf16.py @@ -0,0 +1,198 @@ +# -*- coding: UTF-8 -*- + +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +from paddle.distributed.fleet.utils import mix_precision_utils +from paddle.distributed.fleet.utils.hybrid_parallel_util import ( + fused_allreduce_gradients, +) +from paddle.nn import Linear, ReLU + +seed = 2022 +epoch = 2 +linear_size = 1000 + +np.random.seed(seed) +paddle.seed(seed) + + +class MLP(paddle.nn.Layer): + def __init__(self, linear_size=1000): + super().__init__() + + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + self._relu = ReLU() + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + y = self._relu(y) + return y + + +class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples=200, linear_size=1000): + self.num_samples = num_samples + self.linear_size = linear_size + + def __getitem__(self, idx): + img = np.random.rand(self.linear_size).astype('float32') + return img + + def __len__(self): + return self.num_samples + + +def optimizer_setting(model, use_pure_bf16, use_main_grad): + if use_main_grad: + assert use_pure_bf16 + model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") + optimizer = paddle.optimizer.AdamW( + parameters=model.parameters(), + learning_rate=0.00001, + weight_decay=0.00001, + grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0), + multi_precision=use_pure_bf16, + ) + if use_main_grad: + optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer) + + return optimizer + + +def train_mlp( + model, use_pure_bf16=False, use_main_grad=False, accumulate_grad=False +): + optimizer = optimizer_setting( + model=model, use_pure_bf16=use_pure_bf16, use_main_grad=use_main_grad + ) + if use_pure_bf16: + level = 'O2' + custom_white_list = None + model = paddle.amp.decorate( + models=model, + dtype="bfloat16", + level=level, + ) + else: + level = 'O1' + custom_white_list = [ + "matmul_v2", + "elementwise_add", + "relu", + "reduce_mean", + ] + model = paddle.DataParallel(model) + + paddle.seed(2023) + np.random.seed(2023) + train_loader = paddle.io.DataLoader( + RandomDataset(), + batch_size=100, + shuffle=False, + drop_last=True, + num_workers=0, + ) + if not use_pure_bf16: + for param in model.parameters(): + t = paddle.cast( + paddle.cast(param, dtype='bfloat16'), dtype='float32' + ) + param.set_value(t) + + losses = [] + for eop in range(epoch): + model.train() + + for batch_id, data in enumerate(train_loader()): + data.stop_gradient = True + + with model.no_sync(): + with paddle.amp.auto_cast( + True, + level=level, + dtype="bfloat16", + custom_white_list=custom_white_list, + ): + out = model(data) + loss = paddle.mean(out) + + losses.append(loss) + + loss.backward() + + if not accumulate_grad: + fused_allreduce_gradients(list(model.parameters()), None) + + optimizer.step() + optimizer.clear_grad() + + if accumulate_grad: + fused_allreduce_gradients(list(model.parameters()), None) + + optimizer.step() + optimizer.clear_grad() + + return losses + + +def test_dp_bf16(): + if not paddle.amp.is_bfloat16_supported(): + return + paddle.distributed.init_parallel_env() + mlp = MLP() + state_dict = mlp.state_dict() + + # dp bf16 O1 vs dp bf16 O2 main_grad + mlp1 = MLP() + mlp2 = MLP() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + losses_o1 = train_mlp(mlp1, use_pure_bf16=False) + losses_o2 = train_mlp(mlp2, use_pure_bf16=True, use_main_grad=True) + for i in range(len(losses_o2)): + loss_o2 = paddle.cast(losses_o2[i], dtype='float32').detach() + loss_o1 = paddle.cast(losses_o1[i], dtype='float32').detach() + np.testing.assert_array_equal(loss_o2, loss_o1) + + # grad accumulation test + mlp3 = MLP() + mlp4 = MLP() + mlp3.set_state_dict(state_dict) + mlp4.set_state_dict(state_dict) + losses_acc_grad_o1 = train_mlp( + mlp3, use_pure_bf16=False, accumulate_grad=True + ) + losses_acc_grad_o2 = train_mlp( + mlp4, use_pure_bf16=True, use_main_grad=True, accumulate_grad=True + ) + for i in range(len(losses_acc_grad_o2)): + loss_acc_grad_o2 = paddle.cast( + losses_acc_grad_o2[i], dtype='float32' + ).detach() + loss_acc_grad_o1 = paddle.cast( + losses_acc_grad_o1[i], dtype='float32' + ).detach() + np.testing.assert_array_equal(loss_acc_grad_o2, loss_acc_grad_o1) + + +if __name__ == '__main__': + test_dp_bf16() diff --git a/test/collective/fleet/dygraph_group_sharded_stage1_bf16.py b/test/collective/fleet/dygraph_group_sharded_stage1_bf16.py new file mode 100644 index 0000000000000..9a69976b830cc --- /dev/null +++ b/test/collective/fleet/dygraph_group_sharded_stage1_bf16.py @@ -0,0 +1,279 @@ +# -*- coding: UTF-8 -*- + +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.utils import mix_precision_utils +from paddle.nn import Linear, ReLU + +seed = 2022 +epoch = 2 +linear_size = 1000 + +np.random.seed(seed) +paddle.seed(seed) + + +class MLP(paddle.nn.Layer): + def __init__(self, linear_size=1000): + super().__init__() + + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + self._relu = ReLU() + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + y = self._relu(y) + return y + + +class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples=200, linear_size=1000): + self.num_samples = num_samples + self.linear_size = linear_size + + def __getitem__(self, idx): + img = np.random.rand(self.linear_size).astype('float32') + return img + + def __len__(self): + return self.num_samples + + +def optimizer_setting(model, use_pure_bf16, use_main_grad): + if use_main_grad: + assert use_pure_bf16 + model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") + optimizer = paddle.optimizer.AdamW( + parameters=model.parameters(), + learning_rate=0.00001, + weight_decay=0.00001, + grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0), + multi_precision=use_pure_bf16, + ) + if use_main_grad: + optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer) + + return optimizer + + +def train_mlp( + model, + sharding_stage, + use_pure_bf16=False, + accumulate_grad=False, + use_main_grad=False, + test_scaler=False, +): + # bf16 not support dynamic loss scaling + # disable dynamic_loss_scaling to coverage distributed_scaler + dynamic_loss_scaling = False + scaler = None + scale_loss = 1024 + if test_scaler: + assert sharding_stage == 1 + assert not accumulate_grad + scaler = paddle.amp.GradScaler( + init_loss_scaling=scale_loss, + use_dynamic_loss_scaling=dynamic_loss_scaling, + ) + scaler = fleet.distributed_scaler(scaler) + optimizer = optimizer_setting( + model=model, use_pure_bf16=use_pure_bf16, use_main_grad=use_main_grad + ) + + strategy = fleet.DistributedStrategy() + if use_pure_bf16: + level = 'O2' + custom_white_list = None + + amp_configs = { + "init_loss_scaling": scale_loss, + "use_pure_bf16": True, + "use_dynamic_loss_scaling": dynamic_loss_scaling, + } + strategy.amp = True + strategy.amp_configs = amp_configs + else: + level = 'O1' + custom_white_list = [ + "matmul_v2", + "elementwise_add", + "relu", + "reduce_mean", + ] + + if sharding_stage == 1: + hybrid_configs = { + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + "sharding_degree": 2, + } + strategy.hybrid_configs = hybrid_configs + + fleet.init(is_collective=True, strategy=strategy) + model = fleet.distributed_model(model) + + if sharding_stage == 1: + optimizer = fleet.distributed_optimizer(optimizer) + + paddle.seed(2023) + np.random.seed(2023) + train_loader = paddle.io.DataLoader( + RandomDataset(), + batch_size=100, + shuffle=False, + drop_last=True, + num_workers=0, + ) + + if sharding_stage == 1: + model.to(device="gpu") + + if not use_pure_bf16: + for param in model.parameters(): + t = paddle.cast( + paddle.cast(param, dtype='bfloat16'), dtype='float32' + ) + param.set_value(t) + + losses = [] + for eop in range(epoch): + model.train() + + for batch_id, data in enumerate(train_loader()): + data.stop_gradient = True + + with paddle.amp.auto_cast( + True, + level=level, + dtype="bfloat16", + custom_white_list=custom_white_list, + ): + out = model(data) + loss = paddle.mean(out) + + losses.append(loss) + + if test_scaler: + assert scaler is not None + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.clear_grad() + else: + loss.backward() + if not accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + if accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + return losses + + +def test_stage1_bf16(): + if not paddle.amp.is_bfloat16_supported(): + return + paddle.distributed.init_parallel_env() + + mlp = MLP() + state_dict = mlp.state_dict() + + # stage1 bf16 O1 vs stage1 bf16 O2 main_grad + mlp1 = MLP() + mlp2 = MLP() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + o1_losses = train_mlp( + mlp1, + sharding_stage=1, + use_pure_bf16=False, + ) + o2_losses = train_mlp( + mlp2, + sharding_stage=1, + use_pure_bf16=True, + use_main_grad=True, + ) + for i in range(len(o1_losses)): + o1_32_loss = paddle.cast(o1_losses[i], dtype='float32').detach() + o2_32_loss = paddle.cast(o2_losses[i], dtype='float32').detach() + np.testing.assert_array_equal(o1_32_loss, o2_32_loss) + + # stage1 scaler test with main_grad + mlp3 = MLP() + mlp3.set_state_dict(state_dict) + train_mlp( + mlp3, + sharding_stage=1, + use_pure_bf16=True, + use_main_grad=True, + test_scaler=True, + ) + + # stage1 scaler test without main_grad + mlp4 = MLP() + mlp4.set_state_dict(state_dict) + train_mlp( + mlp4, + sharding_stage=1, + use_pure_bf16=True, + use_main_grad=False, + test_scaler=True, + ) + + # grad accumulation test + mlp5 = MLP() + mlp6 = MLP() + mlp5.set_state_dict(state_dict) + mlp6.set_state_dict(state_dict) + o1_losses_grad_acc = train_mlp( + mlp5, + sharding_stage=1, + use_pure_bf16=False, + accumulate_grad=True, + ) + o2_losses_grad_acc = train_mlp( + mlp6, + sharding_stage=1, + use_pure_bf16=True, + use_main_grad=True, + accumulate_grad=True, + ) + for i in range(len(o2_losses_grad_acc)): + o2_loss_grad_acc = paddle.cast( + o2_losses_grad_acc[i], dtype='float32' + ).detach() + o1_loss_grad_acc = paddle.cast( + o1_losses_grad_acc[i], dtype='float32' + ).detach() + np.testing.assert_array_equal(o2_loss_grad_acc, o1_loss_grad_acc) + + return + + +if __name__ == '__main__': + test_stage1_bf16() diff --git a/test/collective/fleet/dygraph_group_sharded_stage1_fp16.py b/test/collective/fleet/dygraph_group_sharded_stage1_fp16.py index 601659e0fb98b..93e163b9facca 100644 --- a/test/collective/fleet/dygraph_group_sharded_stage1_fp16.py +++ b/test/collective/fleet/dygraph_group_sharded_stage1_fp16.py @@ -83,9 +83,9 @@ def train_mlp( accumulate_grad=False, use_main_grad=False, test_scaler=False, - scale_loss=1024, ): scaler = None + scale_loss = 1024 if test_scaler: assert sharding_stage == 1 assert not accumulate_grad @@ -94,10 +94,15 @@ def train_mlp( optimizer = optimizer_setting( model=model, use_pure_fp16=use_pure_fp16, use_main_grad=use_main_grad ) + + strategy = fleet.DistributedStrategy() if use_pure_fp16: level = 'O2' custom_white_list = None - model = paddle.amp.decorate(models=model, dtype="float16", level=level) + + amp_configs = {"init_loss_scaling": scale_loss, "use_pure_fp16": True} + strategy.amp_configs = amp_configs + strategy.amp = True else: level = 'O1' custom_white_list = [ @@ -108,11 +113,19 @@ def train_mlp( ] if sharding_stage == 1: - optimizer = fleet.distributed_optimizer(optimizer) + hybrid_configs = { + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + "sharding_degree": 2, + } + strategy.hybrid_configs = hybrid_configs - model = fleet.distributed_model(model) - else: - model = paddle.DataParallel(model) + fleet.init(is_collective=True, strategy=strategy) + model = fleet.distributed_model(model) + + if sharding_stage == 1: + optimizer = fleet.distributed_optimizer(optimizer) paddle.seed(2023) np.random.seed(2023) @@ -176,19 +189,6 @@ def test_stage1_fp16(): return paddle.distributed.init_parallel_env() - strategy = fleet.DistributedStrategy() - hybrid_configs = { - "dp_degree": 1, - "mp_degree": 1, - "pp_degree": 1, - "sharding_degree": 2, - } - scale_loss = 1024 - amp_configs = {"init_loss_scaling": scale_loss, "use_pure_fp16": True} - strategy.hybrid_configs = hybrid_configs - strategy.amp_configs = amp_configs - - fleet.init(is_collective=True, strategy=strategy) mlp = MLP() state_dict = mlp.state_dict() @@ -201,14 +201,12 @@ def test_stage1_fp16(): mlp1, sharding_stage=1, use_pure_fp16=False, - scale_loss=scale_loss, ) o2_losses = train_mlp( mlp2, sharding_stage=1, use_pure_fp16=True, use_main_grad=True, - scale_loss=scale_loss, ) for i in range(len(o1_losses)): o1_32_loss = paddle.cast(o1_losses[i], dtype='float32').detach() @@ -224,7 +222,6 @@ def test_stage1_fp16(): use_pure_fp16=True, use_main_grad=True, test_scaler=True, - scale_loss=scale_loss, ) # grad accumulation test @@ -237,7 +234,6 @@ def test_stage1_fp16(): sharding_stage=1, use_pure_fp16=False, accumulate_grad=True, - scale_loss=scale_loss, ) o2_losses_grad_acc = train_mlp( mlp6, @@ -245,7 +241,6 @@ def test_stage1_fp16(): use_pure_fp16=True, use_main_grad=True, accumulate_grad=True, - scale_loss=scale_loss, ) for i in range(len(o2_losses_grad_acc)): o2_loss_grad_acc = paddle.cast( diff --git a/test/collective/fleet/hybrid_parallel_sharding_model.py b/test/collective/fleet/hybrid_parallel_sharding_model.py index 41343d2dbda9e..3fe139f84a402 100644 --- a/test/collective/fleet/hybrid_parallel_sharding_model.py +++ b/test/collective/fleet/hybrid_parallel_sharding_model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import random import unittest @@ -22,12 +23,15 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( DygraphShardingOptimizer, + DygraphShardingOptimizerV2, ) from paddle.distributed.fleet.utils.mix_precision_utils import ( MixPrecisionLayer, MixPrecisionOptimizer, ) +g_shard_split_param = int(os.environ.get("FLAGS_shard_split_param", 0)) + vocab_size = 20 hidden_size = 10 inner_size = 8 @@ -193,6 +197,10 @@ def setUp(self): "mp_degree": 1, "pp_degree": 1, } + self.strategy.hybrid_configs[ + "sharding_configs" + ].split_param = g_shard_split_param + fleet.init(is_collective=True, strategy=self.strategy) self.data = [ np.random.randint( @@ -274,13 +282,19 @@ def sharding_model(self, Optimizer, sharded_accumulators, amp_level=None): model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer( Optimizer=Optimizer, amp_level=amp_level ) - - self.assertTrue( - isinstance(optimizer_a._inner_opt, DygraphShardingOptimizer) + shard_opt_cls = ( + DygraphShardingOptimizerV2 + if g_shard_split_param + else DygraphShardingOptimizer ) + self.assertTrue(isinstance(optimizer_a._inner_opt, shard_opt_cls)) for idx in range(STEPS): - if idx == 2 and paddle.distributed.get_rank() == 0: + if ( + idx == 2 + and paddle.distributed.get_rank() == 0 + and not g_shard_split_param + ): self.assertTrue( set(optimizer_a._inner_opt._inner_opt.state_dict().keys()) == sharded_accumulators @@ -303,38 +317,40 @@ def sharding_model(self, Optimizer, sharded_accumulators, amp_level=None): ) def test_sharding_adam(self): - sharded_accumulators = { - 'linear_0.w_0_moment1_0', - 'linear_1.b_0_moment1_0', - 'linear_2.b_0_moment1_0', - 'embedding_0.w_0_moment1_0', - 'linear_0.w_0_moment2_0', - 'linear_1.b_0_moment2_0', - 'linear_2.b_0_moment2_0', - 'embedding_0.w_0_moment2_0', - 'linear_0.w_0_beta1_pow_acc_0', - 'linear_1.b_0_beta1_pow_acc_0', - 'linear_2.b_0_beta1_pow_acc_0', - 'embedding_0.w_0_beta1_pow_acc_0', - 'linear_0.w_0_beta2_pow_acc_0', - 'linear_1.b_0_beta2_pow_acc_0', - 'linear_2.b_0_beta2_pow_acc_0', - 'embedding_0.w_0_beta2_pow_acc_0', - } - self.sharding_model( - Optimizer="adam", sharded_accumulators=sharded_accumulators - ) + if not g_shard_split_param: + sharded_accumulators = { + 'linear_0.w_0_moment1_0', + 'linear_1.b_0_moment1_0', + 'linear_2.b_0_moment1_0', + 'embedding_0.w_0_moment1_0', + 'linear_0.w_0_moment2_0', + 'linear_1.b_0_moment2_0', + 'linear_2.b_0_moment2_0', + 'embedding_0.w_0_moment2_0', + 'linear_0.w_0_beta1_pow_acc_0', + 'linear_1.b_0_beta1_pow_acc_0', + 'linear_2.b_0_beta1_pow_acc_0', + 'embedding_0.w_0_beta1_pow_acc_0', + 'linear_0.w_0_beta2_pow_acc_0', + 'linear_1.b_0_beta2_pow_acc_0', + 'linear_2.b_0_beta2_pow_acc_0', + 'embedding_0.w_0_beta2_pow_acc_0', + } + self.sharding_model( + Optimizer="adam", sharded_accumulators=sharded_accumulators + ) def test_sharding_momentum(self): - sharded_accumulators = { - 'linear_6.w_0_velocity_0', - 'linear_7.b_0_velocity_0', - 'linear_8.b_0_velocity_0', - 'embedding_2.w_0_velocity_0', - } - self.sharding_model( - Optimizer="Momentum", sharded_accumulators=sharded_accumulators - ) + if not g_shard_split_param: + sharded_accumulators = { + 'linear_6.w_0_velocity_0', + 'linear_7.b_0_velocity_0', + 'linear_8.b_0_velocity_0', + 'embedding_2.w_0_velocity_0', + } + self.sharding_model( + Optimizer="Momentum", sharded_accumulators=sharded_accumulators + ) def test_sharding_momentum_amp(self): sharded_accumulators = { diff --git a/test/dygraph_to_static/test_ifelse_basic.py b/test/collective/fleet/test_dygraph_dataparallel_bf16.py similarity index 60% rename from test/dygraph_to_static/test_ifelse_basic.py rename to test/collective/fleet/test_dygraph_dataparallel_bf16.py index 97043fd7ba688..1401399e8fc4c 100644 --- a/test/dygraph_to_static/test_ifelse_basic.py +++ b/test/collective/fleet/test_dygraph_dataparallel_bf16.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import unittest + +from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestDygraphDataParallel(TestMultipleGpus): + def test_dygraph_dataparallel_bf16(self): + self.run_mnist_2gpu('dygraph_dataparallel_bf16.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/test/collective/fleet/test_dygraph_sharding_stage1_bf16.py b/test/collective/fleet/test_dygraph_sharding_stage1_bf16.py new file mode 100644 index 0000000000000..bd15963edd263 --- /dev/null +++ b/test/collective/fleet/test_dygraph_sharding_stage1_bf16.py @@ -0,0 +1,27 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestDygraphShardingStage1(TestMultipleGpus): + # check sharding logic as well as the accuracy with single mode + def test_dygraph_sharding_stage1_bf16(self): + self.run_mnist_2gpu('dygraph_group_sharded_stage1_bf16.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/test/collective/fleet/test_fleet_private_function.py b/test/collective/fleet/test_fleet_private_function.py deleted file mode 100644 index c6a3a197c09ac..0000000000000 --- a/test/collective/fleet/test_fleet_private_function.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import socket -import threading -import unittest - - -class TestFleetPrivateFunction(unittest.TestCase): - def test_wait_port(self): - def init_server(port): - import time - - time.sleep(5) - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(("127.0.0.1", port)) - sock.listen(10) - while True: - c, addr = sock.accept() - c.send("0") - c.close() - break - - thr = threading.Thread(target=init_server, args=(9292,)) - thr.start() - - from paddle.distributed import fleet - - ep = ["127.0.0.1:9292"] - fleet.base.private_helper_function.wait_server_ready(ep) - - thr.join() - - -if __name__ == "__main__": - unittest.main() diff --git a/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py b/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py index f6152a13e10c1..ac9d32036ae27 100644 --- a/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py +++ b/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py @@ -21,22 +21,32 @@ class TestHybridParallel(TestMultipleGpus): # check sharding logic as well as the accuracy with single mode def test_hybrid_parallel_sharding_logic(self): + # test shard v2 + os.environ["FLAGS_shard_use_reduce"] = "1" + os.environ["FLAGS_shard_norm_align_dp"] = "0" + os.environ["FLAGS_shard_split_param"] = "1" + self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') # test shard grad reduce os.environ["FLAGS_shard_use_reduce"] = "1" os.environ["FLAGS_shard_norm_align_dp"] = "0" + os.environ["FLAGS_shard_split_param"] = "0" self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') # test shard grad allreduce os.environ["FLAGS_shard_use_reduce"] = "0" os.environ["FLAGS_shard_norm_align_dp"] = "1" + os.environ["FLAGS_shard_split_param"] = "0" self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') def test_hybrid_parallel_sharding_tensor_fusion(self): + os.environ["FLAGS_shard_split_param"] = "0" self.run_mnist_2gpu('hybrid_parallel_sharding_model_with_fusion.py') def test_hybrid_parallel_sharding_tensor_fusion_amp(self): + os.environ["FLAGS_shard_split_param"] = "0" self.run_mnist_2gpu('hybrid_parallel_sharding_model_with_fusion_amp.py') def test_hybrid_parallel_sharding_state_dict(self): + os.environ["FLAGS_shard_split_param"] = "0" self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py') def test_group_param_tensor_fusion(self): diff --git a/test/collective/fleet/test_recv_save_op.py b/test/collective/fleet/test_recv_save_op.py index b032cae1f5dcf..a738d5be75724 100644 --- a/test/collective/fleet/test_recv_save_op.py +++ b/test/collective/fleet/test_recv_save_op.py @@ -87,7 +87,7 @@ def _wait_ps_ready(self, pid): # on the /tmp directory until it was ready to process all the RPC call. os.stat("/tmp/paddle.%d.port" % pid) return - except os.error: + except OSError: start_left_time -= sleep_time def _get_pserver_port(self, pid): diff --git a/test/collective/fleet/testslist.csv b/test/collective/fleet/testslist.csv index 08b45e1454209..8b7a3b7a2f4c4 100644 --- a/test/collective/fleet/testslist.csv +++ b/test/collective/fleet/testslist.csv @@ -23,8 +23,10 @@ test_pipeline,,,160,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_pro test_fleet_utils,LINUX;APPLE,,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_static_model_parallel,,,240,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_dygraph_no_sync,,GPU,300,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL +test_dygraph_dataparallel_bf16,,,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., test_dygraph_sharding_stage2,,,200,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_dygraph_sharding_stage2_bf16,,,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., +test_dygraph_sharding_stage1_bf16,,,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., test_dygraph_sharding_stage1_fp16,,,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_dygraph_control_flow,,,350,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_lars_meta_optimizer,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., @@ -60,7 +62,6 @@ test_parallel_dygraph_sparse_embedding_over_height,,ROCM,350,DIST,../../legacy_t test_distributed_strategy,LINUX;APPLE,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_auto_parallel_parallelizer,,,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_recompute_meta_optimizer,LINUX;WIN32,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_fleet_private_function,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_new_group,,GPU;XPU,,DIST,test_new_group.sh,2,,http_proxy=;https_proxy=, test_c_comm_init_op,LINUX,GPU;XPU,120,DIST,test_c_comm_init_op.sh,2,,http_proxy=;https_proxy=, test_fused_attention_pass_with_mp,LINUX,GPU,120,DIST,test_fused_attention_pass_with_mp.sh,2,,http_proxy=;https_proxy=, diff --git a/test/collective/test_collective_alltoall_api.py b/test/collective/test_collective_alltoall_api.py index 01864126a96e9..21d01075aa729 100644 --- a/test/collective/test_collective_alltoall_api.py +++ b/test/collective/test_collective_alltoall_api.py @@ -57,7 +57,7 @@ def test_alltoall_nccl_with_new_comm(self): "alltoall", "nccl", dtype=dtype, - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_alltoall_nccl_dygraph(self): diff --git a/test/collective/test_collective_barrier_api.py b/test/collective/test_collective_barrier_api.py index 74e5cebc873c1..75b0e80905365 100644 --- a/test/collective/test_collective_barrier_api.py +++ b/test/collective/test_collective_barrier_api.py @@ -33,7 +33,7 @@ def test_barrier_nccl_with_new_comm(self): "collective_barrier_api.py", "barrier", "nccl", - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_barrier_gloo(self): diff --git a/test/collective/test_collective_global_gather.py b/test/collective/test_collective_global_gather.py index c4c2e42c0b561..c5110b6519801 100644 --- a/test/collective/test_collective_global_gather.py +++ b/test/collective/test_collective_global_gather.py @@ -44,7 +44,7 @@ def test_global_gather_nccl_new_comm(self): "collective_global_gather.py", "global_gather", "nccl", - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) diff --git a/test/collective/test_collective_global_scatter.py b/test/collective/test_collective_global_scatter.py index 7eb34abe6cf5a..26a267a98d349 100644 --- a/test/collective/test_collective_global_scatter.py +++ b/test/collective/test_collective_global_scatter.py @@ -43,7 +43,7 @@ def test_global_scatter_nccl_new_comm(self): "collective_global_scatter.py", "global_scatter", "nccl", - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) diff --git a/test/collective/test_collective_reduce_api.py b/test/collective/test_collective_reduce_api.py index 9759b50028835..aafda45aea976 100644 --- a/test/collective/test_collective_reduce_api.py +++ b/test/collective/test_collective_reduce_api.py @@ -78,7 +78,7 @@ def test_reduce_nccl_with_new_comm(self): "nccl", dtype=dtype, reduce_type=red_type, - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_reduce_bkcl(self): diff --git a/test/collective/test_collective_reduce_scatter_api.py b/test/collective/test_collective_reduce_scatter_api.py index 4ec909e8d2b44..bd3dd14df88df 100644 --- a/test/collective/test_collective_reduce_scatter_api.py +++ b/test/collective/test_collective_reduce_scatter_api.py @@ -59,7 +59,7 @@ def test_reduce_scatter_nccl_with_new_comm(self): "reduce_scatter", "nccl", dtype=dtype, - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_reduce_scatter_nccl_dygraph(self): diff --git a/test/collective/test_collective_scatter_api.py b/test/collective/test_collective_scatter_api.py index b21e06c6c75d0..7ac51e99b5593 100644 --- a/test/collective/test_collective_scatter_api.py +++ b/test/collective/test_collective_scatter_api.py @@ -47,7 +47,7 @@ def test_scatter_nccl_with_new_comm(self): "scatter", "nccl", dtype=dtype, - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_scatter_nccl_dygraph(self): diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index 42476d7bb323f..eb6d08542b04a 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include +#include #include "glog/logging.h" #include "gtest/gtest.h" @@ -23,6 +24,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" +#include "paddle/phi/core/distributed/type_defs.h" #include "paddle/phi/infermeta/spmd_rules/replicated.h" #include "paddle/phi/infermeta/spmd_rules/rules.h" @@ -30,6 +32,68 @@ namespace paddle { namespace distributed { namespace auto_parallel { +auto& get_dims_mapping(const phi::distributed::ArgDistAttr& dist_attr) { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)); + const auto& tensor_attr = paddle::get<0>(dist_attr); + return tensor_attr.dims_mapping(); +} + +bool is_partial(const phi::distributed::ArgDistAttr& dist_attr) { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)); + const auto& tensor_attr = paddle::get<0>(dist_attr); + return tensor_attr.is_partial(); +} + +auto get_partial_dims(const phi::distributed::ArgDistAttr& dist_attr) { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)); + const auto& tensor_attr = paddle::get<0>(dist_attr); + return tensor_attr.partial_dims(); +} + +void check_dim_mapping(const phi::distributed::ArgDistAttr& dist_attr, + const std::vector& dim_mapping, + const std::string& line = "") { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)) + << line; + EXPECT_EQ(get_dims_mapping(dist_attr), dim_mapping) << line; +} + +void check_partial_dims(const phi::distributed::ArgDistAttr& dist_attr, + const std::set& dims, + const std::string& line = "") { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)) + << line; + EXPECT_EQ(get_partial_dims(dist_attr), dims) << line; +} + +void clean_partial_status(phi::distributed::ArgDistAttr* dist_attr) { + EXPECT_TRUE( + paddle::holds_alternative(*dist_attr)); + auto& tensor_attr = paddle::get<0>(*dist_attr); + tensor_attr.clean_partial_status(); +} + +void clean_partial_dims(phi::distributed::ArgDistAttr* dist_attr, + std::vector dims) { + EXPECT_TRUE( + paddle::holds_alternative(*dist_attr)); + auto& tensor_attr = paddle::get<0>(*dist_attr); + tensor_attr.clean_partial_dims(dims); +} + +void set_partial_status(phi::distributed::ArgDistAttr* dist_attr, + std::vector dims) { + EXPECT_TRUE( + paddle::holds_alternative(*dist_attr)); + auto& tensor_attr = paddle::get<0>(*dist_attr); + tensor_attr.set_partial_status(dims); +} + TEST(MatmulSPMDRule, Ctor) { // build input data class std::vector x_shape = {64, 32}; @@ -66,14 +130,10 @@ TEST(MatmulSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs.first.size(), input_size); EXPECT_EQ(infered_dist_attrs.second.size(), output_size); - - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + check_dim_mapping(infered_dist_attrs.first[0], {1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; // mk[-1,-1],kn[-1,0] --> mk[-1,-1],kn[-1,0] = nm[-1,0] partial[] @@ -84,15 +144,11 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + check_dim_mapping(infered_dist_attrs.first[0], {-1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, 0}); + check_dim_mapping(infered_dist_attrs.second[0], {-1, 0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); VLOG(4) << "test2 done." << std::endl << std::endl << std::endl; - // mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]: done x_dist_attr.set_dims_mapping({1, 0}); y_dist_attr.set_dims_mapping({-1, -1}); @@ -101,15 +157,11 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, 0})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({0, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); + check_dim_mapping(infered_dist_attrs.first[0], {1, 0}); + check_dim_mapping(infered_dist_attrs.first[1], {0, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); VLOG(4) << "test3 done." << std::endl << std::endl << std::endl; // mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: done @@ -120,15 +172,11 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, 1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({1})); + check_dim_mapping(infered_dist_attrs.first[0], {-1, 1}); + check_dim_mapping(infered_dist_attrs.first[1], {1, 0}); + check_dim_mapping(infered_dist_attrs.second[0], {-1, 0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {1}); VLOG(4) << "test4 done." << std::endl << std::endl << std::endl; // abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] = @@ -141,13 +189,10 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({0, 1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({0, 1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + check_dim_mapping(infered_dist_attrs.first[0], {0, 1, -1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {0, 1, -1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); VLOG(4) << "test5 done." << std::endl << std::endl << std::endl; // abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1, @@ -159,15 +204,11 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1, -1, 0})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({0, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); + check_dim_mapping(infered_dist_attrs.first[0], {1, -1, -1, 0}); + check_dim_mapping(infered_dist_attrs.first[1], {0, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, -1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); VLOG(4) << "test6 done." << std::endl << std::endl << std::endl; // abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = @@ -179,13 +220,12 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/true, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1, -1, 0})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1, 0, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + + check_dim_mapping(infered_dist_attrs.first[0], {1, -1, -1, 0}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, 0, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); + VLOG(4) << "test7 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = @@ -197,17 +237,13 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/true}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, -1, -1, 0})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, -1, -1, 1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); - infered_dist_attrs.second[0].clean_partial_dims(std::vector({0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + check_dim_mapping(infered_dist_attrs.first[0], {-1, -1, -1, 0}); + check_dim_mapping(infered_dist_attrs.first[1], {1, 0}); + check_dim_mapping(infered_dist_attrs.second[0], {-1, -1, -1, 1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); + clean_partial_dims(&infered_dist_attrs.second[0], {0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); VLOG(4) << "test8 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, 0, 1]+trans_x=true, kn[1, 0]+trans_y=true --> abcmk[-1, -1, @@ -219,20 +255,16 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/true, /*trans_x=*/true}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, -1, 0, 1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector( - {-1, 0})); // confilct and should be changed to [-1, 0] - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, -1, 1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); - VLOG(4) << infered_dist_attrs.second[0].to_string(); - infered_dist_attrs.second[0].clean_partial_status(); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); - EXPECT_ANY_THROW(infered_dist_attrs.second[0].set_partial_status( - std::vector({1}))); + + check_dim_mapping(infered_dist_attrs.first[0], {-1, -1, 0, 1}); + check_dim_mapping(infered_dist_attrs.first[1], + {-1, 0}); // confilct and should be changed to [-1, 0] + check_dim_mapping(infered_dist_attrs.second[0], {-1, -1, 1, -1}); + check_partial_dims(infered_dist_attrs.second[0], {0}); + + clean_partial_status(&infered_dist_attrs.second[0]); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); + EXPECT_ANY_THROW(set_partial_status(&infered_dist_attrs.second[0], {1})); VLOG(4) << "test9 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, 1, 0], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = @@ -256,29 +288,21 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/true, /*trans_x=*/true}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, -1, 1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); + check_dim_mapping(infered_dist_attrs.second[0], {-1, -1, 1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); // try to clean partial on a dim which is not partial - EXPECT_ANY_THROW(infered_dist_attrs.second[0].clean_partial_dims( - std::vector({1}))); - + EXPECT_ANY_THROW(clean_partial_dims(&infered_dist_attrs.second[0], {1})); // try to clean partial on a dims which is sharded - EXPECT_ANY_THROW(infered_dist_attrs.second[0].set_partial_status( - std::vector({1}))); + EXPECT_ANY_THROW(set_partial_status(&infered_dist_attrs.second[0], {1})); // clean partial and then re-set again - infered_dist_attrs.second[0].clean_partial_dims(std::vector({0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); - infered_dist_attrs.second[0].set_partial_status(std::vector({0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); - + clean_partial_dims(&infered_dist_attrs.second[0], {0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); + set_partial_status(&infered_dist_attrs.second[0], {0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); VLOG(4) << "test11 done." << std::endl << std::endl << std::endl; } @@ -328,26 +352,18 @@ TEST(LayerNormSPMDRule, Ctor) { bias_dist_attr); phi::distributed::InferSpmdContext ctx({x, scale, bias}, {epsilon, begin_norm_axis}); - std::pair, std::vector> - infered_dist_attrs = layer_norm_rule.InferForward(ctx); + auto infered_dist_attrs = layer_norm_rule.InferForward(ctx); size_t input_size = 3; size_t output_size = 3; EXPECT_EQ(infered_dist_attrs.first.size(), input_size); EXPECT_EQ(infered_dist_attrs.second.size(), output_size); - - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.first[2].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[1].dims_mapping(), - std::vector({1})); - EXPECT_EQ(infered_dist_attrs.second[2].dims_mapping(), - std::vector({1})); + check_dim_mapping(infered_dist_attrs.first[0], {1, -1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1}); + check_dim_mapping(infered_dist_attrs.first[2], {-1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, -1}); + check_dim_mapping(infered_dist_attrs.second[1], {1}); + check_dim_mapping(infered_dist_attrs.second[2], {1}); VLOG(4) << "test1 done."; // ijk[1, 0, -1],k[0],k[0] --> ijk[1, -1, -1],z[1],z[1], @@ -364,18 +380,13 @@ TEST(LayerNormSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext({x, scale, bias}, {epsilon, begin_norm_axis}); infered_dist_attrs = layer_norm_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.first[2].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[1].dims_mapping(), - std::vector({1})); - EXPECT_EQ(infered_dist_attrs.second[2].dims_mapping(), - std::vector({1})); + + check_dim_mapping(infered_dist_attrs.first[0], {1, -1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1}); + check_dim_mapping(infered_dist_attrs.first[2], {-1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, -1}); + check_dim_mapping(infered_dist_attrs.second[1], {1}); + check_dim_mapping(infered_dist_attrs.second[2], {1}); VLOG(4) << "test2 done."; // ijk[0, -1, -1],y[-1],y[1] --> ijk[0, 1, -1], i[0], i[0], y=jk, @@ -392,18 +403,13 @@ TEST(LayerNormSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext({x, scale, bias}, {epsilon, begin_norm_axis}); infered_dist_attrs = layer_norm_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({0, -1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.first[2].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({0, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[1].dims_mapping(), - std::vector({0})); - EXPECT_EQ(infered_dist_attrs.second[2].dims_mapping(), - std::vector({0})); + + check_dim_mapping(infered_dist_attrs.first[0], {0, -1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1}); + check_dim_mapping(infered_dist_attrs.first[2], {-1}); + check_dim_mapping(infered_dist_attrs.second[0], {0, -1, -1}); + check_dim_mapping(infered_dist_attrs.second[1], {0}); + check_dim_mapping(infered_dist_attrs.second[2], {0}); VLOG(4) << "test3 done."; } @@ -449,24 +455,19 @@ TEST(MatmulSPMDRuleInferBackward, Ctor) { // -1] phi::distributed::InferSpmdContext ctx( {x, y, out}, {/*trans_x=*/false, /*trans_x=*/false}); - std::pair, std::vector> - infered_dist_attrs = matmul_spmd_rule.InferBackward(ctx); + auto infered_dist_attrs = matmul_spmd_rule.InferBackward(ctx); size_t input_size = 2; size_t output_size = 1; EXPECT_EQ(infered_dist_attrs.first.size(), input_size); EXPECT_EQ(infered_dist_attrs.second.size(), output_size); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, -1, 1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, -1, 1, -1})); - EXPECT_EQ(infered_dist_attrs.first[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs.first[1].is_partial(), false); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - + check_dim_mapping(infered_dist_attrs.first[0], {-1, -1, 1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {-1, -1, 1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.first[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs.first[1]), false); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; } @@ -524,18 +525,14 @@ TEST(ReplicatedSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - EXPECT_EQ(infered_dist_attrs_st.first[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs_st.second[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.second[1].dims_mapping(), - std::vector({-1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.first[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.first[1].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.second[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.second[1].is_partial(), false); + check_dim_mapping(infered_dist_attrs_st.first[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_st.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs_st.second[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_st.second[1], {-1, -1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs_st.first[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.first[1]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.second[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.second[1]), false); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; @@ -554,15 +551,10 @@ TEST(ReplicatedSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_st.second.size(), output_size); EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - - EXPECT_EQ(infered_dist_attrs_dy.first[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.first[2].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[0].dims_mapping(), - std::vector({-1, -1, -1})); + check_dim_mapping(infered_dist_attrs_dy.first[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs_dy.first[2], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[0], {-1, -1, -1}); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test2 done." << std::endl << std::endl << std::endl; @@ -582,14 +574,10 @@ TEST(ReplicatedSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - EXPECT_EQ(infered_dist_attrs_dy.first[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[0].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[1].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[2].dims_mapping(), - std::vector({-1, -1, -1})); + check_dim_mapping(infered_dist_attrs_dy.first[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[0], {-1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[1], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[2], {-1, -1, -1}); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test3 done." << std::endl << std::endl << std::endl; @@ -649,19 +637,15 @@ TEST(DefaultDataParallelSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_st.second.size(), output_size); EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); + check_dim_mapping(infered_dist_attrs_st.first[0], {0, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_st.first[1], {0, -1}); + check_dim_mapping(infered_dist_attrs_st.second[0], {0, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_st.second[1], {0, -1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs_st.first[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.first[1]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.second[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.second[1]), false); - EXPECT_EQ(infered_dist_attrs_st.first[0].dims_mapping(), - std::vector({0, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.first[1].dims_mapping(), - std::vector({0, -1})); - EXPECT_EQ(infered_dist_attrs_st.second[0].dims_mapping(), - std::vector({0, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.second[1].dims_mapping(), - std::vector({0, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.first[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.first[1].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.second[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.second[1].is_partial(), false); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; @@ -682,14 +666,11 @@ TEST(DefaultDataParallelSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - EXPECT_EQ(infered_dist_attrs_dy.first[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[0].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[1].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[2].dims_mapping(), - std::vector({-1, -1, -1})); + check_dim_mapping(infered_dist_attrs_dy.first[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[0], {-1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[1], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[2], {-1, -1, -1}); + EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test2 done." << std::endl << std::endl << std::endl; @@ -735,19 +716,101 @@ TEST(DefaultDataParallelSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_st.second.size(), output_size); EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - - EXPECT_EQ(infered_dist_attrs_dy.first[0].dims_mapping(), - std::vector({0, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.first[1].dims_mapping(), - std::vector({0, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[0].dims_mapping(), - std::vector({0, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[1].dims_mapping(), - std::vector({0, -1, -1})); + check_dim_mapping(infered_dist_attrs_dy.first[0], {0, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.first[1], {0, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[0], {0, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[1], {0, -1, -1}); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test4 done." << std::endl << std::endl << std::endl; } +TEST(ConcatRule, Ctor) { + std::vector mesh_shape = {2, 2}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + std::vector> shapes = { + {16, 16, 16}, {4, 16, 16}, {2, 16, 16}}; + std::vector> dim_mappings = { + {-1, 0, 1}, {-1, 1, 0}, {-1, -1, 0}}; + std::vector> partial_status = {{}, {}, {1}}; + + auto build_inputs = [&] { + std::vector inputs; + for (int i = 0; i < 3; i++) { + auto t_dist_attr = TensorDistAttr(); + t_dist_attr.set_process_mesh(process_mesh); + t_dist_attr.set_dims_mapping(dim_mappings[i]); + t_dist_attr.set_dynamic_dims({false, false, false}); + auto input = phi::distributed::DistMetaTensor(phi::make_ddim(shapes[i]), + t_dist_attr); + inputs.push_back(input); + } + return inputs; + }; + + // test 1, inputs are aligned according to cost, and partial status is cleared + auto inputs = build_inputs(); + auto infered_dist_attrs = phi::distributed::ConcatInferSpmd(inputs, 0); + // list of tensor => sigle tensor + EXPECT_EQ(infered_dist_attrs.first.size(), static_cast(1)); + EXPECT_EQ(infered_dist_attrs.second.size(), static_cast(1)); + EXPECT_TRUE( + paddle::holds_alternative>( + infered_dist_attrs.first[0])); + EXPECT_TRUE(paddle::holds_alternative( + infered_dist_attrs.second[0])); + auto& inputs_infer1 = paddle::get<1>(infered_dist_attrs.first[0]); + for (auto e : inputs_infer1) { + check_dim_mapping(e, {-1, 1, 0}); + check_partial_dims(e, {}); + } + check_dim_mapping(infered_dist_attrs.second[0], {-1, 1, 0}); + check_partial_dims(infered_dist_attrs.second[0], {}); + + // test 2,force replicate along concat axis + inputs = build_inputs(); + infered_dist_attrs = phi::distributed::ConcatInferSpmd(inputs, 1); + // list of tensor => sigle tensor + EXPECT_EQ(infered_dist_attrs.first.size(), static_cast(1)); + EXPECT_EQ(infered_dist_attrs.second.size(), static_cast(1)); + EXPECT_TRUE( + paddle::holds_alternative>( + infered_dist_attrs.first[0])); + EXPECT_TRUE(paddle::holds_alternative( + infered_dist_attrs.second[0])); + auto& inputs_infer2 = paddle::get<1>(infered_dist_attrs.first[0]); + for (auto e : inputs_infer2) { + check_dim_mapping(e, {1, -1, 0}); + check_partial_dims(e, {}); + } + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, 0}); + check_partial_dims(infered_dist_attrs.second[0], {}); +} +TEST(Util, Ctor) { + // test equal test not equal + using phi::distributed::PartialStatus; + using phi::distributed::PlacementEqual; + using phi::distributed::ReplicatedStatus; + using phi::distributed::ShardStatus; + auto a = std::make_shared(phi::ReduceType::kRedSum); + auto b = std::make_shared(phi::ReduceType::kRedMin); + EXPECT_TRUE(PlacementEqual(a, a)); + EXPECT_TRUE(!PlacementEqual(a, b)); + auto c = std::make_shared(0); + auto d = std::make_shared(1); + EXPECT_TRUE(!PlacementEqual(a, c)); + EXPECT_TRUE(!PlacementEqual(b, c)); + EXPECT_TRUE(PlacementEqual(c, c)); + EXPECT_TRUE(!PlacementEqual(c, d)); + auto e = std::make_shared(); + EXPECT_TRUE(PlacementEqual(e, e)); + EXPECT_TRUE(!PlacementEqual(a, e)); + EXPECT_TRUE(!PlacementEqual(b, e)); + EXPECT_TRUE(!PlacementEqual(c, e)); + EXPECT_TRUE(!PlacementEqual(d, e)); +} } // namespace auto_parallel } // namespace distributed diff --git a/test/cpp/eager/data_structure_tests/CMakeLists.txt b/test/cpp/eager/data_structure_tests/CMakeLists.txt index c57ba405881dd..20676d5ae4aaa 100755 --- a/test/cpp/eager/data_structure_tests/CMakeLists.txt +++ b/test/cpp/eager/data_structure_tests/CMakeLists.txt @@ -1,58 +1,30 @@ if(WITH_CINN) set(eager_deps ${eager_deps} cinn_compiler python) endif() -cc_test_old( +cc_test( test_egr_ds_eager_tensor - SRCS - eager_tensor_test.cc - DEPS - fleet_executor - final_dygraph_function - ${eager_deps}) -cc_test_old( + SRCS eager_tensor_test.cc + DEPS fleet_executor final_dygraph_function ${eager_deps}) +cc_test( test_egr_ds_auotgrad_meta - SRCS - autograd_meta_test.cc - DEPS - fleet_executor - final_dygraph_function - ${eager_deps}) + SRCS autograd_meta_test.cc + DEPS fleet_executor final_dygraph_function ${eager_deps}) if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) - cc_test_old( + cc_test( test_egr_ds_grad_tensor_holder - SRCS - grad_tensor_holder_test.cc - DEPS - fleet_executor - conditional_block_op - ${eager_deps} - ${generated_deps}) - cc_test_old( + SRCS grad_tensor_holder_test.cc + DEPS fleet_executor conditional_block_op ${eager_deps} ${generated_deps}) + cc_test( test_egr_ds_grad_node_info - SRCS - grad_node_info_test.cc - DEPS - fleet_executor - conditional_block_op - ${eager_deps} - ${generated_deps}) - cc_test_old( + SRCS grad_node_info_test.cc + DEPS fleet_executor conditional_block_op ${eager_deps} ${generated_deps}) + cc_test( test_egr_ds_accumulation_node - SRCS - accumulation_node_test.cc - DEPS - fleet_executor - conditional_block_op - ${eager_deps} - ${generated_deps}) - cc_test_old( + SRCS accumulation_node_test.cc + DEPS fleet_executor conditional_block_op ${eager_deps} ${generated_deps}) + cc_test( test_egr_ds_tensor_wrapper - SRCS - tensor_wrapper_test.cc - DEPS - fleet_executor - conditional_block_op - ${eager_deps} - ${generated_deps}) + SRCS tensor_wrapper_test.cc + DEPS fleet_executor conditional_block_op ${eager_deps} ${generated_deps}) endif() diff --git a/test/cpp/fluid/CMakeLists.txt b/test/cpp/fluid/CMakeLists.txt index ca62b7c1c7c03..324043b0746fe 100644 --- a/test/cpp/fluid/CMakeLists.txt +++ b/test/cpp/fluid/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(benchmark) +add_subdirectory(framework) if(WITH_CINN) add_subdirectory(cinn) endif() diff --git a/test/cpp/fluid/cinn/CMakeLists.txt b/test/cpp/fluid/cinn/CMakeLists.txt index 0feb905a83902..96c38feb32ba7 100644 --- a/test/cpp/fluid/cinn/CMakeLists.txt +++ b/test/cpp/fluid/cinn/CMakeLists.txt @@ -46,7 +46,13 @@ if(WITH_TESTING) elementwise_add_op paddle_flags) target_link_libraries(cinn_instruction_run_op_test ${PYTHON_LIBRARIES}) - set_tests_properties( - cinn_instruction_run_op_test PROPERTIES LABELS "RUN_TYPE=CINN" ENVIRONMENT - "${CINN_RUN_ENVIRONMENT}") + + get_property( + env + TEST cinn_instruction_run_op_test + PROPERTY ENVIRONMENT) + set_property(TEST cinn_instruction_run_op_test + PROPERTY ENVIRONMENT "${CINN_RUN_ENVIRONMENT}" ${env}) + set_tests_properties(cinn_instruction_run_op_test PROPERTIES LABELS + "RUN_TYPE=CINN") endif() diff --git a/test/cpp/fluid/framework/CMakeLists.txt b/test/cpp/fluid/framework/CMakeLists.txt new file mode 100644 index 0000000000000..663dae547625b --- /dev/null +++ b/test/cpp/fluid/framework/CMakeLists.txt @@ -0,0 +1,289 @@ +# add_subdirectory(details) + +cc_test( + data_type_test + SRCS data_type_test.cc + DEPS data_type place tensor) + +cc_test( + tensor_test + SRCS tensor_test.cc + DEPS tensor isfinite_op) +if(WITH_GPU) + nv_test( + tensor_util_test + SRCS tensor_util_test.cc tensor_util_test.cu + DEPS tensor dlpack_tensor isfinite_op) +elseif(WITH_ROCM) + hip_test( + tensor_util_test + SRCS tensor_util_test.cc tensor_util_test.cu + DEPS tensor dlpack_tensor isfinite_op) +else() + cc_test( + tensor_util_test + SRCS tensor_util_test.cc + DEPS tensor dlpack_tensor isfinite_op) +endif() + +cc_test( + copy_same_tensor_test + SRCS copy_same_tensor_test.cc + DEPS tensor) + +cc_test( + eigen_test + SRCS eigen_test.cc + DEPS tensor) + +cc_test( + lod_tensor_test + SRCS lod_tensor_test.cc + DEPS phi lod_tensor memory) + +if(WITH_GPU) + nv_test( + lod_tensor_gpu_test + SRCS lod_tensor_test.cu + DEPS lod_tensor) +elseif(WITH_ROCM) + hip_test( + lod_tensor_gpu_test + SRCS lod_tensor_test.cu + DEPS lod_tensor) +endif() + +cc_test( + reader_test + SRCS reader_test.cc + DEPS reader) + +cc_test( + threadpool_test + SRCS threadpool_test.cc + DEPS phi) + +cc_test( + var_type_traits_test + SRCS var_type_traits_test.cc + DEPS var_type_traits) + +cc_test( + device_worker_test + SRCS device_worker_test.cc + DEPS device_worker) + +cc_test( + scope_test + SRCS scope_test.cc + DEPS scope) + +cc_test( + variable_test + SRCS variable_test.cc + DEPS tensor var_type_traits) + +if(WITH_GPU) + nv_test( + data_device_transform_test + SRCS data_device_transform_test.cu + DEPS operator op_registry device_context phi scope) +elseif(WITH_ROCM) + hip_test( + data_device_transform_test + SRCS data_device_transform_test.cu + DEPS operator op_registry device_context phi scope) +endif() + +if(WITH_GPU) + nv_test( + data_type_transform_test + SRCS data_type_transform_test.cc data_type_transform_test.cu + DEPS data_type_transform) +elseif(WITH_ROCM) + hip_test( + data_type_transform_test + SRCS data_type_transform_test.cc data_type_transform_test.cu + DEPS data_type_transform) +elseif(WITH_XPU) + cc_test( + data_type_transform_test + SRCS data_type_transform_test.cc + DEPS data_type_transform) +else() + cc_test( + data_type_transform_test + SRCS data_type_transform_test.cc + DEPS data_type_transform) +endif() + +cc_test( + data_layout_transform_test + SRCS data_layout_transform_test.cc + DEPS data_layout_transform) + +cc_test( + attribute_test + SRCS attribute_test.cc + DEPS attribute framework_proto proto_desc) + +cc_test( + program_desc_test + SRCS program_desc_test.cc + DEPS proto_desc device_context) + +cc_test( + op_desc_test + SRCS op_desc_test.cc + DEPS proto_desc) + +cc_test( + op_version_registry_test + SRCS op_version_registry_test.cc + DEPS op_version_registry) + +cc_test( + op_proto_maker_test + SRCS op_proto_maker_test.cc + DEPS op_proto_maker) + +cc_test( + no_need_buffer_vars_inference_test + SRCS no_need_buffer_vars_inference_test.cc + DEPS no_need_buffer_vars_inference layer) + +cc_test( + operator_test + SRCS operator_test.cc + DEPS operator op_registry device_context) +cc_test( + operator_exception_test + SRCS operator_exception_test.cc + DEPS operator op_registry device_context) + +cc_test( + version_test + SRCS version_test.cc + DEPS version) + +cc_test( + op_call_stack_test + SRCS op_call_stack_test.cc + DEPS op_call_stack) + +cc_test( + program_utils_test + SRCS program_utils_test.cc + DEPS proto_desc program_utils) + +if(WITH_GPU) + nv_test( + op_registry_test + SRCS op_registry_test.cc + DEPS op_registry) +elseif(WITH_ROCM) + hip_test( + op_registry_test + SRCS op_registry_test.cc + DEPS op_registry) +endif() + +if(WITH_PSCORE) + get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) + if(WITH_HETERPS) + cc_test( + dist_multi_trainer_test + SRCS dist_multi_trainer_test.cc + DEPS conditional_block_op executor gloo_wrapper ${RPC_DEPS} + graph_gpu_wrapper) + cc_test( + heter_pipeline_trainer_test + SRCS heter_pipeline_trainer_test.cc + DEPS conditional_block_op + generated_op + heter_listen_and_serv_op + executor + heter_server + gloo_wrapper + phi + ${RPC_DEPS} + graph_gpu_wrapper) + else() + cc_test( + dist_multi_trainer_test + SRCS dist_multi_trainer_test.cc + DEPS conditional_block_op executor gloo_wrapper ${RPC_DEPS}) + cc_test( + heter_pipeline_trainer_test + SRCS heter_pipeline_trainer_test.cc + DEPS conditional_block_op + generated_op + heter_listen_and_serv_op + executor + heter_server + gloo_wrapper + phi + ${RPC_DEPS}) + endif() +else() + cc_test( + dist_multi_trainer_test + SRCS dist_multi_trainer_test.cc + DEPS conditional_block_op executor gloo_wrapper) +endif() + +cc_test( + prune_test + SRCS prune_test.cc + DEPS op_info prune recurrent_op device_context) +cc_test( + var_type_inference_test + SRCS var_type_inference_test.cc + DEPS op_registry proto_desc) + +cc_test( + selected_rows_utils_test + SRCS selected_rows_utils_test.cc + DEPS selected_rows_utils) + +cc_test( + op_kernel_type_test + SRCS op_kernel_type_test.cc + DEPS place device_context framework_proto op_kernel_type) +cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) + +cc_test(tuple_test SRCS tuple_test.cc) + +cc_test(inlined_vector_test SRCS inlined_vector_test.cc) + +cc_test( + dlpack_tensor_test + SRCS dlpack_tensor_test.cc + DEPS dlpack_tensor glog) + +cc_test_old( + op_compatible_info_test + SRCS + op_compatible_info_test.cc + DEPS + op_compatible_info + proto_desc + string_helper + glog) + +cc_test( + infershape_utils_test + SRCS infershape_utils_test.cc + DEPS infershape_utils phi) + +if(WITH_TESTING AND TEST selected_rows_utils_test) + set_tests_properties(selected_rows_utils_test PROPERTIES TIMEOUT 120) +endif() + +cc_test(scope_guard_test SRCS scope_guard_test.cc) +cc_test( + phi_utils_test + SRCS phi_utils_test.cc + DEPS phi_utils) + +cc_test(convert_utils_test SRCS convert_utils_test.cc) diff --git a/paddle/fluid/framework/attribute_test.cc b/test/cpp/fluid/framework/attribute_test.cc similarity index 100% rename from paddle/fluid/framework/attribute_test.cc rename to test/cpp/fluid/framework/attribute_test.cc diff --git a/paddle/fluid/framework/convert_utils_test.cc b/test/cpp/fluid/framework/convert_utils_test.cc similarity index 100% rename from paddle/fluid/framework/convert_utils_test.cc rename to test/cpp/fluid/framework/convert_utils_test.cc diff --git a/paddle/fluid/framework/copy_same_tensor_test.cc b/test/cpp/fluid/framework/copy_same_tensor_test.cc similarity index 100% rename from paddle/fluid/framework/copy_same_tensor_test.cc rename to test/cpp/fluid/framework/copy_same_tensor_test.cc diff --git a/paddle/fluid/framework/data_device_transform_test.cu b/test/cpp/fluid/framework/data_device_transform_test.cu similarity index 100% rename from paddle/fluid/framework/data_device_transform_test.cu rename to test/cpp/fluid/framework/data_device_transform_test.cu diff --git a/paddle/fluid/framework/data_feed_test.cc b/test/cpp/fluid/framework/data_feed_test.cc similarity index 100% rename from paddle/fluid/framework/data_feed_test.cc rename to test/cpp/fluid/framework/data_feed_test.cc diff --git a/paddle/fluid/framework/data_layout_transform_test.cc b/test/cpp/fluid/framework/data_layout_transform_test.cc similarity index 100% rename from paddle/fluid/framework/data_layout_transform_test.cc rename to test/cpp/fluid/framework/data_layout_transform_test.cc diff --git a/paddle/fluid/framework/data_type_test.cc b/test/cpp/fluid/framework/data_type_test.cc similarity index 100% rename from paddle/fluid/framework/data_type_test.cc rename to test/cpp/fluid/framework/data_type_test.cc diff --git a/paddle/fluid/framework/data_type_transform_test.cc b/test/cpp/fluid/framework/data_type_transform_test.cc similarity index 100% rename from paddle/fluid/framework/data_type_transform_test.cc rename to test/cpp/fluid/framework/data_type_transform_test.cc diff --git a/paddle/fluid/framework/data_type_transform_test.cu b/test/cpp/fluid/framework/data_type_transform_test.cu similarity index 100% rename from paddle/fluid/framework/data_type_transform_test.cu rename to test/cpp/fluid/framework/data_type_transform_test.cu diff --git a/paddle/fluid/framework/details/cow_ptr_test.cc b/test/cpp/fluid/framework/details/cow_ptr_test.cc similarity index 100% rename from paddle/fluid/framework/details/cow_ptr_test.cc rename to test/cpp/fluid/framework/details/cow_ptr_test.cc diff --git a/paddle/fluid/framework/device_worker_test.cc b/test/cpp/fluid/framework/device_worker_test.cc similarity index 100% rename from paddle/fluid/framework/device_worker_test.cc rename to test/cpp/fluid/framework/device_worker_test.cc diff --git a/paddle/fluid/framework/dist_multi_trainer_test.cc b/test/cpp/fluid/framework/dist_multi_trainer_test.cc similarity index 100% rename from paddle/fluid/framework/dist_multi_trainer_test.cc rename to test/cpp/fluid/framework/dist_multi_trainer_test.cc diff --git a/paddle/fluid/framework/dlpack_tensor_test.cc b/test/cpp/fluid/framework/dlpack_tensor_test.cc similarity index 100% rename from paddle/fluid/framework/dlpack_tensor_test.cc rename to test/cpp/fluid/framework/dlpack_tensor_test.cc diff --git a/paddle/fluid/framework/eigen_test.cc b/test/cpp/fluid/framework/eigen_test.cc similarity index 100% rename from paddle/fluid/framework/eigen_test.cc rename to test/cpp/fluid/framework/eigen_test.cc diff --git a/paddle/fluid/framework/heter_pipeline_trainer_test.cc b/test/cpp/fluid/framework/heter_pipeline_trainer_test.cc similarity index 100% rename from paddle/fluid/framework/heter_pipeline_trainer_test.cc rename to test/cpp/fluid/framework/heter_pipeline_trainer_test.cc diff --git a/paddle/fluid/framework/infershape_utils_test.cc b/test/cpp/fluid/framework/infershape_utils_test.cc similarity index 100% rename from paddle/fluid/framework/infershape_utils_test.cc rename to test/cpp/fluid/framework/infershape_utils_test.cc diff --git a/paddle/fluid/framework/inlined_vector_test.cc b/test/cpp/fluid/framework/inlined_vector_test.cc similarity index 100% rename from paddle/fluid/framework/inlined_vector_test.cc rename to test/cpp/fluid/framework/inlined_vector_test.cc diff --git a/paddle/fluid/framework/lod_tensor_test.cc b/test/cpp/fluid/framework/lod_tensor_test.cc similarity index 100% rename from paddle/fluid/framework/lod_tensor_test.cc rename to test/cpp/fluid/framework/lod_tensor_test.cc diff --git a/paddle/fluid/framework/lod_tensor_test.cu b/test/cpp/fluid/framework/lod_tensor_test.cu similarity index 100% rename from paddle/fluid/framework/lod_tensor_test.cu rename to test/cpp/fluid/framework/lod_tensor_test.cu diff --git a/paddle/fluid/framework/naive_executor_test.cc b/test/cpp/fluid/framework/naive_executor_test.cc similarity index 100% rename from paddle/fluid/framework/naive_executor_test.cc rename to test/cpp/fluid/framework/naive_executor_test.cc diff --git a/paddle/fluid/framework/no_need_buffer_vars_inference_test.cc b/test/cpp/fluid/framework/no_need_buffer_vars_inference_test.cc similarity index 100% rename from paddle/fluid/framework/no_need_buffer_vars_inference_test.cc rename to test/cpp/fluid/framework/no_need_buffer_vars_inference_test.cc diff --git a/paddle/fluid/framework/op_call_stack_test.cc b/test/cpp/fluid/framework/op_call_stack_test.cc similarity index 100% rename from paddle/fluid/framework/op_call_stack_test.cc rename to test/cpp/fluid/framework/op_call_stack_test.cc diff --git a/paddle/fluid/framework/op_compatible_info_test.cc b/test/cpp/fluid/framework/op_compatible_info_test.cc similarity index 100% rename from paddle/fluid/framework/op_compatible_info_test.cc rename to test/cpp/fluid/framework/op_compatible_info_test.cc diff --git a/paddle/fluid/framework/op_desc_test.cc b/test/cpp/fluid/framework/op_desc_test.cc similarity index 100% rename from paddle/fluid/framework/op_desc_test.cc rename to test/cpp/fluid/framework/op_desc_test.cc diff --git a/paddle/fluid/framework/op_kernel_type_test.cc b/test/cpp/fluid/framework/op_kernel_type_test.cc similarity index 100% rename from paddle/fluid/framework/op_kernel_type_test.cc rename to test/cpp/fluid/framework/op_kernel_type_test.cc diff --git a/paddle/fluid/framework/op_proto_maker_test.cc b/test/cpp/fluid/framework/op_proto_maker_test.cc similarity index 100% rename from paddle/fluid/framework/op_proto_maker_test.cc rename to test/cpp/fluid/framework/op_proto_maker_test.cc diff --git a/paddle/fluid/framework/op_registry_test.cc b/test/cpp/fluid/framework/op_registry_test.cc similarity index 100% rename from paddle/fluid/framework/op_registry_test.cc rename to test/cpp/fluid/framework/op_registry_test.cc diff --git a/paddle/fluid/framework/op_version_registry_test.cc b/test/cpp/fluid/framework/op_version_registry_test.cc similarity index 100% rename from paddle/fluid/framework/op_version_registry_test.cc rename to test/cpp/fluid/framework/op_version_registry_test.cc diff --git a/paddle/fluid/framework/operator_exception_test.cc b/test/cpp/fluid/framework/operator_exception_test.cc similarity index 100% rename from paddle/fluid/framework/operator_exception_test.cc rename to test/cpp/fluid/framework/operator_exception_test.cc diff --git a/paddle/fluid/framework/operator_test.cc b/test/cpp/fluid/framework/operator_test.cc similarity index 100% rename from paddle/fluid/framework/operator_test.cc rename to test/cpp/fluid/framework/operator_test.cc diff --git a/paddle/fluid/framework/phi_utils_test.cc b/test/cpp/fluid/framework/phi_utils_test.cc similarity index 100% rename from paddle/fluid/framework/phi_utils_test.cc rename to test/cpp/fluid/framework/phi_utils_test.cc diff --git a/paddle/fluid/framework/program_desc_test.cc b/test/cpp/fluid/framework/program_desc_test.cc similarity index 100% rename from paddle/fluid/framework/program_desc_test.cc rename to test/cpp/fluid/framework/program_desc_test.cc diff --git a/paddle/fluid/framework/program_utils_test.cc b/test/cpp/fluid/framework/program_utils_test.cc similarity index 100% rename from paddle/fluid/framework/program_utils_test.cc rename to test/cpp/fluid/framework/program_utils_test.cc diff --git a/paddle/fluid/framework/prune_test.cc b/test/cpp/fluid/framework/prune_test.cc similarity index 100% rename from paddle/fluid/framework/prune_test.cc rename to test/cpp/fluid/framework/prune_test.cc diff --git a/paddle/fluid/framework/reader_test.cc b/test/cpp/fluid/framework/reader_test.cc similarity index 100% rename from paddle/fluid/framework/reader_test.cc rename to test/cpp/fluid/framework/reader_test.cc diff --git a/paddle/fluid/framework/scope_guard_test.cc b/test/cpp/fluid/framework/scope_guard_test.cc similarity index 100% rename from paddle/fluid/framework/scope_guard_test.cc rename to test/cpp/fluid/framework/scope_guard_test.cc diff --git a/paddle/fluid/framework/scope_test.cc b/test/cpp/fluid/framework/scope_test.cc similarity index 100% rename from paddle/fluid/framework/scope_test.cc rename to test/cpp/fluid/framework/scope_test.cc diff --git a/paddle/fluid/framework/selected_rows_utils_test.cc b/test/cpp/fluid/framework/selected_rows_utils_test.cc similarity index 100% rename from paddle/fluid/framework/selected_rows_utils_test.cc rename to test/cpp/fluid/framework/selected_rows_utils_test.cc diff --git a/paddle/fluid/framework/tensor_test.cc b/test/cpp/fluid/framework/tensor_test.cc similarity index 100% rename from paddle/fluid/framework/tensor_test.cc rename to test/cpp/fluid/framework/tensor_test.cc diff --git a/paddle/fluid/framework/tensor_util_test.cc b/test/cpp/fluid/framework/tensor_util_test.cc similarity index 100% rename from paddle/fluid/framework/tensor_util_test.cc rename to test/cpp/fluid/framework/tensor_util_test.cc diff --git a/paddle/fluid/framework/tensor_util_test.cu b/test/cpp/fluid/framework/tensor_util_test.cu similarity index 100% rename from paddle/fluid/framework/tensor_util_test.cu rename to test/cpp/fluid/framework/tensor_util_test.cu diff --git a/paddle/fluid/framework/threadpool_test.cc b/test/cpp/fluid/framework/threadpool_test.cc similarity index 100% rename from paddle/fluid/framework/threadpool_test.cc rename to test/cpp/fluid/framework/threadpool_test.cc diff --git a/paddle/fluid/framework/trainer_test.cc b/test/cpp/fluid/framework/trainer_test.cc similarity index 100% rename from paddle/fluid/framework/trainer_test.cc rename to test/cpp/fluid/framework/trainer_test.cc diff --git a/paddle/fluid/framework/tuple_test.cc b/test/cpp/fluid/framework/tuple_test.cc similarity index 100% rename from paddle/fluid/framework/tuple_test.cc rename to test/cpp/fluid/framework/tuple_test.cc diff --git a/paddle/fluid/framework/var_type_inference_test.cc b/test/cpp/fluid/framework/var_type_inference_test.cc similarity index 100% rename from paddle/fluid/framework/var_type_inference_test.cc rename to test/cpp/fluid/framework/var_type_inference_test.cc diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/test/cpp/fluid/framework/var_type_traits_test.cc similarity index 100% rename from paddle/fluid/framework/var_type_traits_test.cc rename to test/cpp/fluid/framework/var_type_traits_test.cc diff --git a/paddle/fluid/framework/variable_test.cc b/test/cpp/fluid/framework/variable_test.cc similarity index 100% rename from paddle/fluid/framework/variable_test.cc rename to test/cpp/fluid/framework/variable_test.cc diff --git a/paddle/fluid/framework/version_test.cc b/test/cpp/fluid/framework/version_test.cc similarity index 100% rename from paddle/fluid/framework/version_test.cc rename to test/cpp/fluid/framework/version_test.cc diff --git a/test/cpp/fluid/mkldnn/CMakeLists.txt b/test/cpp/fluid/mkldnn/CMakeLists.txt index 3d5883dabfbf8..f83fd91963be2 100644 --- a/test/cpp/fluid/mkldnn/CMakeLists.txt +++ b/test/cpp/fluid/mkldnn/CMakeLists.txt @@ -83,3 +83,18 @@ else() cc_test_old(test_mkldnn_op_nhwc SRCS test_mkldnn_op_nhwc.cc DEPS ${paddle_lib} python) endif() + +cc_test( + test_mkldnn_pool_adaptive_op + SRCS test_mkldnn_pool_adaptive_op.cc + DEPS fleet_executor + conditional_block_op + standalone_executor + executor + op_registry + generated_static_op + generated_op + phi + scope + device_context + enforce) diff --git a/test/cpp/fluid/mkldnn/test_conv_mkldnn_nhwc.cc b/test/cpp/fluid/mkldnn/test_conv_mkldnn_nhwc.cc index ecc5ce726b2d8..4dfc4a731bff2 100644 --- a/test/cpp/fluid/mkldnn/test_conv_mkldnn_nhwc.cc +++ b/test/cpp/fluid/mkldnn/test_conv_mkldnn_nhwc.cc @@ -19,7 +19,6 @@ #include "gtest/gtest.h" #include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" @@ -109,3 +108,92 @@ TEST(test_conv2d_output, int8) { op->Run(scope, cpu_place); } +TEST(test_conv2d_output, ic1) { + paddle::framework::Scope scope; + paddle::platform::CPUPlace cpu_place; + + paddle::framework::OpDesc conv2d_op(nullptr); + conv2d_op.SetType("conv2d"); + conv2d_op.SetInput("Input", {"conv2d-X"}); + conv2d_op.SetInput("Filter", {"conv2d-Y"}); + conv2d_op.SetOutput("Output", {"conv2d-Out"}); + + AddVarToScope("conv2d-X", &scope, {1, 1, 224, 224}); + AddVarToScope("conv2d-Y", &scope, {64, 1, 7, 7}); + AddVarToScope("conv2d-Out", &scope, {1, 64, 218, 218}); + + const std::vector strides({1, 1}); + const std::vector paddings({1, 1}); + const std::vector dilations({1, 1}); + const int groups = 1; + + conv2d_op.SetAttr("strides", strides); + conv2d_op.SetAttr("paddings", paddings); + conv2d_op.SetAttr("dilations", dilations); + conv2d_op.SetAttr("groups", groups); + conv2d_op.SetAttr("use_mkldnn", true); + + auto op = paddle::framework::OpRegistry::CreateOp(conv2d_op); + + op->Run(scope, cpu_place); +} + +TEST(test_conv2d_output, ic2) { + paddle::framework::Scope scope; + paddle::platform::CPUPlace cpu_place; + + paddle::framework::OpDesc conv2d_op(nullptr); + conv2d_op.SetType("conv2d"); + conv2d_op.SetInput("Input", {"conv2d-X"}); + conv2d_op.SetInput("Filter", {"conv2d-Y"}); + conv2d_op.SetOutput("Output", {"conv2d-Out"}); + + AddVarToScope("conv2d-X", &scope, {1, 2, 224, 224}); + AddVarToScope("conv2d-Y", &scope, {64, 2, 7, 7}); + AddVarToScope("conv2d-Out", &scope, {1, 64, 218, 218}); + + const std::vector strides({1, 1}); + const std::vector paddings({1, 1}); + const std::vector dilations({1, 1}); + const int groups = 1; + + conv2d_op.SetAttr("strides", strides); + conv2d_op.SetAttr("paddings", paddings); + conv2d_op.SetAttr("dilations", dilations); + conv2d_op.SetAttr("groups", groups); + conv2d_op.SetAttr("use_mkldnn", true); + + auto op = paddle::framework::OpRegistry::CreateOp(conv2d_op); + + op->Run(scope, cpu_place); +} + +TEST(test_conv2d_output, ic4) { + paddle::framework::Scope scope; + paddle::platform::CPUPlace cpu_place; + + paddle::framework::OpDesc conv2d_op(nullptr); + conv2d_op.SetType("conv2d"); + conv2d_op.SetInput("Input", {"conv2d-X"}); + conv2d_op.SetInput("Filter", {"conv2d-Y"}); + conv2d_op.SetOutput("Output", {"conv2d-Out"}); + + AddVarToScope("conv2d-X", &scope, {1, 4, 224, 224}); + AddVarToScope("conv2d-Y", &scope, {64, 4, 7, 7}); + AddVarToScope("conv2d-Out", &scope, {1, 64, 218, 218}); + + const std::vector strides({1, 1}); + const std::vector paddings({1, 1}); + const std::vector dilations({1, 1}); + const int groups = 1; + + conv2d_op.SetAttr("strides", strides); + conv2d_op.SetAttr("paddings", paddings); + conv2d_op.SetAttr("dilations", dilations); + conv2d_op.SetAttr("groups", groups); + conv2d_op.SetAttr("use_mkldnn", true); + + auto op = paddle::framework::OpRegistry::CreateOp(conv2d_op); + + op->Run(scope, cpu_place); +} diff --git a/test/cpp/fluid/mkldnn/test_mkldnn_pool_adaptive_op.cc b/test/cpp/fluid/mkldnn/test_mkldnn_pool_adaptive_op.cc new file mode 100644 index 0000000000000..3e1a9230ec231 --- /dev/null +++ b/test/cpp/fluid/mkldnn/test_mkldnn_pool_adaptive_op.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include + +#include + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/naive_executor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +template +void AddVarToScope(const std::string var_name, + paddle::framework::Scope* scope, + const paddle::framework::DDim& dims) { + std::random_device seed; + std::default_random_engine engine(seed()); + std::uniform_real_distribution dist(0, 100); + + phi::DenseTensor tmp_tensor; + auto* tmp_data = + tmp_tensor.mutable_data(dims, paddle::platform::CPUPlace()); + auto* tensor = scope->Var(var_name)->GetMutable(); + tensor->mutable_data(dims, paddle::platform::CPUPlace()); + for (auto i = 0; i < tensor->numel(); ++i) { + tmp_data[i] = static_cast(dist(engine)); + } + paddle::framework::TensorCopySync( + tmp_tensor, paddle::platform::CPUPlace(), tensor); +} +void test_pool2d(bool adaptive, bool ceil_mode, std::string pool_type = "max") { + framework::Scope scope; + paddle::platform::CPUPlace cpu_place; + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("pool2d"); + desc.SetInput("X", {"pool2d-X"}); + desc.SetOutput("Out", {"pool2d-Out"}); + AddVarToScope("pool2d-X", &scope, {1, 3, 9, 12}); + AddVarToScope("pool2d-Out", &scope, {1, 3, 2, 2}); + std::vector ksize({2, 2}); + std::vector strides({1, 1}); + std::vector paddings({0, 0}); + std::string pooling_t = pool_type; + + desc.SetAttr("pooling_type", pooling_t); + desc.SetAttr("ksize", ksize); + desc.SetAttr("strides", strides); + desc.SetAttr("paddings", paddings); + desc.SetAttr("adaptive", adaptive); + desc.SetAttr("ceil_mode", ceil_mode); + desc.SetAttr("use_mkldnn", true); + + auto op = paddle::framework::OpRegistry::CreateOp(desc); + + op->Run(scope, cpu_place); +} + +TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); } +TEST(Pool2dOpConverter, adaptive) { test_pool2d(true, false); } + +TEST(Pool2dOpConverter, max_ceil_test) { test_pool2d(false, true); } +TEST(Pool2dOpConverter, avg_ceil_test) { test_pool2d(true, true, "avg"); } + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP_ITSELF(pool2d); +PD_DECLARE_KERNEL(pool2d, OneDNN, ONEDNN); +PD_DECLARE_KERNEL(pool2d, CPU, ALL_LAYOUT); diff --git a/test/cpp/new_executor/CMakeLists.txt b/test/cpp/new_executor/CMakeLists.txt index 435124d87049a..c5906fc0f263e 100644 --- a/test/cpp/new_executor/CMakeLists.txt +++ b/test/cpp/new_executor/CMakeLists.txt @@ -1,8 +1,7 @@ # skip win32 since wget is not installed by default on windows machine. if(NOT WIN32) - paddle_test(standalone_executor_new_ir_test SRCS - standalone_executor_new_ir_test.cc) + paddle_test(standalone_executor_pir_test SRCS standalone_executor_pir_test.cc) endif() set(OPS diff --git a/test/cpp/new_executor/standalone_executor_new_ir_test.cc b/test/cpp/new_executor/standalone_executor_pir_test.cc similarity index 89% rename from test/cpp/new_executor/standalone_executor_new_ir_test.cc rename to test/cpp/new_executor/standalone_executor_pir_test.cc index 28a425dbd4ebe..e83b763428855 100644 --- a/test/cpp/new_executor/standalone_executor_new_ir_test.cc +++ b/test/cpp/new_executor/standalone_executor_pir_test.cc @@ -22,7 +22,7 @@ #include "paddle/phi/core/kernel_registry.h" -#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" +#include "paddle/fluid/framework/new_executor/pir_interpreter.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -35,7 +35,7 @@ #include "paddle/fluid/platform/init_phi.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" -#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" DECLARE_FILE_SYMBOLS(kernel_dialect); @@ -65,20 +65,19 @@ TEST(StandaloneExecutor, run) { paddle::dialect::FullOp op2 = builder.Build( std::vector{2, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); - builder.Build(op1->result(0), op2->result(0)); + auto add_op = + builder.Build(op1->result(0), op2->result(0)); + + std::string out_name = "add_out"; + builder.Build(add_op->result(0), out_name); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; - ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string out_name = os.str() + "_inner_var_2"; test_core.SetSkipGcVars({out_name}); test_core.Run({}); @@ -136,8 +135,10 @@ TEST(StandaloneExecutor, run_feed_tensor) { pir::Operation::Create({}, attr_map2, {dense_tensor_dtype}, feed_op_info); program.block()->push_back(feed_op2); - builder.Build(feed_op1->result(0), - feed_op2->result(0)); + auto add_op = builder.Build(feed_op1->result(0), + feed_op2->result(0)); + std::string out_name = "add_out"; + builder.Build(add_op->result(0), out_name); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); @@ -145,10 +146,6 @@ TEST(StandaloneExecutor, run_feed_tensor) { Scope scope; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string out_name = os.str() + "_inner_var_2"; test_core.SetSkipGcVars({out_name}); phi::DenseTensorMeta meta( @@ -191,16 +188,15 @@ TEST(StandaloneExecutor, run_inplace_sqrt) { builder.Build(full->result(0)); + std::string out_name = "full_out"; + builder.Build(full->result(0), out_name); + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string out_name = os.str() + "_inner_var_0"; test_core.SetSkipGcVars({out_name}); test_core.Run({}); @@ -254,16 +250,16 @@ TEST(StandaloneExecutor, if_op) { std::vector{3}, true, phi::DataType::BOOL); builder.Build(std::vector{full_op_2.out()}); + std::string out_name = "if_out"; + builder.SetInsertionPointToEnd(block); + builder.Build(if_op->result(0), out_name); + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string out_name = os.str() + "_inner_var_1"; test_core.SetSkipGcVars({out_name}); test_core.Run({}); @@ -325,16 +321,15 @@ TEST(StandaloneExecutor, while_op) { builder.SetInsertionPointAfter(while_op); + std::string out_name = "while_out"; + builder.Build(while_op->result(0), out_name); + auto kernel_program = PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string out_name = os.str() + "_inner_var_3"; test_core.SetSkipGcVars({out_name}); test_core.Run({}); diff --git a/test/cpp/new_executor/standalone_executor_test.cc b/test/cpp/new_executor/standalone_executor_test.cc index 51abf47617985..e25f8e0aec99d 100644 --- a/test/cpp/new_executor/standalone_executor_test.cc +++ b/test/cpp/new_executor/standalone_executor_test.cc @@ -147,25 +147,28 @@ ProgramDesc GetLmMainProgram() { TEST(StandaloneExecutor, run) { auto place = platform::CUDAPlace(0); - ProgramDesc startup_prog = load_from_file("lm_startup_program"); - ProgramDesc main_prog = GetLmMainProgram(); + std::shared_ptr p_startup_prog = + std::make_shared(load_from_file("lm_startup_program")); + std::shared_ptr p_main_prog = + std::make_shared(GetLmMainProgram()); Scope scope; std::shared_ptr startup_job = std::make_shared(Job("startup")); StandaloneExecutor startup_exec( place, Plan(std::vector>({startup_job}), - std::unordered_map( - {{startup_job->Type(), &startup_prog}})), + std::unordered_map>( + {{startup_job->Type(), p_startup_prog}})), &scope); startup_exec.Run({}); std::shared_ptr main_job = std::make_shared(Job("main")); - StandaloneExecutor exec(place, - Plan(std::vector>({main_job}), - std::unordered_map( - {{main_job->Type(), &main_prog}})), - &scope); + StandaloneExecutor exec( + place, + Plan(std::vector>({main_job}), + std::unordered_map>( + {{main_job->Type(), p_main_prog}})), + &scope); exec.Run({}); auto start = std::chrono::steady_clock::now(); diff --git a/test/cpp/pir/cinn/CMakeLists.txt b/test/cpp/pir/cinn/CMakeLists.txt index 1440805e7fc98..5e5b8bd81ca73 100644 --- a/test/cpp/pir/cinn/CMakeLists.txt +++ b/test/cpp/pir/cinn/CMakeLists.txt @@ -1,27 +1,38 @@ if(WITH_TESTING AND WITH_CINN) + paddle_test(test_pir_compiler SRCS pir_compiler_test.cc DEPS pir_compiler + cinn_runtime_dialect) + set_tests_properties(test_pir_compiler PROPERTIES LABELS "RUN_TYPE=CINN") + + paddle_test(test_jit_instruction SRCS jit_instruction_test.cc DEPS + cinn_runtime_dialect pir_compiler) + set_tests_properties(test_jit_instruction PROPERTIES LABELS "RUN_TYPE=CINN") + cc_test_old( - test_new_ir_compiler + dialect_convert_test SRCS - new_ir_compiler_test.cc + dialect_convert_test.cc DEPS - new_ir_compiler - convert_to_dialect - cinn_runtime_dialect - pir - phi + drr gtest - glog) - set_tests_properties(test_new_ir_compiler PROPERTIES LABELS "RUN_TYPE=CINN") + pd_to_cinn_pass + op_dialect_vjp + cinn_op_dialect + pir) + set_tests_properties(dialect_convert_test PROPERTIES LABELS "RUN_TYPE=CINN") cc_test_old( - test_jit_instruction + add_broadcast_to_elementwise_test SRCS - jit_instruction_test.cc + add_broadcast_to_elementwise_test.cc DEPS - interpreter - new_ir_compiler - convert_to_dialect) - set_tests_properties(test_jit_instruction PROPERTIES LABELS "RUN_TYPE=CINN") + drr + gtest + pd_to_cinn_pass + op_dialect_vjp + cinn_op_dialect + add_broadcast_to_elementwise_pass + pir) + set_tests_properties(dialect_convert_test PROPERTIES LABELS "RUN_TYPE=CINN") cc_test_old( ir_op_fusion_test @@ -29,17 +40,39 @@ if(WITH_TESTING AND WITH_CINN) ir_op_fusion_test.cc DEPS op_with_group_merge_pass - pd_op_dialect + op_dialect_vjp cinn_op_dialect pir gtest glog) + set_tests_properties(ir_op_fusion_test PROPERTIES LABELS "RUN_TYPE=CINN") - paddle_test(test_group_op SRCS group_op_test.cc DEPS cinn_op_dialect) + paddle_test( + test_pir_all_path + SRCS + pir_all_path_test.cc + DEPS + op_with_group_merge_pass + transform + cinn_op_dialect + pd_to_cinn_pass + add_broadcast_to_elementwise_pass) + set_tests_properties(test_pir_all_path PROPERTIES LABELS "RUN_TYPE=CINN") + + paddle_test( + test_group_op + SRCS + group_op_test.cc + DEPS + pd_to_cinn_pass + add_broadcast_to_elementwise_pass + op_with_group_merge_pass + cinn_op_dialect + transform) set_tests_properties(test_group_op PROPERTIES LABELS "RUN_TYPE=CINN") paddle_test(test_pir_build_cinn_pass SRCS build_cinn_pass_test.cc DEPS - pd_build_cinn_pass) + transform pir) set_tests_properties(test_pir_build_cinn_pass PROPERTIES LABELS "RUN_TYPE=CINN") endif() diff --git a/test/cpp/pir/cinn/add_broadcast_to_elementwise_test.cc b/test/cpp/pir/cinn/add_broadcast_to_elementwise_test.cc new file mode 100644 index 0000000000000..801d5644930e1 --- /dev/null +++ b/test/cpp/pir/cinn/add_broadcast_to_elementwise_test.cc @@ -0,0 +1,160 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_x = + builder.Build(std::vector{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_input_y = builder.Build( + std::vector{16}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + auto add_op = builder.Build(full_input_x.result(0), + full_input_y.result(0)); + + auto relu_op = builder.Build(add_op.result(0)); +} + +void BuildProgramBoth(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_x = + builder.Build(std::vector{10, 1}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_input_y = + builder.Build(std::vector{1, 10}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + auto add_op = builder.Build(full_input_x.result(0), + full_input_y.result(0)); + + auto relu_op = builder.Build(add_op.result(0)); +} + +void BuildProgramSubBoth(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_x = + builder.Build(std::vector{10, 1}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_input_y = + builder.Build(std::vector{1, 10}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + auto sub_op = builder.Build( + full_input_x.result(0), full_input_y.result(0)); + + auto relu_op = builder.Build(sub_op.result(0)); +} + +TEST(PatternRewrite, broadcast_elementwise) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + + pir::PassManager pm(ctx); + pm.AddPass( + std::make_unique()); + + pm.Run(&program); + + auto it = program.block()->begin(); + + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); +} + +TEST(PatternRewrite, broadcast_elementwise_both) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgramBoth(builder); + + pir::PassManager pm(ctx); + pm.AddPass( + std::make_unique()); + + pm.Run(&program); + + auto it = program.block()->begin(); + + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); +} + +TEST(PatternRewrite, broadcast_elementwise_sub_both) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgramSubBoth(builder); + + pir::PassManager pm(ctx); + pm.AddPass( + std::make_unique()); + + pm.Run(&program); + + auto it = program.block()->begin(); + + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); +} diff --git a/test/cpp/pir/cinn/build_cinn_pass_test.cc b/test/cpp/pir/cinn/build_cinn_pass_test.cc index 40fefeb3d2173..e80e88242e0b1 100644 --- a/test/cpp/pir/cinn/build_cinn_pass_test.cc +++ b/test/cpp/pir/cinn/build_cinn_pass_test.cc @@ -23,7 +23,7 @@ limitations under the License. */ #include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/program.h" -#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" #include "paddle/fluid/pir/transforms/build_cinn_pass.h" #include "paddle/pir/pass/pass.h" @@ -78,3 +78,172 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { CHECK_EQ(iter->name(), op_names[index++]); } } + +std::shared_ptr<::pir::Program> BuildNoOpSupportCinnGraph() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + + auto program = std::make_shared<::pir::Program>(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + + // ones -> hardswish -> square -> unsqueeze + const std::vector shape = {64, 128}; + const std::vector axis = {0}; + auto ones_op_x = builder.Build( + shape, phi::DataType::FLOAT32, phi::GPUPlace()); + auto hardswish_op_y = + builder.Build(ones_op_x->result(0)); + auto square_op_y = + builder.Build(hardswish_op_y->result(0)); + auto unsqueeze_op_x = + builder.Build(square_op_y->result(0), axis); + + return program; +} + +TEST(BuildCinnPassTest, NoOpSupportCinn) { + auto origin_program = BuildNoOpSupportCinnGraph(); + pir::IrContext* ctx = pir::IrContext::Instance(); + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateBuildCinnPass()); + pm.EnablePassTiming(); + pm.EnableIRPrinting(); + CHECK_EQ(pm.Run(origin_program.get()), true); + LOG(INFO) << "after pass: " << *origin_program; + + CHECK_EQ(origin_program->block()->size(), 5u); // Because of `FullIntArrayOp` + + std::vector op_names = { + paddle::dialect::OnesOp::name(), + paddle::dialect::HardswishOp::name(), + paddle::dialect::SquareOp::name(), + paddle::dialect::FullIntArrayOp::name(), + paddle::dialect::UnsqueezeOp::name(), + }; + int index = 0; + for (auto iter : *origin_program->block()) { + CHECK_EQ(iter->name(), op_names[index++]); + } +} + +std::shared_ptr<::pir::Program> BuildOneCinnSubgraph() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + + auto program = std::make_shared<::pir::Program>(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + + // full -> acosh -> relu -> square -> unsqueeze + const std::vector axis = {0}; + + const float value_one = 1.0; + const std::vector shape = {64, 128}; + auto full_op_x = builder.Build( + shape, value_one, phi::DataType::FLOAT32, phi::GPUPlace()); + + auto acosh_op_x = + builder.Build(full_op_x->result(0)); + auto relu_op_y = + builder.Build(acosh_op_x->result(0)); + auto square_op_y = + builder.Build(relu_op_y->result(0)); + auto unsqueeze_op_x = + builder.Build(square_op_y->result(0), axis); + return program; +} + +TEST(BuildCinnPassTest, OneCinnSubgraph) { + auto origin_program = BuildOneCinnSubgraph(); + pir::IrContext* ctx = pir::IrContext::Instance(); + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateBuildCinnPass()); + pm.EnablePassTiming(); + pm.EnableIRPrinting(); + CHECK_EQ(pm.Run(origin_program.get()), true); + LOG(INFO) << "after pass: " << *origin_program; + + CHECK_EQ(origin_program->block()->size(), 4u); + pir::Operation* group_op = origin_program->block()->front(); + pir::Block* group_block = + group_op->dyn_cast().block(); + CHECK_EQ(group_block->size(), 4u); + + std::vector op_names = { + paddle::dialect::FullOp::name(), + paddle::dialect::AcoshOp::name(), + paddle::dialect::ReluOp::name(), + pir::YieldOp::name(), + }; + int index = 0; + for (auto iter : *group_block) { + CHECK_EQ(iter->name(), op_names[index++]); + } +} + +std::shared_ptr<::pir::Program> BuildMultiCinnSubgraph() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + + auto program = std::make_shared<::pir::Program>(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + + // full -> acosh -> hardswish -> square -> unsqueeze -> relu + const std::vector axis = {0}; + + const float value_one = 1.0; + const std::vector shape = {64, 128}; + auto full_op_x = builder.Build( + shape, value_one, phi::DataType::FLOAT32, phi::GPUPlace()); + + auto acosh_op_x = + builder.Build(full_op_x->result(0)); + auto hardswish_op_y = + builder.Build(acosh_op_x->result(0)); + auto square_op_y = + builder.Build(hardswish_op_y->result(0)); + auto unsqueeze_op_x = + builder.Build(square_op_y->result(0), axis); + auto relu_op_y = + builder.Build(unsqueeze_op_x->result(0)); + return program; +} + +TEST(BuildCinnPassTest, MultiCinnSubgraph) { + auto origin_program = BuildMultiCinnSubgraph(); + pir::IrContext* ctx = pir::IrContext::Instance(); + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateBuildCinnPass()); + pm.EnablePassTiming(); + pm.EnableIRPrinting(); + CHECK_EQ(pm.Run(origin_program.get()), true); + LOG(INFO) << "after pass: " << *origin_program; + + CHECK_EQ(origin_program->block()->size(), 6u); + pir::Operation* group_op = origin_program->block()->front(); + pir::Block* group_block = + group_op->dyn_cast().block(); + CHECK_EQ(group_block->size(), 3u); + + std::vector op_names_front = { + paddle::dialect::FullOp::name(), + paddle::dialect::AcoshOp::name(), + pir::YieldOp::name(), + }; + int index = 0; + for (auto iter : *group_block) { + CHECK_EQ(iter->name(), op_names_front[index++]); + } + + group_op = origin_program->block()->back(); + group_block = group_op->dyn_cast().block(); + CHECK_EQ(group_block->size(), 2u); + + std::vector op_names_back = { + paddle::dialect::ReluOp::name(), + pir::YieldOp::name(), + }; + index = 0; + for (auto iter : *group_block) { + CHECK_EQ(iter->name(), op_names_back[index++]); + } +} diff --git a/test/cpp/pir/cinn/dialect_convert_test.cc b/test/cpp/pir/cinn/dialect_convert_test.cc new file mode 100644 index 0000000000000..91d52f2adc8bd --- /dev/null +++ b/test/cpp/pir/cinn/dialect_convert_test.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op = + builder.Build(std::vector{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + auto sum_op = + builder.Build(full_input_op.result(0), + std::vector({-1}), + phi::DataType::FLOAT32, + true); + auto relu_op = builder.Build(sum_op.result(0)); + auto exp_op = builder.Build(sum_op.result(0)); +} + +void BuildProgramMax(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op = + builder.Build(std::vector{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + auto max_op = builder.Build( + full_input_op.result(0), std::vector({-1}), true); + auto relu_op = builder.Build(max_op.result(0)); + auto exp_op = builder.Build(max_op.result(0)); +} + +TEST(DrrTest, reduce_sum) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + + cinn::dialect::ir::PdOp2CinnOpConverter(&program); + + auto it = program.block()->begin(); + + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); +} + +TEST(DrrTest, reduce_max) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgramMax(builder); + + cinn::dialect::ir::PdOp2CinnOpConverter(&program); + + auto it = program.block()->begin(); + + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); +} diff --git a/test/cpp/pir/cinn/group_op_test.cc b/test/cpp/pir/cinn/group_op_test.cc index 6e0f05a8cb244..24ebf47a7c84f 100644 --- a/test/cpp/pir/cinn/group_op_test.cc +++ b/test/cpp/pir/cinn/group_op_test.cc @@ -19,13 +19,18 @@ #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h" +#include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" #include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/program.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" -#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" + +bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; } std::vector<::pir::Type> CreateDenseTensorTypes(const phi::DDim& dims) { ::pir::IrContext* ctx = ::pir::IrContext::Instance(); @@ -89,3 +94,148 @@ TEST(GroupOp, TestBuild) { ++i; } } + +std::shared_ptr<::pir::Program> BuildGroupProgramByBlock() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>(); + + auto program = std::make_shared<::pir::Program>(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + + // ------- Group op 1 --------- + const float value_one = 1.0; + const std::vector shape = {64, 128}; + std::unique_ptr<::pir::Block> block1(new ::pir::Block()); + builder.SetInsertionPointToEnd(block1.get()); + auto full_op_x = builder.Build( + shape, value_one, phi::DataType::FLOAT32, phi::GPUPlace()); + builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{full_op_x.out()}); + + builder.SetInsertionPointToEnd(program->block()); + auto group_op1 = builder.Build(std::move(block1)); + + // ------- Group op 2 --------- + std::unique_ptr<::pir::Block> block2(new ::pir::Block()); + builder.SetInsertionPointToEnd(block2.get()); + auto tan_op_x = builder.Build(group_op1->result(0)); + auto relu_op_x = builder.Build(tan_op_x->result(0)); + auto tan_op_y = builder.Build(relu_op_x->result(0)); + auto relu_op_y = builder.Build(tan_op_y->result(0)); + builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{relu_op_y.out()}); + + builder.SetInsertionPointToEnd(program->block()); + auto group_op2 = builder.Build(std::move(block2)); + + return program; +} + +TEST(GroupOp, TestBuildByBlock) { + // Step 1: Construct pir::Program + std::shared_ptr<::pir::Program> program = BuildGroupProgramByBlock(); + std::stringstream ss; + program->Print(ss); + LOG(INFO) << ss.str(); + + EXPECT_EQ(program->block()->size(), 2u); + LOG(INFO) << program->block()->size(); + std::vector op_num = {2, 5}; + int i = 0; + for (auto* sub_op : *(program->block())) { + EXPECT_TRUE(sub_op->isa()); + EXPECT_EQ(sub_op->dyn_cast().ops().size(), + op_num[i]); + ++i; + } +} + +std::shared_ptr<::pir::Program> BuildGroupProgramForLowering() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>(); + + auto program = std::make_shared<::pir::Program>(ctx); + const std::vector shape = {2, 2}; + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + const float value = 0.5; + auto full_x = builder.Build( + shape, value, phi::DataType::FLOAT32, phi::GPUPlace()); + + auto full_y = builder.Build( + shape, value, phi::DataType::FLOAT32, phi::GPUPlace()); + + auto group_op1 = builder.Build( + CreateDenseTensorTypes(phi::make_ddim(shape))); + pir::Block* block1 = group_op1.block(); + builder.SetInsertionPointToEnd(block1); + auto sin = builder.Build(full_x->result(0)); + + builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{ + sin.out(), + }); + + builder.SetInsertionPointToEnd(program->block()); + auto group_op2 = builder.Build( + CreateDenseTensorTypes(phi::make_ddim(shape))); + pir::Block* block2 = group_op2.block(); + builder.SetInsertionPointToEnd(block2); + auto cos_op = builder.Build(full_y->result(0)); + builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{cos_op.out()}); + + builder.SetInsertionPointToEnd(program->block()); + auto group_op3 = builder.Build( + CreateDenseTensorTypes(phi::make_ddim(shape))); + pir::Block* block3 = group_op3.block(); + builder.SetInsertionPointToEnd(block3); + auto add = builder.Build(group_op1->result(0), + group_op2->result(0)); + builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{add.out()}); + + builder.SetInsertionPointToEnd(program->block()); + auto exp = builder.Build(group_op3->result(0)); + + builder.Build(exp.out(), "out", 0); + return program; +} + +TEST(GroupOp, CINNLowering) { + // Step 1: Construct pir::Program + std::shared_ptr<::pir::Program> program = BuildGroupProgramForLowering(); + + auto res = cinn::dialect::ir::CINNGroupLoweringPass(program.get()); + + paddle::platform::Place place = paddle::platform::CUDAPlace(0); + + auto kernel_program = + paddle::dialect::PdOpLowerToKernelPass(res.get(), place); + + paddle::framework::Scope exe_scope; + + paddle::framework::InterpreterCore executor( + place, {"out@fetch"}, kernel_program->block(), &exe_scope); + + std::set out_names; + out_names.insert("out@fetch"); + auto local_names = exe_scope.LocalVarNames(); + for (size_t i = 0; i < local_names.size(); ++i) { + out_names.insert(local_names[i]); + } + + executor.SetSkipGcVars(out_names); + executor.Run({}, true); + + auto out_tensor = + executor.local_scope()->FindVar("out@fetch")->Get(); + + bool res0 = simple_cmp(out_tensor.data()[0], 3.88455); + bool res1 = simple_cmp(out_tensor.data()[1], 3.88455); + bool res2 = simple_cmp(out_tensor.data()[2], 3.88455); + bool res3 = simple_cmp(out_tensor.data()[3], 3.88455); + + EXPECT_EQ(res0, true); + EXPECT_EQ(res1, true); + EXPECT_EQ(res2, true); + EXPECT_EQ(res3, true); +} diff --git a/test/cpp/pir/cinn/ir_op_fusion_test.cc b/test/cpp/pir/cinn/ir_op_fusion_test.cc index a392373358b2a..0233dd1a2ae7f 100644 --- a/test/cpp/pir/cinn/ir_op_fusion_test.cc +++ b/test/cpp/pir/cinn/ir_op_fusion_test.cc @@ -52,9 +52,13 @@ TEST(IROpFusionPass, demo) { auto add = builder.Build(inputs[0], inputs[1]); builder.Build(add.result(0)); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); ASSERT_EQ(res.size(), 1u); + + ASSERT_EQ(res[0]->ops.size(), program.block()->size()); } TEST(IROpFusionPass, ElementWise_Fusion_0) { @@ -75,12 +79,15 @@ TEST(IROpFusionPass, ElementWise_Fusion_0) { auto f = builder.Build(e, inputs[2]).result(0); builder.Build(f, inputs[2]); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(res.size(), 1u); + + ASSERT_EQ(res[0]->ops.size(), program.block()->size()); } // Real test 0 @@ -107,12 +114,15 @@ TEST(IROpFusionPass, Broadcast_Test_0) { builder.Build(e, axes, out_shape).result(0); builder.Build(e1, f); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); + + ASSERT_EQ(res.size(), 1u); - // ASSERT_EQ(res.size(), 1u); + ASSERT_EQ(res[0]->ops.size(), program.block()->size()); } // Real test 1 @@ -138,45 +148,50 @@ TEST(IROpFusionPass, Broadcast_Test_1) { builder.Build(e, axes, out_shape).result(0); builder.Build(inputs[3], e1); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 2u); + ASSERT_EQ(new_group[0]->ops.size(), 2u); + ASSERT_EQ(new_group[1]->ops.size(), 3u); } -// Real test 2 -TEST(IROpFusionPass, Broadcast_Test_2) { - ::pir::IrContext* ctx = ::pir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - ::pir::Program program_base(ctx); - ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); +// FIXME(Aurelius84): Real test 2 +// TEST(IROpFusionPass, Broadcast_Test_2) { +// ::pir::IrContext* ctx = ::pir::IrContext::Instance(); +// ctx->GetOrRegisterDialect(); +// ctx->GetOrRegisterDialect(); +// ::pir::Program program_base(ctx); +// ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); - int h = 32, w = 32; - auto inputs = BuildInput(&builder_base, {{w}, {w}, {w}, {h, w}, {h, w}}); +// int h = 32, w = 32; +// auto inputs = BuildInput(&builder_base, {{w}, {w}, {w}, {h, w}, {h, w}}); - ::pir::Program program(ctx); - ::pir::Builder builder = ::pir::Builder(ctx, program.block()); +// ::pir::Program program(ctx); +// ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - auto f = - builder.Build(inputs[0], inputs[1]).result(0); - builder.Build(inputs[2], f).result(0); - std::vector axes{1}; - std::vector out_shape{h, w}; - auto f1 = - builder.Build(f, axes, out_shape).result(0); - builder.Build(inputs[3], f1); - builder.Build(inputs[4], f1); +// auto f = +// builder.Build(inputs[0], inputs[1]).result(0); +// builder.Build(inputs[2], f).result(0); +// std::vector axes{1}; +// std::vector out_shape{h, w}; +// auto f1 = +// builder.Build(f, axes, +// out_shape).result(0); +// builder.Build(inputs[3], f1); +// builder.Build(inputs[4], f1); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); +// auto res = +// cinn::dialect::ir::OpFusionPassInternal(std::vector( +// program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); +// auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); - ASSERT_EQ(new_group.size(), 2u); -} +// ASSERT_EQ(new_group.size(), 2u); +// } // Real reduce 0 TEST(IROpFusionPass, reduce_test_0) { @@ -199,12 +214,15 @@ TEST(IROpFusionPass, reduce_test_0) { builder.Build(c, axes, true).result(0); builder.Build(c, axes, true).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 1u); + + ASSERT_EQ(new_group[0]->ops.size(), program.block()->size()); } // Real reduce 1 @@ -228,12 +246,15 @@ TEST(IROpFusionPass, reduce_test_1) { builder.Build(c, axes, true).result(0); builder.Build(c, axes1, true).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 2u); + ASSERT_EQ(new_group[0]->ops.size(), 2u); + ASSERT_EQ(new_group[1]->ops.size(), 2u); } // Real reduce 2 @@ -259,12 +280,15 @@ TEST(IROpFusionPass, reduce_test_2) { builder.Build(inputs[2], e).result(0); builder.Build(inputs[2], f).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 2u); + ASSERT_EQ(new_group[0]->ops.size(), 3u); + ASSERT_EQ(new_group[1]->ops.size(), 3u); } // Real reduce 3 @@ -294,51 +318,57 @@ TEST(IROpFusionPass, reduce_test_3) { builder.Build(f, axes1, out_shape).result(0); builder.Build(inputs[2], f1).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 1u); + ASSERT_EQ(new_group[0]->ops.size(), program.block()->size()); } -// Real reduce 4 -TEST(IROpFusionPass, reduce_test_4) { - ::pir::IrContext* ctx = ::pir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - ::pir::Program program_base(ctx); - ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); +// FIXME(Aurelius84): Real reduce 4 +// TEST(IROpFusionPass, reduce_test_4) { +// ::pir::IrContext* ctx = ::pir::IrContext::Instance(); +// ctx->GetOrRegisterDialect(); +// ctx->GetOrRegisterDialect(); +// ::pir::Program program_base(ctx); +// ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); - int h = 32, w = 32; - auto inputs = BuildInput(&builder_base, {{h, w}, {h, w}, {w}, {h, w}}); +// int h = 32, w = 32; +// auto inputs = BuildInput(&builder_base, {{h, w}, {h, w}, {w}, {h, w}}); - ::pir::Program program(ctx); - ::pir::Builder builder = ::pir::Builder(ctx, program.block()); +// ::pir::Program program(ctx); +// ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - std::vector axes{0}; - std::vector axes1{1}; - auto e = - builder.Build(inputs[0], inputs[1]).result(0); - auto f = builder.Build(e, axes, false).result(0); +// std::vector axes{0}; +// std::vector axes1{1}; +// auto e = +// builder.Build(inputs[0], inputs[1]).result(0); +// auto f = builder.Build(e, axes, +// false).result(0); - builder.Build(inputs[2], f).result(0); +// builder.Build(inputs[2], f).result(0); - std::vector out_shape{h, w}; - auto f1 = - builder.Build(f, axes1, out_shape).result(0); - builder.Build(inputs[3], f1).result(0); - auto f2 = - builder.Build(f, axes1, out_shape).result(0); - builder.Build(inputs[3], f2).result(0); +// std::vector out_shape{h, w}; +// auto f1 = +// builder.Build(f, axes1, +// out_shape).result(0); +// builder.Build(inputs[3], f1).result(0); +// auto f2 = +// builder.Build(f, axes1, +// out_shape).result(0); +// builder.Build(inputs[3], f2).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); +// auto res = +// cinn::dialect::ir::OpFusionPassInternal(std::vector( +// program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); +// auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); - ASSERT_EQ(new_group.size(), 1u); -} +// ASSERT_EQ(new_group.size(), 1u); +// } // Real reduce 5 TEST(IROpFusionPass, reduce_test_5) { @@ -362,12 +392,15 @@ TEST(IROpFusionPass, reduce_test_5) { builder.Build(inputs[1], axes, false).result(0); builder.Build(c, axes, false).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 1u); + + ASSERT_EQ(new_group[0]->ops.size(), program.block()->size()); } TEST(IROpFusionPass, layer_norm) { @@ -435,10 +468,61 @@ TEST(IROpFusionPass, layer_norm) { auto t5 = builder.Build(t3, scale).result(0); builder.Build(t5, bias).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 1u); + + ASSERT_EQ(new_group[0]->ops.size(), program.block()->size()); +} + +TEST(IROpFusionPass, softmax) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ::pir::Program program_base(ctx); + ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); + + auto inputs = BuildInput(&builder_base, {{128, 128, 768}}); + + ::pir::Program program(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + + std::vector axes{-1}; + + auto x = inputs[0]; + auto max = builder.Build(x, axes, true).result(0); + auto broadcast_1 = builder + .Build( + max, + std::vector({0, 1, 2}), + std::vector({128, 128, 768})) + .result(0); + auto sub = + builder.Build(x, broadcast_1).result(0); + auto exp = builder.Build(sub).result(0); + auto sum = + builder.Build(exp, axes, true).result(0); + + auto broadcast_2 = builder + .Build( + sum, + std::vector({0, 1, 2}), + std::vector({128, 128, 768})) + .result(0); + auto divide = + builder.Build(exp, broadcast_2).result(0); + + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); + + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); + + ASSERT_EQ(new_group.size(), 1u); + + ASSERT_EQ(new_group[0]->ops.size(), program.block()->size()); } diff --git a/test/cpp/pir/cinn/jit_instruction_test.cc b/test/cpp/pir/cinn/jit_instruction_test.cc index 2996bf17c962a..5291ce0582e2f 100644 --- a/test/cpp/pir/cinn/jit_instruction_test.cc +++ b/test/cpp/pir/cinn/jit_instruction_test.cc @@ -27,11 +27,18 @@ #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/program.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" #include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" -#include "paddle/cinn/hlir/framework/convert_to_dialect.h" -#include "paddle/cinn/hlir/framework/new_ir_compiler.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" #include "paddle/cinn/utils/data_util.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" +#include "paddle/phi/backends/gpu/gpu_context.h" + +bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; } std::unique_ptr<::pir::Program> BuildProgram() { ::pir::IrContext* ctx = ::pir::IrContext::Instance(); @@ -39,18 +46,29 @@ std::unique_ptr<::pir::Program> BuildProgram() { auto program = std::make_unique<::pir::Program>(ctx); ::pir::Builder builder = ::pir::Builder(ctx, program->block()); - const float value = 2.0; + const float value = 0.5; auto full_op_x = - builder.Build(std::vector{64, 128}, + builder.Build(std::vector{2, 2}, value, phi::DataType::FLOAT32, phi::GPUPlace()); auto full_op_y = - builder.Build(std::vector{128, 64}, + builder.Build(std::vector{2, 2}, + value, + phi::DataType::FLOAT32, + phi::GPUPlace()); + auto full_op_z = + builder.Build(std::vector{2, 2}, value, phi::DataType::FLOAT32, phi::GPUPlace()); + + auto sin = builder.Build(full_op_x.result(0)); + auto cos = builder.Build(full_op_y.result(0)); + auto add = + builder.Build(sin.result(0), cos.result(0)); + builder.Build(add.out(), "out", 0); return std::move(program); } @@ -60,43 +78,105 @@ namespace framework { TEST(CinnJitInstruction, Run) { // Step 1: Construct pir::Program std::unique_ptr<::pir::Program> program = BuildProgram(); - EXPECT_EQ(program->block()->size(), 2u); + EXPECT_EQ(program->block()->size(), 7u); // Step 2: Compiler New pir::Program into Runtime Program auto target = cinn::common::DefaultNVGPUTarget(); auto scope = cinn::hlir::framework::BuildScope(target, *program); - ASSERT_EQ(scope->var_names().size(), 2); - cinn::hlir::framework::NewIRCompiler ir_compiler(*program, target, scope); - auto runtime_program = ir_compiler.Build(); + std::vector compiler_list; - // Step 3: Convert into cinn::dialect::RuntimeDialect - std::unique_ptr<::pir::Program> ir_runtime_program = - cinn::hlir::framework::ConvertToRuntimeDialect(*runtime_program); + std::set checking_cinn_ops = {"pd_op.sin", "pd_op.cos"}; - std::set out_names; - for (auto& var_name : scope->var_names()) { - std::string name = {var_name.begin(), var_name.end()}; - out_names.insert(name); + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto ir_program = std::make_unique<::pir::Program>(ctx); + std::string jit_op_name = cinn::dialect::JitKernelOp::name(); + ::pir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name); + + std::unordered_map value_map; + for (auto it = program->block()->begin(); it != program->block()->end(); + ++it) { + if (checking_cinn_ops.count((*it)->name())) { + auto ir_compiler = + new cinn::hlir::framework::PirCompiler(*program, target, scope); + + std::vector<::pir::Operation*> ops = {*it}; + auto group = std::make_shared(ops); + auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group}); + compiler_list.push_back(ir_compiler); + std::unordered_map op_attrs{ + {cinn::dialect::JitKernelOp::kAttrName, + cinn::dialect::CUDAJITInfoAttribute::get(ctx, fn_ptr_res[0])}, + }; + + auto out_type = (*it)->result(0).type(); + + std::vector vec_ins; + + for (size_t i = 0; i < (*it)->num_operands(); ++i) { + vec_ins.push_back(value_map.at((*it)->operand_source(i))); + } + + ::pir::Operation* cinn_op = + ::pir::Operation::Create(vec_ins, op_attrs, {out_type}, op_info); + + value_map[(*it)->result(0)] = cinn_op->result(0); + + ir_program->block()->push_back(cinn_op); + } else { + std::vector vec_ins; + + for (size_t i = 0; i < (*it)->num_operands(); ++i) { + vec_ins.push_back(value_map.at((*it)->operand_source(i))); + } + + auto type1 = (*it)->result(0).type(); + ::pir::OpInfo info1 = ctx->GetRegisteredOpInfo((*it)->name()); + ::pir::Operation* op = ::pir::Operation::Create( + vec_ins, (*it)->attributes(), {type1}, info1); + + ir_program->block()->push_back(op); + + value_map[(*it)->result(0)] = op->result(0); + } } platform::Place place = platform::CUDAPlace(0); + + auto kernel_program = + paddle::dialect::PdOpLowerToKernelPass(ir_program.get(), place); + Scope exe_scope; - InterpreterCore executor(place, {}, ir_runtime_program->block(), &exe_scope); - executor.SetSkipGcVars(out_names); - executor.Run({}); - - // TODO(Aurelius84): Need to replace check with framework::Scope. - const float value = 2.0; - for (auto& name : out_names) { - std::vector data = - cinn::GetTensorData(scope->GetTensor(name), target); - for (int i = 0; i < data.size(); ++i) { - LOG_FIRST_N(INFO, 3) << "data: " << data[i]; - ASSERT_NEAR(data[i], value, 1e-5); - } + paddle::framework::interpreter::ExecutionConfig exe_conf; + exe_conf.create_local_scope = false; + InterpreterCore executor( + place, {"out@fetch"}, kernel_program->block(), &exe_scope); + + std::set out_names; + out_names.insert("out@fetch"); + auto local_names = exe_scope.LocalVarNames(); + for (size_t i = 0; i < local_names.size(); ++i) { + out_names.insert(local_names[i]); } + + executor.SetSkipGcVars(out_names); + executor.Run({}, true); + auto out_tensor = + executor.local_scope()->FindVar("out@fetch")->Get(); + + bool res0 = simple_cmp(out_tensor.data()[0], 1.35701); + bool res1 = simple_cmp(out_tensor.data()[1], 1.35701); + bool res2 = simple_cmp(out_tensor.data()[2], 1.35701); + bool res3 = simple_cmp(out_tensor.data()[3], 1.35701); + + EXPECT_EQ(res0, true); + EXPECT_EQ(res1, true); + EXPECT_EQ(res2, true); + EXPECT_EQ(res3, true); } } // namespace framework diff --git a/test/cpp/pir/cinn/new_ir_compiler_test.cc b/test/cpp/pir/cinn/new_ir_compiler_test.cc deleted file mode 100644 index 4b680b1ac8904..0000000000000 --- a/test/cpp/pir/cinn/new_ir_compiler_test.cc +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include - -#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" -#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" -#include "paddle/cinn/hlir/framework/convert_to_dialect.h" -#include "paddle/cinn/hlir/framework/new_ir_compiler.h" -#include "paddle/cinn/utils/data_util.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/pir/core/ir_context.h" -#include "paddle/pir/core/program.h" - -using cinn::hlir::framework::newir::Group; -using cinn::hlir::framework::newir::GroupPtr; - -using ProgramInfo = - std::tuple, std::vector>; -ProgramInfo BuildProgram() { - ::pir::IrContext* ctx = ::pir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); - auto program = std::make_shared<::pir::Program>(ctx); - ::pir::Builder builder = ::pir::Builder(ctx, program->block()); - - const float value_one = 1.0; // relu(tan(1.)) = 1.5; - const float value_two = 2.0; // relu(tan(2.)) = 0. - auto full_op_x = - builder.Build(std::vector{64, 128}, - value_one, - phi::DataType::FLOAT32, - phi::GPUPlace()); - - auto full_op_y = - builder.Build(std::vector{64, 128}, - value_two, - phi::DataType::FLOAT32, - phi::GPUPlace()); - - auto tan_op_x = builder.Build(full_op_x->result(0)); - auto relu_op_x = builder.Build(tan_op_x->result(0)); - auto tan_op_y = builder.Build(relu_op_x->result(0)); - auto relu_op_y = builder.Build(tan_op_y->result(0)); - - std::vector groups; - groups.emplace_back( - std::make_shared(std::initializer_list<::pir::Operation*>( - {full_op_x.operation()}))); // For coverage - groups.emplace_back(std::make_shared( - std::initializer_list<::pir::Operation*>({full_op_y.operation()}))); - groups.emplace_back(std::make_shared( - std::vector<::pir::Operation*>({tan_op_x.operation(), - relu_op_x.operation(), - tan_op_y.operation(), - relu_op_y.operation()}))); - - return {program, groups}; -} - -TEST(NewIRCompier, CompilerAndRun) { - // Step 1: Construct pir::Program - auto prog_info = BuildProgram(); - std::shared_ptr<::pir::Program> program = std::get<0>(prog_info); - EXPECT_EQ(program->block()->size(), 6u); - LOG(INFO) << program->block()->size(); - - std::stringstream ss; - program->Print(ss); - LOG(INFO) << ss.str(); - - // Step 2: Compiler New pir::Program into Runtime Program - auto target = cinn::common::DefaultNVGPUTarget(); - auto scope = cinn::hlir::framework::BuildScope(target, *program); - ASSERT_EQ(scope->var_names().size(), 6); - - cinn::hlir::framework::NewIRCompiler ir_compiler(*program, target, scope); - auto runtime_program = ir_compiler.Build(); - - // Step 3: Execute Runtime Instruction and check Scope. - ASSERT_NO_THROW(runtime_program->Execute()); - for (auto& var_name : scope->var_names()) { - std::string name = {var_name.begin(), var_name.end()}; - std::vector data = - cinn::GetTensorData(scope->GetTensor(name), target); - for (int i = 0; i < 1; ++i) { - LOG_FIRST_N(INFO, 10) << "data: " << data[i]; - } - } -} - -TEST(NewIRCompier, CompileGroupOps) { - // Step 1: Construct pir::Program - auto prog_info = BuildProgram(); - std::shared_ptr<::pir::Program> program = std::get<0>(prog_info); - std::vector groups = std::get<1>(prog_info); - EXPECT_EQ(program->block()->size(), 6u); - LOG(INFO) << program->block()->size(); - - std::stringstream ss; - program->Print(ss); - LOG(INFO) << ss.str(); - - // Step 2: Compiler New pir::Program into Runtime Program - auto target = cinn::common::DefaultNVGPUTarget(); - auto scope = cinn::hlir::framework::BuildScope(target, *program); - ASSERT_EQ(scope->var_names().size(), 6); - - cinn::hlir::framework::NewIRCompiler ir_compiler(*program, target, scope); - auto runtime_program = ir_compiler.Build(groups); - - // Step 3: Execute Runtime Instruction and check Scope. - ASSERT_NO_THROW(runtime_program->Execute()); - for (auto& var_name : scope->var_names()) { - std::string name = {var_name.begin(), var_name.end()}; - std::vector data = - cinn::GetTensorData(scope->GetTensor(name), target); - for (int i = 0; i < 1; ++i) { - LOG_FIRST_N(INFO, 10) << "data: " << data[i]; - } - } -} - -TEST(RuntimeDialect, CompilerAndRun) { - // Step 1: Construct pir::Program - auto prog_info = BuildProgram(); - std::shared_ptr<::pir::Program> program = std::get<0>(prog_info); - EXPECT_EQ(program->block()->size(), 6u); - - // Step 2: Compiler New pir::Program into Runtime Program - auto target = cinn::common::DefaultNVGPUTarget(); - auto scope = cinn::hlir::framework::BuildScope(target, *program); - ASSERT_EQ(scope->var_names().size(), 6u); - - cinn::hlir::framework::NewIRCompiler ir_compiler(*program, target, scope); - auto runtime_program = ir_compiler.Build(); - - // Step 3: Convert into cinn::dialect::RuntimeDialect - std::shared_ptr<::pir::Program> ir_runtime_program = - cinn::hlir::framework::ConvertToRuntimeDialect(*runtime_program); - - // Step 4: Run cinn::dialect::RuntimeDialect - for (auto iter = ir_runtime_program->block()->begin(); - iter != ir_runtime_program->block()->end(); - ++iter) { - auto op = (*iter)->dyn_cast(); - auto* instr = op.instruction(); - instr->Run(/*name2podargs=*/nullptr, - false, - /*stream=*/nullptr, - /*use_cache=*/true); - } -#ifdef CINN_WITH_CUDA - CUDA_CALL(cudaDeviceSynchronize()); -#endif - - // Step 5: Check Scope Tensor Value. - for (auto& var_name : scope->var_names()) { - std::string name = {var_name.begin(), var_name.end()}; - std::vector data = - cinn::GetTensorData(scope->GetTensor(name), target); - for (int i = 0; i < 1; ++i) { - LOG_FIRST_N(INFO, 10) << "data: " << data[i]; - } - } -} diff --git a/test/cpp/pir/cinn/pir_all_path_test.cc b/test/cpp/pir/cinn/pir_all_path_test.cc new file mode 100644 index 0000000000000..6d72dc06c12a4 --- /dev/null +++ b/test/cpp/pir/cinn/pir_all_path_test.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" +#include "paddle/fluid/framework/new_executor/interpretercore.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/transforms/build_cinn_pass.h" +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" + +bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; } + +std::vector<::pir::Type> CreateDenseTensorTypes(const phi::DDim& dims) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ::pir::Type fp32_dtype = ::pir::Float32Type::get(ctx); + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {}; + size_t offset = 0; + std::vector<::pir::Type> op_output_types = {::pir::DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset)}; + return op_output_types; +} + +std::shared_ptr<::pir::Program> BuildGroupProgram() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + + auto program = std::make_shared<::pir::Program>(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + + // full -> softmax(max -> subtract -> exp -> sum -> divide) + const float value_one = 1.0; + const std::vector shape = {128, 128, 768}; + auto x = builder + .Build( + shape, value_one, phi::DataType::FLOAT32, phi::GPUPlace()) + .result(0); + + auto max = + builder.Build(x, std::vector{-1}, true) + .result(0); + auto sub = builder.Build(x, max).result(0); + auto exp = builder.Build(sub).result(0); + auto sum = + builder + .Build( + exp, std::vector{-1}, phi::DataType::FLOAT32, true) + .result(0); + auto out = builder.Build(exp, sum).result(0); + + builder.Build(out, "out", 0); + return program; +} + +TEST(GroupOp, TestBuild) { + // Step 1: Construct pir::Program + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + std::shared_ptr<::pir::Program> program = BuildGroupProgram(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); + + pir::PassManager pm(ctx); + pm.AddPass( + std::make_unique()); + pm.AddPass(pir::CreateBuildCinnPass()); + CHECK_EQ(pm.Run(program.get()), true); + + auto res = cinn::dialect::ir::CINNGroupLoweringPass(program.get()); + + paddle::platform::Place place = paddle::platform::CUDAPlace(0); + + auto kernel_program = + paddle::dialect::PdOpLowerToKernelPass(res.get(), place); + + paddle::framework::Scope exe_scope; + + paddle::framework::InterpreterCore executor( + place, {"out@fetch"}, kernel_program->block(), &exe_scope); + + executor.Run({}, true); + + auto out_tensor = + executor.local_scope()->FindVar("out@fetch")->Get(); + + bool res0 = simple_cmp(out_tensor.data()[0], 1.0 / 768); + EXPECT_EQ(res0, true); +} diff --git a/test/cpp/pir/cinn/pir_compiler_test.cc b/test/cpp/pir/cinn/pir_compiler_test.cc new file mode 100644 index 0000000000000..c3f8b409260c0 --- /dev/null +++ b/test/cpp/pir/cinn/pir_compiler_test.cc @@ -0,0 +1,285 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" +#include "paddle/cinn/utils/data_util.h" +#include "paddle/fluid/framework/new_executor/interpretercore.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/program.h" + +using cinn::hlir::framework::pir::Group; +using cinn::hlir::framework::pir::GroupPtr; + +bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; } +using ProgramInfo = + std::tuple, std::vector>; +ProgramInfo BuildProgram() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + auto program = std::make_shared<::pir::Program>(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + + const float value_one = 1.0; // relu(tan(1.)) = 1.5; + const float value_two = 2.0; // relu(tan(2.)) = 0. + auto full_op_x = + builder.Build(std::vector{64, 128}, + value_one, + phi::DataType::FLOAT32, + phi::GPUPlace()); + + auto full_op_y = + builder.Build(std::vector{64, 128}, + value_two, + phi::DataType::FLOAT32, + phi::GPUPlace()); + + auto tan_op_x = builder.Build(full_op_x->result(0)); + auto relu_op_x = builder.Build(tan_op_x->result(0)); + auto tan_op_y = builder.Build(relu_op_x->result(0)); + auto relu_op_y = builder.Build(tan_op_y->result(0)); + + std::vector groups; + groups.emplace_back( + std::make_shared(std::initializer_list<::pir::Operation*>( + {full_op_x.operation()}))); // For coverage + groups.emplace_back(std::make_shared( + std::initializer_list<::pir::Operation*>({full_op_y.operation()}))); + groups.emplace_back(std::make_shared( + std::vector<::pir::Operation*>({tan_op_x.operation(), + relu_op_x.operation(), + tan_op_y.operation(), + relu_op_y.operation()}))); + + return {program, groups}; +} + +ProgramInfo BuildSoftmax() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto program = std::make_shared<::pir::Program>(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + std::vector axes{-1}; + + auto x = builder + .Build(std::vector({16, 16}), + 1.0, + phi::DataType::FLOAT32, + phi::GPUPlace(0)) + .result(0); + auto max = builder.Build(x, axes, true).result(0); + auto broadcast_1 = + builder + .Build( + max, std::vector({0, 1}), std::vector({16, 16})) + .result(0); + auto sub = + builder.Build(x, broadcast_1).result(0); + auto exp = builder.Build(sub).result(0); + auto sum = + builder.Build(exp, axes, true).result(0); + + auto broadcast_2 = + builder + .Build( + sum, std::vector({0, 1}), std::vector({16, 16})) + .result(0); + auto divide = + builder.Build(exp, broadcast_2).result(0); + + std::vector groups; + groups.emplace_back(std::make_shared( + std::initializer_list<::pir::Operation*>({max.owner(), + broadcast_1.owner(), + sub.owner(), + exp.owner(), + sum.owner(), + broadcast_2.owner(), + divide.owner()}))); + + return {program, groups}; +} + +TEST(PirCompier, CompileSoftmax) { + // Step 1: Construct pir::Program + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto new_program = std::make_shared<::pir::Program>(ctx); + + auto prog_info = BuildSoftmax(); + std::shared_ptr<::pir::Program> program = std::get<0>(prog_info); + std::vector groups = std::get<1>(prog_info); + EXPECT_EQ(program->block()->size(), 8u); + LOG(INFO) << program->block()->size(); + + std::stringstream ss; + program->Print(ss); + LOG(INFO) << ss.str(); + + // Step 2: Compiler New pir::Program into Runtime Program + auto target = cinn::common::DefaultNVGPUTarget(); + auto scope = cinn::hlir::framework::BuildScope(target, *program); + LOG(INFO) << scope->var_names().size(); + ASSERT_EQ(scope->var_names().size(), 8); + + cinn::hlir::framework::PirCompiler ir_compiler(*program, target, scope); + auto fn_ptr_res = ir_compiler.BuildCUDAJITInfo(groups); + + ::pir::Builder builder = ::pir::Builder(ctx, new_program->block()); + auto x = builder + .Build(std::vector({16, 16}), + 1.0, + phi::DataType::FLOAT32, + phi::GPUPlace(0)) + .result(0); + + std::unordered_map op_attrs{ + {cinn::dialect::JitKernelOp::kAttrName, + cinn::dialect::CUDAJITInfoAttribute::get(ctx, fn_ptr_res[0])}, + }; + + std::vector vec_types; + + vec_types.push_back(groups[0]->ops.back()->result(0).type()); + + std::string jit_op_name = cinn::dialect::JitKernelOp::name(); + ::pir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name); + ::pir::Operation* cinn_op = + ::pir::Operation::Create({x}, op_attrs, vec_types, op_info); + + new_program->block()->push_back(cinn_op); + + builder.SetInsertionPointToEnd(new_program->block()); + builder.Build( + cinn_op->result(cinn_op->num_results() - 1), "out", 0); + + paddle::platform::Place place = paddle::platform::CUDAPlace(0); + + auto kernel_program = + paddle::dialect::PdOpLowerToKernelPass(new_program.get(), place); + + paddle::framework::Scope exe_scope; + + paddle::framework::interpreter::ExecutionConfig exe_conf; + exe_conf.create_local_scope = false; + paddle::framework::InterpreterCore executor( + place, {"out@fetch"}, kernel_program->block(), &exe_scope); + + executor.Run({}, true); + auto out_tensor = + executor.local_scope()->FindVar("out@fetch")->Get(); + + bool res0 = simple_cmp(out_tensor.data()[0], 1.0 / 16); + EXPECT_EQ(res0, true); +} + +TEST(PirCompier, CompilerAndRun) { + // Step 1: Construct pir::Program + auto prog_info = BuildProgram(); + std::shared_ptr<::pir::Program> program = std::get<0>(prog_info); + EXPECT_EQ(program->block()->size(), 6u); + LOG(INFO) << program->block()->size(); + + std::stringstream ss; + program->Print(ss); + LOG(INFO) << ss.str(); + + // Step 2: Compiler New pir::Program into Runtime Program + auto target = cinn::common::DefaultNVGPUTarget(); + auto scope = cinn::hlir::framework::BuildScope(target, *program); + ASSERT_EQ(scope->var_names().size(), 6); + + cinn::hlir::framework::PirCompiler ir_compiler(*program, target, scope); + auto runtime_program = ir_compiler.Build(); + + // Step 3: Execute Runtime Instruction and check Scope. + ASSERT_NO_THROW(runtime_program->Execute()); + for (auto& var_name : scope->var_names()) { + std::string name = {var_name.begin(), var_name.end()}; + std::vector data = + cinn::GetTensorData(scope->GetTensor(name), target); + for (int i = 0; i < 1; ++i) { + LOG_FIRST_N(INFO, 10) << "data: " << data[i]; + } + } +} + +TEST(PirCompier, CompileGroupOps) { + // Step 1: Construct pir::Program + auto prog_info = BuildProgram(); + std::shared_ptr<::pir::Program> program = std::get<0>(prog_info); + std::vector groups = std::get<1>(prog_info); + EXPECT_EQ(program->block()->size(), 6u); + LOG(INFO) << program->block()->size(); + + std::stringstream ss; + program->Print(ss); + LOG(INFO) << ss.str(); + + // Step 2: Compiler New pir::Program into Runtime Program + auto target = cinn::common::DefaultNVGPUTarget(); + auto scope = cinn::hlir::framework::BuildScope(target, *program); + ASSERT_EQ(scope->var_names().size(), 6); + + cinn::hlir::framework::PirCompiler ir_compiler(*program, target, scope); + auto runtime_program = ir_compiler.Build(groups); + + // Step 3: Execute Runtime Instruction and check Scope. + ASSERT_NO_THROW(runtime_program->Execute()); + for (auto& var_name : scope->var_names()) { + std::string name = {var_name.begin(), var_name.end()}; + std::vector data = + cinn::GetTensorData(scope->GetTensor(name), target); + for (int i = 0; i < 1; ++i) { + LOG_FIRST_N(INFO, 10) << "data: " << data[i]; + } + } +} + +TEST(RuntimeDialect, CompilerAndRun) { + // Step 1: Construct pir::Program + auto prog_info = BuildProgram(); + std::shared_ptr<::pir::Program> program = std::get<0>(prog_info); + EXPECT_EQ(program->block()->size(), 6u); + + // Step 2: Compiler New pir::Program into Runtime Program + auto target = cinn::common::DefaultNVGPUTarget(); + auto scope = cinn::hlir::framework::BuildScope(target, *program); + ASSERT_EQ(scope->var_names().size(), 6u); + + cinn::hlir::framework::PirCompiler ir_compiler(*program, target, scope); + auto runtime_program = ir_compiler.Build(); +} diff --git a/test/cpp/pir/control_flow_dialect/CMakeLists.txt b/test/cpp/pir/control_flow_dialect/CMakeLists.txt index 64af30a54d0ee..d295fa37b6b73 100644 --- a/test/cpp/pir/control_flow_dialect/CMakeLists.txt +++ b/test/cpp/pir/control_flow_dialect/CMakeLists.txt @@ -4,7 +4,7 @@ cc_test_old( if_op_test.cc DEPS pir - pd_op_dialect + op_dialect_vjp gtest) cc_test_old( @@ -13,5 +13,5 @@ cc_test_old( while_op_test.cc DEPS pir - pd_op_dialect + op_dialect_vjp gtest) diff --git a/test/cpp/pir/control_flow_dialect/if_op_test.cc b/test/cpp/pir/control_flow_dialect/if_op_test.cc index 02d4061a0d5f8..475d446027f4f 100644 --- a/test/cpp/pir/control_flow_dialect/if_op_test.cc +++ b/test/cpp/pir/control_flow_dialect/if_op_test.cc @@ -21,7 +21,9 @@ #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/program.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" -#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" + +using namespace paddle::dialect; // NOLINT TEST(if_op_test, base) { pir::IrContext* ctx = pir::IrContext::Instance(); @@ -99,3 +101,82 @@ TEST(if_op_test, build_by_block) { LOG(INFO) << ss.str(); } + +TEST(if_op_test, network_with_backward) { + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + pir::Block* block = program.block(); + pir::Builder builder(ctx, block); + auto x = builder.Build(std::vector{2, 2}, 1.0f).out(); + auto y = builder.Build(std::vector{2, 2}, 2.0f).out(); + auto cond = builder.Build(x, y).out(); + auto [stack_0, inlet_0, outlet_0] = builder.Build().out(); + auto [stack_1, inlet_1, outlet_1] = builder.Build().out(); + (void)(stack_0); + (void)(stack_1); + + auto if_op = builder.Build(cond, std::vector{x.type()}); + + builder.SetInsertionPointToStart(if_op.true_block()); + auto local1_z = builder.Build(x, y).out(); + auto local1_w = builder.Build(local1_z, y).out(); + builder.Build(inlet_0, + std::initializer_list{local1_z}); + builder.Build(std::vector{local1_w}); + + builder.SetInsertionPointToStart(if_op.false_block()); + auto local2_z = builder.Build(x, y).out(); + auto local2_w = builder.Build(local2_z, y).out(); + builder.Build(inlet_1, + std::initializer_list{local2_z}); + builder.Build(std::vector{local2_w}); + + builder.SetInsertionPointToEnd(block); + + // build backward network + auto out_grad = builder.Build(std::vector{2, 2}, 1.0f).out(); + // the output of if_grad op is {x_grad, y_grad} + auto if_grad = + builder.Build(cond, std::vector{x.type(), y.type()}); + + // construct the true block of if_grad + builder.SetInsertionPointToStart(if_grad.true_block()); + auto pop_local1_z = builder.Build(outlet_0).outlet_element(0); + auto local1_add_grad_op = builder.Build(pop_local1_z, y, out_grad); + auto pop_local1_z_grad = local1_add_grad_op.x_grad(), + local1_y_grad_0 = local1_add_grad_op.y_grad(); + auto local1_add_grad_op_1 = builder.Build(x, y, pop_local1_z_grad); + auto local1_x_grad = local1_add_grad_op_1.x_grad(), + local1_y_grad_1 = local1_add_grad_op_1.y_grad(); + auto local1_y_grad = + builder.Build(local1_y_grad_0, local1_y_grad_1).out(); + builder.Build( + std::vector{local1_x_grad, local1_y_grad}); + + // construct the false block of if_grad + builder.SetInsertionPointToStart(if_grad.false_block()); + auto pop_local2_z = builder.Build(outlet_1).outlet_element(0); + auto local2_matmul_grad_op = + builder.Build(pop_local2_z, y, out_grad); + auto pop_local2_z_grad = local2_matmul_grad_op.x_grad(), + local2_y_grad_0 = local2_matmul_grad_op.y_grad(); + auto local2_matmul_grad_op_1 = + builder.Build(x, y, pop_local2_z_grad); + auto local2_x_grad = local2_matmul_grad_op_1.x_grad(), + local2_y_grad_1 = local2_matmul_grad_op_1.y_grad(); + + auto local2_y_grad = + builder.Build(local2_y_grad_0, local2_y_grad_1).out(); + builder.Build( + std::vector{local2_x_grad, local2_y_grad}); + + builder.SetInsertionPointToEnd(block); + + std::stringstream ss; + program.Print(ss); + + LOG(INFO) << ss.str(); +} diff --git a/test/cpp/pir/control_flow_dialect/while_op_test.cc b/test/cpp/pir/control_flow_dialect/while_op_test.cc index 7536ea2014fe0..27893b1917c0a 100644 --- a/test/cpp/pir/control_flow_dialect/while_op_test.cc +++ b/test/cpp/pir/control_flow_dialect/while_op_test.cc @@ -21,7 +21,7 @@ #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/program.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" -#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" using namespace paddle::dialect; // NOLINT diff --git a/test/cpp/pir/core/CMakeLists.txt b/test/cpp/pir/core/CMakeLists.txt index ca71cb8fe9eef..231f6a64cef6a 100644 --- a/test/cpp/pir/core/CMakeLists.txt +++ b/test/cpp/pir/core/CMakeLists.txt @@ -5,7 +5,7 @@ cc_test_old( DEPS pir gtest - pd_op_dialect) + op_dialect_vjp) cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS pir gtest) cc_test_old(ir_value_test SRCS ir_value_test.cc DEPS pir gtest) paddle_test( @@ -16,7 +16,7 @@ paddle_test( pir gtest test_dialect - pd_op_dialect) + op_dialect_vjp) cc_test_old(ir_region_test SRCS ir_region_test.cc DEPS pir gtest) cc_test_old(ir_builder_test SRCS ir_builder_test.cc DEPS pir gtest) cc_test_old( @@ -24,7 +24,7 @@ cc_test_old( SRCS ir_program_test.cc DEPS - pd_op_dialect + op_dialect_vjp pir phi gtest) @@ -34,7 +34,7 @@ cc_test_old( SRCS ir_infershape_test.cc DEPS - pd_op_dialect + op_dialect_vjp pir phi gtest) @@ -44,7 +44,7 @@ cc_test_old( SRCS scalar_attribute_test.cc DEPS - pd_op_dialect + op_dialect_vjp pir gtest) @@ -80,7 +80,7 @@ cc_test_old( DEPS program_translator gtest - pd_op_dialect + op_dialect_vjp pir) cc_test_old( @@ -89,13 +89,13 @@ cc_test_old( add_dialect_parser_test.cc DEPS gtest - pd_op_dialect + op_dialect_vjp pir) cc_test( ir_parser_test SRCS ir_parser_test.cc - DEPS gtest pd_op_dialect pir) + DEPS gtest op_dialect_vjp pir) cc_test_old(ir_op_info_test SRCS op_info_test.cc DEPS gtest pir) cc_test_old( @@ -104,8 +104,8 @@ cc_test_old( op_yaml_info_parser_test.cc DEPS gtest - pd_op_dialect - pd_interface + op_dialect + op_dialect_vjp pir) cc_test_old( @@ -115,7 +115,7 @@ cc_test_old( DEPS gtest program_translator - pd_op_dialect + op_dialect_vjp pir) cc_test_old( @@ -126,7 +126,7 @@ cc_test_old( pir test_dialect gtest - pd_op_dialect) + op_dialect_vjp) cc_test_old( block_operand_test diff --git a/test/cpp/pir/core/ir_builder_test.cc b/test/cpp/pir/core/ir_builder_test.cc index e3705d08c7ef9..84e7d271bce47 100644 --- a/test/cpp/pir/core/ir_builder_test.cc +++ b/test/cpp/pir/core/ir_builder_test.cc @@ -47,6 +47,7 @@ TEST(builder_test, attribute_api) { EXPECT_EQ(pir::DoubleAttribute::get(&ctx, 2.0), builder.double_attr(2.0)); EXPECT_EQ(pir::Int32Attribute::get(&ctx, 2), builder.int32_attr(2)); EXPECT_EQ(pir::Int64Attribute::get(&ctx, 2), builder.int64_attr(2)); + EXPECT_EQ(pir::IndexAttribute::get(&ctx, 2), builder.index_attr(2)); EXPECT_EQ(pir::ArrayAttribute::get(&ctx, std::vector()), builder.array_attr({})); EXPECT_EQ(pir::PointerAttribute::get(&ctx, nullptr), diff --git a/test/cpp/pir/core/ir_op_test.cc b/test/cpp/pir/core/ir_op_test.cc index 596519ba57d4c..1631c8198d3e5 100644 --- a/test/cpp/pir/core/ir_op_test.cc +++ b/test/cpp/pir/core/ir_op_test.cc @@ -21,7 +21,6 @@ #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_op.h" -#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/dialect.h" #include "paddle/pir/core/enforce.h" #include "paddle/pir/core/ir_context.h" @@ -32,39 +31,7 @@ #include "test/cpp/pir/tools/test_dialect.h" #include "test/cpp/pir/tools/test_op.h" -pir::AttributeMap CreateAttributeMap( - const std::vector &attribute_names, - const std::vector &attributes) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::AttributeMap attr_map; - for (size_t i = 0; i < attribute_names.size(); i++) { - pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); - attr_map.insert( - std::pair(attribute_names[i], attr_value)); - } - return attr_map; -} - -pir::Operation *CreateDenseTensorOp( - pir::IrContext *ctx, - const phi::DDim &dims, - const std::vector &attribute_names, - const std::vector &attributes, - const pir::Type &dtype = - pir::Float32Type::get(pir::IrContext::Instance())) { - std::vector op_inputs = {}; - phi::DataLayout data_layout = phi::DataLayout::NCHW; - phi::LoD lod = {{0, 1, 2}}; - size_t offset = 0; - std::vector op_output_types = { - pir::DenseTensorType::get(ctx, dtype, dims, data_layout, lod, offset)}; - pir::Operation *op = - pir::Operation::Create(op_inputs, - CreateAttributeMap(attribute_names, attributes), - op_output_types, - pir::OpInfo()); - return op; -} +#include "test/cpp/pir/tools/test_pir_utils.h" TEST(op_test, region_test) { // (1) Register Dialect, Operation1, Operation2 into IrContext. @@ -76,12 +43,12 @@ TEST(op_test, region_test) { pir::OpInfo op1_info = ctx->GetRegisteredOpInfo(test::Operation1::name()); pir::OpInfo op2_info = ctx->GetRegisteredOpInfo(test::Operation2::name()); - pir::Operation *op1 = - pir::Operation::Create({}, - CreateAttributeMap({"op1_attr1", "op1_attr2"}, - {"op1_attr1", "op1_attr2"}), - {pir::Float32Type::get(ctx)}, - op1_info); + pir::Operation *op1 = pir::Operation::Create( + {}, + test::CreateAttributeMap({"op1_attr1", "op1_attr2"}, + {"op1_attr1", "op1_attr2"}), + {pir::Float32Type::get(ctx)}, + op1_info); pir::Operation *op_2 = pir::Operation::Create({}, {}, {pir::Float32Type::get(ctx)}, op2_info); @@ -169,9 +136,9 @@ TEST(op_test, op_traits_test) { pir::DenseTensorType::get(ctx, dtype, dims, data_layout, lod, offset); pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); auto op3 = builder.Build( op1->result(0), op2->result(0), dense_tensor_dtype); @@ -220,9 +187,9 @@ TEST(op_test, same_operands_shape_trait_test2) { pir::DenseTensorType::get(ctx, dtype1, dims1, data_layout, lod, offset); pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); + test::CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); + test::CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); EXPECT_THROW(builder.Build( op1->result(0), op2->result(0), dense_tensor_dtype), @@ -255,9 +222,9 @@ TEST(op_test, same_operands_and_result_shape_trait_test2) { phi::DDim dims = {2, 2, 2}; pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); EXPECT_THROW(builder.Build( op1->result(0), op2->result(0)), @@ -287,9 +254,9 @@ TEST(op_test, same_operands_and_result_shape_trait_test3) { pir::DenseTensorType::get(ctx, dtype1, dims1, data_layout, lod, offset); pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); + test::CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); + test::CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); EXPECT_THROW(builder.Build( op1->result(0), op2->result(0), dense_tensor_dtype), @@ -330,9 +297,9 @@ TEST(op_test, same_operands_element_type_trait_test2) { pir::DenseTensorType::get(ctx, dtype1, dims, data_layout, lod, offset); pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype1); + test::CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype1); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype2); + test::CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype2); EXPECT_THROW(builder.Build( op1->result(0), op2->result(0), dense_tensor_dtype), @@ -365,9 +332,9 @@ TEST(op_test, same_operands_and_result_element_type_trait_test2) { phi::DDim dims = {2, 2}; pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); EXPECT_THROW(builder.Build( op1->result(0), op2->result(0)), @@ -399,9 +366,9 @@ TEST(op_test, same_operands_and_result_element_type_trait_test3) { pir::DenseTensorType::get(ctx, dtype2, dims2, data_layout, lod, offset); pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); + test::CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); + test::CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); EXPECT_THROW(builder.Build( op1->result(0), @@ -443,9 +410,9 @@ TEST(op_test, same_operands_and_result_type_trait_test2) { phi::DDim dims = {2, 2}; pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); EXPECT_THROW(builder.Build( op1->result(0), op2->result(0)), @@ -481,9 +448,9 @@ TEST(op_test, same_operands_and_result_type_trait_test3) { pir::DenseTensorType::get(ctx, dtype1, dims2, data_layout, lod, offset); pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype2); + test::CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype2); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype1); + test::CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype1); EXPECT_THROW(builder.Build( op1->result(0), diff --git a/test/cpp/pir/core/ir_program_test.cc b/test/cpp/pir/core/ir_program_test.cc index 7ae348d004f53..b4221cf5518d7 100644 --- a/test/cpp/pir/core/ir_program_test.cc +++ b/test/cpp/pir/core/ir_program_test.cc @@ -110,10 +110,10 @@ TEST(program_test, program) { auto op1 = builder.Build("a", dense_tensor_dtype); EXPECT_EQ(&program, op1->GetParentProgram()); - EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id()); + EXPECT_EQ(op1->result_type(0).dialect().id(), paddle_dialect->id()); using Interface = paddle::dialect::ParameterConvertInterface; Interface *a_interface = - op1->result(0).type().dialect().GetRegisteredInterface(); + op1->result_type(0).dialect().GetRegisteredInterface(); std::shared_ptr a_var = a_interface->ParameterToVariable(program.GetParameter("a")); const phi::DenseTensor &a_tensor = a_var->Get(); @@ -130,9 +130,9 @@ TEST(program_test, program) { // (5) Def b = GetParameterOp("b"), and create DenseTensor for b. auto op2 = builder.Build("b", dense_tensor_dtype); - EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id()); + EXPECT_EQ(op2->result_type(0).dialect().id(), paddle_dialect->id()); Interface *b_interface = - op2->result(0).type().dialect().GetRegisteredInterface(); + op2->result_type(0).dialect().GetRegisteredInterface(); std::shared_ptr b_var = b_interface->ParameterToVariable(program.GetParameter("b")); const phi::DenseTensor &b_tensor = b_var->Get(); @@ -269,7 +269,7 @@ TEST(program_test, builder) { paddle::dialect::FullOp full_op = builder.Build( std::vector{2, 2}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); - pir::Type full_op_output = full_op->result(0).type(); + pir::Type full_op_output = full_op->result_type(0); EXPECT_EQ(program.block()->size(), 1u); EXPECT_EQ(program.block()->back(), full_op.operation()); EXPECT_EQ(full_op.num_operands(), 0u); diff --git a/test/cpp/pir/core/ir_value_test.cc b/test/cpp/pir/core/ir_value_test.cc index d4a7d14322a66..dba46b72c08a0 100644 --- a/test/cpp/pir/core/ir_value_test.cc +++ b/test/cpp/pir/core/ir_value_test.cc @@ -16,55 +16,47 @@ #include "paddle/pir/core/attribute.h" #include "paddle/pir/core/builtin_attribute.h" -#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/operation.h" +#include "test/cpp/pir/tools/test_pir_utils.h" + // This unittest is used to test the construction interfaces of value class and // operation. The constructed test scenario is: a = OP1(); b = OP2(); c = OP3(a, // b); d, e, f, g, h, i, j = OP4(a, c); -pir::AttributeMap CreateAttributeMap(std::string attribute_name, - std::string attribute) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Attribute attr_value = pir::StrAttribute::get(ctx, attribute); - pir::AttributeMap attr_map; - attr_map.insert( - std::pair(attribute_name, attr_value)); - return attr_map; -} TEST(value_test, value_test) { pir::IrContext *ctx = pir::IrContext::Instance(); // 1. Construct OP1: a = OP1() std::vector op1_inputs = {}; std::vector op1_output_types = {pir::Float32Type::get(ctx)}; - pir::Operation *op1 = - pir::Operation::Create(op1_inputs, - CreateAttributeMap("op1_name", "op1_attr"), - op1_output_types, - pir::OpInfo()); + pir::Operation *op1 = pir::Operation::Create( + op1_inputs, + test::CreateAttributeMap({"op1_name"}, {"op1_attr"}), + op1_output_types, + pir::OpInfo()); op1->Print(std::cout); pir::OpResult a = op1->result(0); EXPECT_TRUE(a.use_empty()); // 2. Construct OP2: b = OP2(); std::vector op2_inputs = {}; std::vector op2_output_types = {pir::Float32Type::get(ctx)}; - pir::Operation *op2 = - pir::Operation::Create(op2_inputs, - CreateAttributeMap("op2_name", "op2_attr"), - op2_output_types, - pir::OpInfo()); + pir::Operation *op2 = pir::Operation::Create( + op2_inputs, + test::CreateAttributeMap({"op2_name"}, {"op2_attr"}), + op2_output_types, + pir::OpInfo()); op2->Print(std::cout); pir::OpResult b = op2->result(0); EXPECT_TRUE(b.use_empty()); // 3. Construct OP3: c = OP3(a, b); std::vector op3_inputs{a, b}; std::vector op3_output_types = {pir::Float32Type::get(ctx)}; - pir::Operation *op3 = - pir::Operation::Create(op3_inputs, - CreateAttributeMap("op3_name", "op3_attr"), - op3_output_types, - pir::OpInfo()); + pir::Operation *op3 = pir::Operation::Create( + op3_inputs, + test::CreateAttributeMap({"op3_name"}, {"op3_attr"}), + op3_output_types, + pir::OpInfo()); EXPECT_TRUE(op1->result(0).HasOneUse()); EXPECT_TRUE(op2->result(0).HasOneUse()); @@ -76,11 +68,11 @@ TEST(value_test, value_test) { for (size_t i = 0; i < 7; i++) { op4_output_types.push_back(pir::Float32Type::get(ctx)); } - pir::Operation *op4 = - pir::Operation::Create(op4_inputs, - CreateAttributeMap("op4_name", "op4_attr"), - op4_output_types, - pir::OpInfo()); + pir::Operation *op4 = pir::Operation::Create( + op4_inputs, + test::CreateAttributeMap({"op4_name"}, {"op4_attr"}), + op4_output_types, + pir::OpInfo()); op4->Print(std::cout); // Test 1: diff --git a/test/cpp/pir/core/program_translator_test.cc b/test/cpp/pir/core/program_translator_test.cc index ba85e396d41b7..d20bf912a4b0d 100644 --- a/test/cpp/pir/core/program_translator_test.cc +++ b/test/cpp/pir/core/program_translator_test.cc @@ -37,7 +37,7 @@ #include "paddle/pir/core/ir_printer.h" #include "paddle/pir/core/parser/ir_parser.h" #include "paddle/pir/core/program.h" -#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" using OperatorDialect = paddle::dialect::OperatorDialect; using ProgramDesc = paddle::framework::ProgramDesc; diff --git a/test/cpp/pir/core/scalar_attribute_test.cc b/test/cpp/pir/core/scalar_attribute_test.cc index e15ebfad84585..5d547c58c3a92 100644 --- a/test/cpp/pir/core/scalar_attribute_test.cc +++ b/test/cpp/pir/core/scalar_attribute_test.cc @@ -50,6 +50,9 @@ TEST(ScalarTest, test_classof) { pir::Attribute int32_scalar = pir::Int32Attribute::get(ctx, 1); EXPECT_TRUE(int32_scalar.isa()); + pir::Attribute index_scalar = pir::IndexAttribute::get(ctx, 1l); + EXPECT_TRUE(index_scalar.isa()); + pir::Attribute int64_scalar = pir::Int64Attribute::get(ctx, 1l); EXPECT_TRUE(int64_scalar.isa()); } diff --git a/test/cpp/pir/core/type_test.cc b/test/cpp/pir/core/type_test.cc index 0f3581732784f..2ec503dd20a95 100644 --- a/test/cpp/pir/core/type_test.cc +++ b/test/cpp/pir/core/type_test.cc @@ -95,6 +95,7 @@ TEST(type_test, built_in_type) { pir::Type index_1 = pir::IndexType::get(ctx); pir::Type index_2 = pir::IndexType::get(ctx); + EXPECT_TRUE(index_1.IsIndex()); EXPECT_EQ(index_1, index_2); EXPECT_EQ(index_1.type_id(), index_2.type_id()); EXPECT_EQ(&index_1.abstract_type(), diff --git a/test/cpp/pir/kernel_dialect/CMakeLists.txt b/test/cpp/pir/kernel_dialect/CMakeLists.txt index 16cad10461745..89b129935c1e6 100644 --- a/test/cpp/pir/kernel_dialect/CMakeLists.txt +++ b/test/cpp/pir/kernel_dialect/CMakeLists.txt @@ -3,10 +3,9 @@ cc_test_old( SRCS ir_kernel_dialect_pass_test.cc DEPS - pd_op_to_kernel_pass + transform program_translator - pd_kernel_dialect - pd_trait + op_dialect pir phi gtest) diff --git a/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc b/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc index 6812e7a9ed194..bd71aab4b304e 100644 --- a/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc +++ b/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc @@ -46,7 +46,7 @@ #include "paddle/pir/core/program.h" #include "paddle/pir/core/utils.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" -#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT); diff --git a/test/cpp/pir/pass/CMakeLists.txt b/test/cpp/pir/pass/CMakeLists.txt index be68cdab344e7..fb9f37e080f38 100644 --- a/test/cpp/pir/pass/CMakeLists.txt +++ b/test/cpp/pir/pass/CMakeLists.txt @@ -4,6 +4,6 @@ cc_test_old( pass_manager_test.cc DEPS pir - pd_op_dialect + op_dialect_vjp phi gtest) diff --git a/test/cpp/pir/pattern_rewrite/CMakeLists.txt b/test/cpp/pir/pattern_rewrite/CMakeLists.txt index 7edd32531be34..209e33c1eca42 100644 --- a/test/cpp/pir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/pir/pattern_rewrite/CMakeLists.txt @@ -1,5 +1,4 @@ -set(PATTERN_REWRITE_TEST_DEPS - _constant_folding_pass transform_general_functions gtest pd_op_dialect pir) +set(PATTERN_REWRITE_TEST_DEPS transform gtest op_dialect_vjp pir) if(WITH_DISTRIBUTE) set(PATTERN_REWRITE_TEST_DEPS ${PATTERN_REWRITE_TEST_DEPS} fleet_executor @@ -8,3 +7,31 @@ endif() cc_test_old(pattern_rewrite_test SRCS pattern_rewrite_test.cc DEPS ${PATTERN_REWRITE_TEST_DEPS}) + +cc_test( + drr_test + SRCS drr_test.cc + DEPS drr transform) + +cc_test( + drr_same_type_binding_test + SRCS drr_same_type_binding_test.cc + DEPS drr gtest op_dialect_vjp pir transform) + +cc_test( + drr_fuse_linear_test + SRCS drr_fuse_linear_test.cc + DEPS transform drr gtest op_dialect_vjp pir) + +cc_test( + drr_fuse_linear_param_grad_add_test + SRCS drr_fuse_linear_param_grad_add_test.cc + DEPS transform drr gtest op_dialect_vjp pir) + +cc_test( + drr_attention_fuse_test + SRCS drr_attention_fuse_test.cc + DEPS transform drr gtest op_dialect_vjp pir) + +set_tests_properties(pattern_rewrite_test + PROPERTIES ENVIRONMENT "FLAGS_enable_pir_in_executor=true") diff --git a/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc new file mode 100644 index 0000000000000..8ac00044146f5 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc @@ -0,0 +1,149 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass_manager.h" + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp matmul_1_in_1 = + builder.Build(std::vector{1, 300, 256}, + 0.9, + phi::DataType::FLOAT32, + phi::CPUPlace()); + // The first path to matmul with scale (q). + paddle::dialect::FullOp matmul_1_in_2 = + builder.Build(std::vector{256, 256}, + 1.1, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::MatmulOp matmul_1 = builder.Build( + matmul_1_in_1.out(), matmul_1_in_2.out(), false, false); + + paddle::dialect::FullOp add_1_in_2 = builder.Build( + std::vector{256}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::AddOp add_1 = + builder.Build(matmul_1.out(), add_1_in_2.out()); + + paddle::dialect::ReshapeOp reshape_1 = + builder.Build( + add_1.out(), std::vector{0, 0, 8, 32}); + + paddle::dialect::TransposeOp transpose_1 = + builder.Build(reshape_1.out(), + std::vector{0, 2, 1, 3}); + + paddle::dialect::ScaleOp scale_op = builder.Build( + transpose_1.out(), 0.1767766922712326, 0.0, true); + + // The second path to matmul (k). + paddle::dialect::FullOp matmul_2_in_2 = + builder.Build(std::vector{256, 256}, + 1.1, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::MatmulOp matmul_2 = builder.Build( + matmul_1_in_1.out(), matmul_2_in_2.out(), false, false); + + paddle::dialect::FullOp add_2_in_2 = builder.Build( + std::vector{256}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + paddle::dialect::AddOp add_op2 = + builder.Build(matmul_2.out(), add_2_in_2.out()); + + paddle::dialect::ReshapeOp reshape_2 = + builder.Build( + add_op2.out(), std::vector{0, 0, 8, 32}); + + paddle::dialect::TransposeOp transpose_2 = + builder.Build(reshape_2.out(), + std::vector{0, 2, 1, 3}); + + // The third path to matmul (v). + paddle::dialect::FullOp matmul_3_in_2 = + builder.Build(std::vector{256, 256}, + 1.1, + phi::DataType::FLOAT32, + phi::CPUPlace()); + paddle::dialect::MatmulOp matmul_3 = builder.Build( + matmul_1_in_1.out(), matmul_3_in_2.out(), false, false); + + paddle::dialect::FullOp add_3_in_2 = builder.Build( + std::vector{256}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::AddOp add_3 = + builder.Build(matmul_3.out(), add_3_in_2.out()); + + paddle::dialect::ReshapeOp reshape_3 = + builder.Build( + add_3.out(), std::vector{0, 0, 8, 32}); + + paddle::dialect::TransposeOp transpose_3 = + builder.Build(reshape_3.out(), + std::vector{0, 2, 1, 3}); + + // softmax(qk)v + paddle::dialect::MatmulOp matmul_4 = builder.Build( + scale_op.out(), transpose_2.out(), false, true); + + paddle::dialect::FullOp add_4_in_2 = builder.Build( + std::vector{1, 8, 300, 300}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::AddOp add_4 = + builder.Build(matmul_4.out(), add_4_in_2.out()); + + paddle::dialect::SoftmaxOp softmax_op = + builder.Build(add_4.out(), -1); + paddle::dialect::MatmulOp matmul_5 = builder.Build( + softmax_op.out(), transpose_3.out(), false, false); + + paddle::dialect::TransposeOp transpose_4 = + builder.Build(matmul_5.out(), + std::vector{0, 2, 1, 3}); + + paddle::dialect::ReshapeOp reshape_4 = + builder.Build( + transpose_4.out(), std::vector{0, 0, 256}); + + builder.Build(reshape_4.out(), "out", 0); +} + +TEST(DrrTest, AttentionFuse) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + EXPECT_EQ(program.block()->size(), 33u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateAttentionFusePass()); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 20u); +} diff --git a/test/cpp/pir/pattern_rewrite/drr_fuse_linear_param_grad_add_test.cc b/test/cpp/pir/pattern_rewrite/drr_fuse_linear_param_grad_add_test.cc new file mode 100644 index 0000000000000..da496e6d940c0 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_fuse_linear_param_grad_add_test.cc @@ -0,0 +1,240 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +void BuildProgram0(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build(std::vector{32, 32}, 1.5); + paddle::dialect::FullOp full_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + paddle::dialect::FullOp full_bias_op1 = + builder.Build(std::vector{32}, 1.0); + + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(full_input_op1.out(), + full_weight_op1.out()); + paddle::dialect::AddOp add_op1 = builder.Build( + matmul_op1.out(), full_bias_op1.out()); + + paddle::dialect::FullOp full_d_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::FullOp full_d_out_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::AddGradOp add_grad_op1 = + builder.Build( + matmul_op1.out(), full_bias_op1.out(), full_d_out_op1.out()); + + paddle::dialect::MatmulGradOp matmul_grad_op1 = + builder.Build( + full_input_op1.out(), full_weight_op1.out(), add_grad_op1.x_grad()); + + paddle::dialect::Add_Op add__op1 = builder.Build( + full_d_weight_op1.out(), matmul_grad_op1.y_grad()); + + builder.Build(add_op1.out(), "out", 0); + builder.Build(add_grad_op1.y_grad(), "dbias", 1); + builder.Build(add__op1.out(), "dweight", 2); +} + +void BuildProgram1(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build(std::vector{32, 32}, 1.5); + paddle::dialect::FullOp full_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(full_input_op1.out(), + full_weight_op1.out()); + + paddle::dialect::FullOp full_d_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::FullOp full_d_out_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::MatmulGradOp matmul_grad_op1 = + builder.Build( + full_input_op1.out(), full_weight_op1.out(), full_d_out_op1.out()); + + paddle::dialect::Add_Op add__op1 = builder.Build( + full_d_weight_op1.out(), matmul_grad_op1.y_grad()); + + builder.Build(matmul_op1.out(), "out", 0); + builder.Build(add__op1.out(), "dweight", 1); +} + +void BuildProgram2(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build(std::vector{32, 32}, 1.5); + paddle::dialect::FullOp full_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(full_input_op1.out(), + full_weight_op1.out()); + + paddle::dialect::FullOp full_d_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::FullOp full_d_out_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::MatmulOp matmul_op2 = + builder.Build( + full_input_op1.out(), full_d_out_op1.out(), true, false); + + paddle::dialect::Add_Op add__op1 = builder.Build( + full_d_weight_op1.out(), matmul_op2.out()); + + builder.Build(matmul_op1.out(), "out", 0); + builder.Build(add__op1.out(), "dweight", 1); +} + +void BuildProgram3(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build(std::vector{32, 32}, 1.5); + paddle::dialect::FullOp full_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + paddle::dialect::FullOp full_bias_op1 = + builder.Build(std::vector{32}, 1.0); + + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(full_input_op1.out(), + full_weight_op1.out()); + paddle::dialect::AddOp add_op1 = builder.Build( + matmul_op1.out(), full_bias_op1.out()); + + paddle::dialect::FullOp full_d_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::FullOp full_d_out_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::AddGradOp add_grad_op1 = + builder.Build( + matmul_op1.out(), full_bias_op1.out(), full_d_out_op1.out()); + + paddle::dialect::MatmulOp matmul_op2 = + builder.Build( + add_grad_op1.x_grad(), full_weight_op1.out(), false, true); + + paddle::dialect::MatmulOp matmul_op3 = + builder.Build( + full_input_op1.out(), add_grad_op1.x_grad(), true, false); + + paddle::dialect::Add_Op add__op1 = builder.Build( + full_d_weight_op1.out(), matmul_op3.out()); + + builder.Build(add_op1.out(), "out", 0); + builder.Build(add_grad_op1.y_grad(), "dbias", 1); + builder.Build(add__op1.out(), "dweight", 2); + builder.Build(matmul_op2.out(), "dx", 3); +} + +bool verify_pass(const pir::Program &program) { + for (auto op : *(program.block())) { + if (op->name() == paddle::dialect::FusedLinearParamGradAddOp::name()) { + return true; + } + } + return false; +} + +TEST(DrrTest, FusedLinearParamGradAdd0) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram0(builder); + + EXPECT_EQ(program.block()->size(), 13u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateFusedLinearParamGradAddPass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(verify_pass(program), true); +} + +TEST(DrrTest, FusedLinearParamGradAdd1) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram1(builder); + + EXPECT_EQ(program.block()->size(), 9u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateFusedLinearParamGradAddPass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(verify_pass(program), true); +} + +TEST(DrrTest, FusedLinearParamGradAdd2) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram2(builder); + + EXPECT_EQ(program.block()->size(), 9u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateFusedLinearParamGradAddPass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(verify_pass(program), true); +} + +TEST(DrrTest, FusedLinearParamGradAdd3) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram3(builder); + + EXPECT_EQ(program.block()->size(), 15u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateFusedLinearParamGradAddPass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(verify_pass(program), true); +} diff --git a/test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc b/test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc new file mode 100644 index 0000000000000..3ef77cd1f9665 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc @@ -0,0 +1,144 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build(std::vector{1, 512, 64}, + 1.5); + // linear 1 + paddle::dialect::FullOp full_weight_op1 = + builder.Build(std::vector{64, 64}, 1.5); + paddle::dialect::FullOp full_bias_op1 = + builder.Build(std::vector{64}, 1.0); + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(full_input_op1.out(), + full_weight_op1.out()); + paddle::dialect::AddOp add_op1 = builder.Build( + matmul_op1.out(), full_bias_op1.out()); + // linear 2 + paddle::dialect::FullOp full_weight_op2 = + builder.Build(std::vector{64, 128}, + 1.5); + paddle::dialect::FullOp full_bias_op2 = + builder.Build(std::vector{128}, 1.0); + paddle::dialect::MatmulOp matmul_op2 = + builder.Build(add_op1.out(), + full_weight_op2.out()); + paddle::dialect::AddOp add_op2 = builder.Build( + matmul_op2.out(), full_bias_op2.out()); + paddle::dialect::ReluOp relu_op = + builder.Build(add_op2.out()); + // linear 3 + paddle::dialect::FullOp full_weight_op3 = + builder.Build(std::vector{128, 64}, + 1.5); + paddle::dialect::FullOp full_bias_op3 = + builder.Build(std::vector{64}, 1.0); + paddle::dialect::MatmulOp matmul_op3 = + builder.Build(relu_op.out(), + full_weight_op3.out()); + paddle::dialect::AddOp add_op3 = builder.Build( + matmul_op3.out(), full_bias_op3.out()); + paddle::dialect::GeluOp gelu_op1 = + builder.Build(add_op3.out()); + // linear 4 + paddle::dialect::FullOp full_weight_op4 = + builder.Build(std::vector{64, 64}, 1.5); + paddle::dialect::FullOp full_bias_op4 = + builder.Build(std::vector{64}, 1.0); + paddle::dialect::MatmulOp matmul_op4 = + builder.Build(gelu_op1.out(), + full_weight_op4.out()); + paddle::dialect::AddOp add_op4 = builder.Build( + matmul_op4.out(), full_bias_op4.out()); + paddle::dialect::GeluOp gelu_op2 = + builder.Build(add_op4.out()); + + // backward + paddle::dialect::FullOp full_grad_op = builder.Build( + std::vector{1, 512, 64}, 1.0); + + paddle::dialect::GeluGradOp gelu_op2_grad = + builder.Build( + add_op4.out(), full_grad_op.out(), false); + // backward linear 4 + paddle::dialect::AddGradOp add_op4_grad = + builder.Build( + matmul_op4.out(), full_bias_op4.out(), gelu_op2_grad.x_grad()); + paddle::dialect::MatmulGradOp matmul_op4_grad = + builder.Build( + gelu_op1.out(), full_weight_op4.out(), add_op4_grad.x_grad()); + + paddle::dialect::GeluGradOp gelu_op1_grad = + builder.Build( + add_op3.out(), matmul_op4_grad.x_grad(), false); + // backward linear 3 + paddle::dialect::AddGradOp add_op3_grad = + builder.Build( + matmul_op3.out(), full_bias_op3.out(), gelu_op1_grad.x_grad()); + paddle::dialect::MatmulGradOp matmul_op3_grad = + builder.Build( + relu_op.out(), full_weight_op3.out(), add_op3_grad.x_grad()); + + paddle::dialect::ReluGradOp relu_op_grad = + builder.Build(relu_op.out(), + matmul_op3_grad.x_grad()); + // backward linear 2 + paddle::dialect::AddGradOp add_op2_grad = + builder.Build( + matmul_op2.out(), full_bias_op2.out(), relu_op_grad.x_grad()); + paddle::dialect::MatmulGradOp matmul_op2_grad = + builder.Build( + add_op1.out(), full_weight_op2.out(), add_op2_grad.x_grad()); + // backward linear 1 + paddle::dialect::AddGradOp add_op1_grad = + builder.Build( + matmul_op1.out(), full_bias_op1.out(), matmul_op2_grad.x_grad()); + paddle::dialect::MatmulGradOp matmul_op1_grad = + builder.Build( + full_input_op1.out(), full_weight_op1.out(), add_op1_grad.x_grad()); + + builder.Build(gelu_op2.out(), "out", 0); + builder.Build(matmul_op1_grad.x_grad(), "dx", 1); +} + +TEST(DrrTest, FusedLinear) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + + EXPECT_EQ(program.block()->size(), 34u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateFusedGemmEpiloguePass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 22u); +} diff --git a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc new file mode 100644 index 0000000000000..32b0ed6935515 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc @@ -0,0 +1,332 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +/* Source pattern: + input1 + / | \ \ \ + / | \ \ \ + full / | | \ \ full_tmp + / | transpos1 | trans2 trans3 \ / | + / | / | | | | \ / | + softmax1 | / | | | | \ / | + \ | / softmax2 | | | add1 | + \ | / \ | \ / | | + layernorm matmul2 matmul1 \ | + / | \ | | \ | + / | \ \ / \ | + / | \ matmul3 add2 + | | | / | \ | + | | | / | \ | + | | | / | \ | + | | | trans4 trans5 trans6 | + | | | | | | | + | | | relu1 softmax3 softmax4 relu2 + | | | | | | | + output0 output1 output2 output3 output4 output5 output6 +*/ + +class SameTypeBindingTestPattern + // This class is for test cases of the same type of OP. + // (without considering the computational logic between OPs, + // only focusing on the process of matching and replacing) + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern src = ctx->SourcePattern(); + + // path 1 + const auto &transpose_1 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_1")}}); + src.Tensor("transpose_1_out") = transpose_1(src.Tensor("input_1")); + const auto &softmax_2 = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_2_axis")}}); + src.Tensor("softmax_2_out") = softmax_2(src.Tensor("transpose_1_out")); + const auto &matmul_2 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_2_tradnspose_x")}, + {"transpose_y", src.Attr("matmul_2_transpose_y")}}); + src.Tensor("matmul_2_out") = + matmul_2(src.Tensor("softmax_2_out"), src.Tensor("input_1")); + + // path 2 + const auto &full_1 = src.Op("pd_op.full", + {{"shape", src.Attr("shape_1")}, + {"value", src.Attr("value_1")}, + {"dtype", src.Attr("dtype_1")}, + {"place", src.Attr("place_1")}}); + src.Tensor("full_1_out") = full_1(); + const auto &softmax_1 = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_1_axis")}}); + src.Tensor("softmax_1_out") = softmax_1(src.Tensor("full_1_out")); + const auto &layernorm_1 = + src.Op("pd_op.layer_norm", + {{"epsilon", src.Attr("layernorm_epsilon")}, + {"begin_norm_axis", src.Attr("layernorm_begin_norm_axis")}}); + layernorm_1({&src.Tensor("transpose_1_out"), + &src.Tensor("full_1_out"), + &src.Tensor("softmax_1_out")}, + {&src.Tensor("output0"), + &src.Tensor("output1"), + &src.Tensor("output2")}); + + // path 3 + const auto &transpose_2 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_2")}}); + const auto &transpose_3 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_3")}}); + const auto &matmul_1 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_1_transpose_x")}, + {"transpose_y", src.Attr("matmul_1_transpose_y")}}); + src.Tensor("matmul_1_out") = matmul_1(transpose_2(src.Tensor("input_1")), + transpose_3(src.Tensor("input_1"))); + const auto &matmul_3 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_3_transpose_x")}, + {"transpose_y", src.Attr("matmul_3_transpose_y")}}); + src.Tensor("matmul_3_out") = + matmul_3(src.Tensor("matmul_2_out"), src.Tensor("matmul_1_out")); + const auto &transpose_4 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_4")}}); + const auto &transpose_5 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_5")}}); + const auto &transpose_6 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_6")}}); + const auto &relu_1 = src.Op("pd_op.relu"); + const auto &softmax_3 = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_3_axis")}}); + const auto &softmax_4 = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_4_axis")}}); + src.Tensor("output3") = relu_1(transpose_4(src.Tensor("matmul_3_out"))); + src.Tensor("output4") = softmax_3(transpose_5(src.Tensor("matmul_3_out"))); + src.Tensor("output5") = softmax_4(transpose_6(src.Tensor("matmul_3_out"))); + + // path 4 + const auto &full_tmp = src.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + src.Tensor("full_tmp_out") = full_tmp(); + const auto &add_1 = src.Op("pd_op.add"); + src.Tensor("add_1_out") = + add_1(src.Tensor("input_1"), src.Tensor("full_tmp_out")); + const auto &add_2 = src.Op("pd_op.add"); + src.Tensor("add_2_out") = + add_2(src.Tensor("add_1_out"), src.Tensor("full_tmp_out")); + const auto &relu_2 = src.Op("pd_op.relu"); + src.Tensor("output6") = relu_2(src.Tensor("add_2_out")); + + pir::drr::ResultPattern res = src.ResultPattern(); + const auto &transpose_7 = + res.Op("pd_op.transpose", {{"perm", src.Attr("perm_4")}}); + res.Tensor("output0") = transpose_7(res.Tensor("input_1")); + const auto &transpose_8 = + res.Op("pd_op.transpose", {{"perm", src.Attr("perm_5")}}); + res.Tensor("output1") = transpose_8(res.Tensor("input_1")); + const auto &full_2 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + const auto &full_3 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + const auto &full_4 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + const auto &full_5 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + const auto &full_6 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + res.Tensor("output2") = full_2(); + res.Tensor("output3") = full_3(); + res.Tensor("output4") = full_4(); + res.Tensor("output5") = full_5(); + res.Tensor("output6") = full_6(); + } +}; + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build(std::vector{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + // path 1 + paddle::dialect::TransposeOp transpose_op1 = + builder.Build(full_input_op1.out(), + std::vector{0, 1, 2}); + + paddle::dialect::SoftmaxOp softmax_op2 = + builder.Build(transpose_op1.out(), -1); + + paddle::dialect::MatmulOp matmul_op2 = + builder.Build(softmax_op2.out(), + full_input_op1.out()); + + // path 2 + paddle::dialect::FullOp full_op_scale = + builder.Build(std::vector{48}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + paddle::dialect::SoftmaxOp softmax_op_bias = + builder.Build(full_op_scale.out(), -1); + paddle::dialect::LayerNormOp layernorm_op1 = + builder.Build( + transpose_op1.out(), full_op_scale.out(), softmax_op_bias.out()); + + // path 3 + paddle::dialect::TransposeOp transpose_op2 = + builder.Build(full_input_op1.out(), + std::vector{0, 1, 2}); + + paddle::dialect::TransposeOp transpose_op3 = + builder.Build(full_input_op1.out(), + std::vector{0, 1, 2}); + + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(transpose_op2.out(), + transpose_op3.out()); + + paddle::dialect::MatmulOp matmul_op3 = + builder.Build(matmul_op2.out(), + matmul_op1.out()); + + paddle::dialect::TransposeOp transpose_op4 = + builder.Build(matmul_op3.out(), + std::vector{0, 1, 2}); + + paddle::dialect::ReluOp relu_op1 = + builder.Build(transpose_op4.out()); + + paddle::dialect::TransposeOp transpose_op5 = + builder.Build(matmul_op3.out(), + std::vector{0, 1, 2}); + + paddle::dialect::SoftmaxOp softmax_op3 = + builder.Build(transpose_op5.out(), -1); + + paddle::dialect::TransposeOp transpose_op6 = + builder.Build(matmul_op3.out(), + std::vector{0, 1, 2}); + + paddle::dialect::SoftmaxOp softmax_op4 = + builder.Build(transpose_op6.out(), -1); + + // path 4 + paddle::dialect::FullOp full_input_op2 = + builder.Build(std::vector{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::AddOp add_op1 = builder.Build( + full_input_op1.out(), full_input_op2.out()); + + paddle::dialect::AddOp add_op2 = builder.Build( + add_op1.out(), full_input_op2.out()); + + paddle::dialect::ReluOp relu_op2 = + builder.Build(add_op2.out()); + + // tail + paddle::dialect::MatmulOp matmul_op4 = + builder.Build(layernorm_op1.variance(), + layernorm_op1.mean()); + + paddle::dialect::MatmulOp matmul_op5 = + builder.Build(relu_op1.out(), + softmax_op3.out()); + + paddle::dialect::MatmulOp matmul_op6 = + builder.Build(softmax_op4.out(), + relu_op2.out()); + + builder.Build(matmul_op4.out(), "out1", 0); + builder.Build(matmul_op5.out(), "out2", 1); + builder.Build(matmul_op6.out(), "out3", 2); +} + +class DrrPatternRewritePass : public pir::Pass { + public: + DrrPatternRewritePass() : pir::Pass("DrrPatternRewritePass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(SameTypeBindingTestPattern().Build(context)); + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +TEST(DrrTest, drr_demo) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + + EXPECT_EQ(program.block()->size(), 27u); + + pir::PassManager pm(ctx); + pm.AddPass(std::make_unique()); + pm.AddPass(pir::CreateDeadCodeEliminationPass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 13u); +} diff --git a/test/cpp/pir/pattern_rewrite/drr_test.cc b/test/cpp/pir/pattern_rewrite/drr_test.cc new file mode 100644 index 0000000000000..446de9519fc3b --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_test.cc @@ -0,0 +1,239 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +class RemoveRedundentReshapePattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // Source patterns + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &reshape1 = pat.Op("pd_op.reshape"); + const auto &reshape2 = pat.Op("pd_op.reshape"); + + reshape1({&pat.Tensor("arg0"), &pat.Tensor("shape0")}, + {&pat.Tensor("out1"), &pat.Tensor("xshape_0")}); + reshape2({&pat.Tensor("out1"), &pat.Tensor("shape1")}, + {&pat.Tensor("ret"), &pat.Tensor("xshape_1")}); + + // Result patterns + pir::drr::ResultPattern res = pat.ResultPattern(); + res.Op("pd_op.reshape")({&res.Tensor("arg0"), &res.Tensor("shape1")}, + {&res.Tensor("ret"), &res.Tensor("xshape_1")}); + } +}; + +class FoldExpandToConstantPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // Source Pattern + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &full1 = pat.Op("pd_op.full", + {{"shape", pat.Attr("shape_1")}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); + const auto &full_int_array1 = + pat.Op("pd_op.full_int_array", + {{"value", pat.Attr("expand_shape_value")}, + {"dtype", pat.Attr("dtype_2")}, + {"place", pat.Attr("place_2")}}); + const auto &expand = pat.Op("pd_op.expand"); + pat.Tensor("ret") = expand(full1(), full_int_array1()); + + // Result patterns + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &new_perm_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> phi::IntArray { + auto shape = + match_ctx.Attr>("expand_shape_value"); + + return phi::IntArray(shape); + }); + const auto &full2 = res.Op("pd_op.full", + {{"shape", new_perm_attr}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); + res.Tensor("ret") = full2(); + } +}; + +class RemoveRedundentTransposePattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &transpose1 = + pat.Op("pd_op.transpose", {{"perm", pat.Attr("perm_1")}}); + const auto &transpose2 = + pat.Op("pd_op.transpose", {{"perm", pat.Attr("perm_2")}}); + + pat.Tensor("ret") = transpose2(transpose1(pat.Tensor("arg_transpose"))); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &new_perm_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> std::vector { + const auto &perm1 = match_ctx.Attr>("perm_1"); + const auto &perm2 = match_ctx.Attr>("perm_2"); + std::vector new_perm; + for (int v : perm2) { + new_perm.emplace_back(perm1[v]); + } + return new_perm; + }); + const auto &tranpose_continuous = + res.Op("pd_op.transpose", {{"perm", new_perm_attr}}); + + res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose")); + } +}; + +class RemoveRedundentCastPattern + : public pir::drr::DrrPatternBase { + void operator()(pir::drr::DrrPatternContext *ctx) const override { + auto pat = ctx->SourcePattern(); + pat.Tensor("tmp") = pat.Op( + "pd_op.cast", {{"dtype", pat.Attr("dtype1")}})(pat.Tensor("arg0")); + pat.Tensor("ret") = pat.Op( + "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(pat.Tensor("tmp")); + auto res = pat.ResultPattern(); + res.Tensor("ret") = res.Op( + "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); + } +}; + +class RemoveUselessCastPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + auto pat = ctx->SourcePattern(); + pat.Tensor("ret") = pat.Op("pd_op.cast")(pat.Tensor("arg0")); + pat.RequireEqual(pat.Tensor("ret").dtype(), pat.Tensor("arg0").dtype()); + auto res = pat.ResultPattern(); + res.Tensor("ret").Assign(res.Tensor("arg0")); + } +}; + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op = + builder.Build(std::vector{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullIntArrayOp full_int_array_op = + builder.Build( + std::vector{4, 3, 16, 16}, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::ExpandOp expand_op = + builder.Build(full_input_op.out(), + full_int_array_op.out()); + + paddle::dialect::ReshapeOp reshape_op1 = + builder.Build( + expand_op.out(), std::vector{16, 3, 4, 16}); + + paddle::dialect::ReshapeOp reshape_op2 = + builder.Build( + reshape_op1.out(), std::vector{16, 3, 4, 16}); + + paddle::dialect::ReluOp relu_op = + builder.Build(reshape_op2.out()); + + paddle::dialect::CastOp cast_op1 = builder.Build( + relu_op.out(), phi::DataType::FLOAT64); + + paddle::dialect::CastOp cast_op2 = builder.Build( + cast_op1.out(), phi::DataType::FLOAT32); + + paddle::dialect::TransposeOp transpose_op1 = + builder.Build(cast_op2.out(), + std::vector{0, 2, 1, 3}); + + paddle::dialect::TransposeOp transpose_op2 = + builder.Build(transpose_op1.out(), + std::vector{1, 0, 2, 3}); + + paddle::dialect::ReluOp relu_op_second = + builder.Build(transpose_op2.out()); + + builder.Build(relu_op_second.out(), "out", 0); +} + +class DrrPatternRewritePass : public pir::Pass { + public: + DrrPatternRewritePass() : pir::Pass("DrrPatternRewritePass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(RemoveRedundentReshapePattern().Build(context)); + ps.Add(RemoveRedundentTransposePattern().Build(context)); + ps.Add(RemoveRedundentCastPattern().Build(context)); + ps.Add(RemoveUselessCastPattern().Build(context)); + ps.Add(FoldExpandToConstantPattern().Build(context)); + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +TEST(DrrTest, drr_demo) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + + EXPECT_EQ(program.block()->size(), 14u); + + pir::PassManager pm(ctx); + pm.AddPass(std::make_unique()); + pm.AddPass(pir::CreateDeadCodeEliminationPass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 7u); +} diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc index 18644c08e21b7..d45a74f6fd0d1 100644 --- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc @@ -19,9 +19,10 @@ #include #include #include -#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/transforms/constant_folding_pass.h" +#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/pir/core/builder.h" @@ -42,8 +43,6 @@ #include "paddle/pir/pattern_rewrite/pattern_applicator.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" -#include "paddle/pir/transforms/dead_code_elimination_pass.h" -#include "paddle/pir/transforms/reorder_block_ops_pass.h" // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // paddle/fluid/pir/dialect/CMakeLists.txt. @@ -1108,12 +1107,14 @@ void BuildProgram(pir::Builder &builder) { // NOLINT builder.Build(transpose2_op.out(), "out", 0); } -// TODO(wilber): Add a normal test. TEST(pattern_rewrite, Patterns) { pir::IrContext *ctx = pir::IrContext::Instance(); + + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); auto *test_dialect = ctx->GetOrRegisterDialect(); test_dialect->RegisterOp(); - ctx->GetOrRegisterDialect(); + pir::Program program(ctx); pir::Builder builder = pir::Builder(ctx, program.block()); BuildProgram(builder); @@ -1122,20 +1123,19 @@ TEST(pattern_rewrite, Patterns) { pir::PassManager pm(ctx); pm.AddPass(std::make_unique()); - // pm.AddPass(ir::CreateConstantFoldingPass()); + pm.AddPass(pir::CreateConstantFoldingPass()); pm.AddPass(pir::CreateDeadCodeEliminationPass()); - pm.AddPass(pir::CreateReorderBlockOpsPass()); pm.EnablePassTiming(); - pm.EnableIRPrinting(); - // pm.EnableIRPrinting(std::make_unique( - // [](pir::Pass *pass, pir::Operation *op) { - // return pass->name() == "ConstantFoldingPass"; - // }, - // [](pir::Pass *pass, pir::Operation *op) { - // return pass->name() == "ConstantFoldingPass"; - // }, - // true, - // true)); + pm.EnableIRPrinting(std::make_unique( + [](pir::Pass *pass, pir::Operation *op) { + return pass->name() == "constant_folding_pass"; + }, + [](pir::Pass *pass, pir::Operation *op) { + return pass->name() == "constant_folding_pass"; + }, + true, + true)); CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 2u); } diff --git a/test/cpp/pir/shape_dialect/CMakeLists.txt b/test/cpp/pir/shape_dialect/CMakeLists.txt index 349d6a32dfa22..ec80962c1cb3a 100644 --- a/test/cpp/pir/shape_dialect/CMakeLists.txt +++ b/test/cpp/pir/shape_dialect/CMakeLists.txt @@ -3,7 +3,7 @@ paddle_test( SRCS shape_op_test.cc DEPS - pd_op_dialect + op_dialect_vjp pir gtest) @@ -12,7 +12,7 @@ paddle_test( SRCS shape_struct_test.cc DEPS - pd_op_dialect + op_dialect_vjp pir gtest) @@ -22,12 +22,11 @@ paddle_test( constraint_pass_test.cc DEPS gtest - pd_op_dialect + op_dialect_vjp pir) -set_tests_properties( - constraint_pass_test PROPERTIES ENVIRONMENT - "FLAGS_enable_new_ir_in_executor=true") +set_tests_properties(constraint_pass_test + PROPERTIES ENVIRONMENT "FLAGS_enable_pir_in_executor=true") if(WITH_ONNXRUNTIME AND WIN32) # Copy onnxruntime for some c++ test in Windows, since the test will diff --git a/test/cpp/pir/shape_dialect/constraint_pass_test.cc b/test/cpp/pir/shape_dialect/constraint_pass_test.cc index 860bf34a69ac4..4b5e660cf6f3b 100644 --- a/test/cpp/pir/shape_dialect/constraint_pass_test.cc +++ b/test/cpp/pir/shape_dialect/constraint_pass_test.cc @@ -21,13 +21,11 @@ #include #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/core/builtin_op.h" -#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/core/cast_utils.h" #include "paddle/pir/core/dialect.h" @@ -40,94 +38,72 @@ #include "paddle/pir/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/dialect/shape/ir/shape_op.h" #include "paddle/pir/dialect/shape/transforms/passes.h" -#include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" -pir::AttributeMap CreateAttributeMap( - const std::vector &attribute_names, - const std::vector &attributes) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::AttributeMap attr_map; - for (size_t i = 0; i < attribute_names.size(); i++) { - pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); - attr_map.insert( - std::pair(attribute_names[i], attr_value)); - } - return attr_map; -} +#include "test/cpp/pir/tools/test_pir_utils.h" -pir::Operation *CreateDenseTensorOp( - pir::IrContext *ctx, - const phi::DDim &dims, - const std::vector &attribute_names, - const std::vector &attributes, - const pir::Type &dtype = - pir::Float32Type::get(pir::IrContext::Instance())) { - std::vector op_inputs = {}; - phi::DataLayout data_layout = phi::DataLayout::NCHW; - phi::LoD lod = {{0, 1, 2}}; - size_t offset = 0; - std::vector op_output_types = { - paddle::dialect::DenseTensorType::get( - ctx, dtype, dims, data_layout, lod, offset)}; - pir::Operation *op = - pir::Operation::Create(op_inputs, - CreateAttributeMap(attribute_names, attributes), - op_output_types, - pir::OpInfo()); - return op; -} - -TEST(constraint_pass, materialize_and_build_shape) { +TEST(shape_constraint_pass, materialize_and_build_shape) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - pir::PassManager pm(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); - pir::Operation *op0 = CreateDenseTensorOp( - ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op0_attr"}, {"op0_name"}); - program.block()->push_back(op0); + + pir::Operation *op0 = + test::CreateDenseTensorOp(ctx, + {pir::ShapedTypeInterface::kDynamic, 2}, + {"op0_attr"}, + {"create_dense_tensor_op0"}); pir::Operation *op1 = - CreateDenseTensorOp(ctx, - {pir::ShapedTypeInterface::kDynamic, 2, 2}, - {"op1_attr"}, - {"op1_name"}); + test::CreateDenseTensorOp(ctx, + {pir::ShapedTypeInterface::kDynamic, 2, 2}, + {"op1_attr"}, + {"create_dense_tensor_op1"}); + program.block()->push_back(op0); program.block()->push_back(op1); - EXPECT_EQ(program.block()->size(), static_cast(2)); + EXPECT_EQ(program.block()->size(), 2u); + + std::stringstream ss1; + program.Print(ss1); + LOG(INFO) << " ================================================ Before Add " + "and Run Pass ================================================ "; + LOG(INFO) << ss1.str(); + + pir::PassManager pm(ctx); pm.AddPass(pir::CreateShapeOptimizationPass()); EXPECT_TRUE(pm.Run(&program)); // 5 ConstantOp + 5 TensorDim + 2 TieShape + op0 + op1 + 1 funcOp == 15 Ops. - EXPECT_EQ(program.block()->size(), static_cast(15)); - - std::stringstream ss; - program.Print(ss); + EXPECT_EQ(program.block()->size(), 15u); - LOG(INFO) << ss.str(); + std::stringstream ss2; + program.Print(ss2); + LOG(INFO) << " ================================================ After Add " + "and Run Pass ================================================ "; + LOG(INFO) << ss2.str(); } -TEST(constraint_pass, shape_computation_run) { +TEST(shape_constraint_pass, shape_computation_run) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - pir::PassManager pm(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); - ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - builder.Build(); - pir::Operation *op0 = - CreateDenseTensorOp(ctx, - {2}, - {"op0_attr"}, - {"op0_name"}, - pir::Int64Type::get(pir::IrContext::Instance())); + pir::Builder builder = ::pir::Builder(ctx, program.block()); + builder.Build(); + pir::Operation *op0 = test::CreateDenseTensorOp( + ctx, + {2}, + {"op0_attr"}, + {"op0_name"}, + pir::Int64Type::get(pir::IrContext::Instance())); program.block()->push_back(op0); - pir::Operation *op1 = CreateDenseTensorOp( + pir::Operation *op1 = test::CreateDenseTensorOp( ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op1_attr"}, {"op1_name"}); program.block()->push_back(op1); + pir::PassManager pm(ctx); pm.AddPass(pir::CreateShapeOptimizationPass()); EXPECT_TRUE(pm.Run(&program)); @@ -135,3 +111,5 @@ TEST(constraint_pass, shape_computation_run) { EXPECT_TRUE(mgr.Load()); EXPECT_TRUE(mgr.Save()); } + +// TODO(zhangbopd): ExpandShapeOfOpPattern etc. diff --git a/test/cpp/pir/shape_dialect/shape_op_test.cc b/test/cpp/pir/shape_dialect/shape_op_test.cc index 9d71e721fe72d..89a728beed9b7 100644 --- a/test/cpp/pir/shape_dialect/shape_op_test.cc +++ b/test/cpp/pir/shape_dialect/shape_op_test.cc @@ -16,119 +16,121 @@ #include #include #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/pir/core/block.h" -#include "paddle/pir/core/builder.h" -#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/core/dialect.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/program.h" #include "paddle/pir/dialect/shape/ir/shape_dialect.h" -#include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/dialect/shape/utils/symbol_table.h" +#include "test/cpp/pir/tools/test_pir_utils.h" -pir::AttributeMap CreateAttributeMap( - const std::vector &attribute_names, - const std::vector &attributes) { +TEST(shape_op, symbolic_dim_op) { pir::IrContext *ctx = pir::IrContext::Instance(); - pir::AttributeMap attr_map; - for (size_t i = 0; i < attribute_names.size(); i++) { - pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); - attr_map.insert( - std::pair(attribute_names[i], attr_value)); - } - return attr_map; -} + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::shape::SymbolicDimOp sym_dim_op1 = + builder.Build( + "S0", 10, false, false, false, false); + pir::shape::SymbolicDimOp sym_dim_op2 = + builder.Build( + "S1", 10, false, false, false, false); + + EXPECT_EQ(sym_dim_op1.GetDimSize(), 10); + EXPECT_EQ(sym_dim_op1.GetSymName(), "S0"); + EXPECT_FALSE(sym_dim_op1.GetKnownNegativeOne()); + EXPECT_FALSE(sym_dim_op1.GetKnownNonSizeOne()); + EXPECT_FALSE(sym_dim_op1.GetKnownNonSizeZero()); + EXPECT_FALSE(sym_dim_op1.GetKnownNonNegative()); + + EXPECT_FALSE(sym_dim_op1.IsDynamic()); + EXPECT_TRUE(sym_dim_op1.Merge(sym_dim_op2)); -pir::Operation *CreateDenseTensorOp( - pir::IrContext *ctx, - const phi::DDim &dims, - const std::vector &attribute_names, - const std::vector &attributes) { - std::vector op_inputs = {}; - pir::Type fp32_dtype = pir::Float32Type::get(ctx); - phi::DataLayout data_layout = phi::DataLayout::NCHW; - phi::LoD lod = {{0, 1, 2}}; - size_t offset = 0; - std::vector op_output_types = { - paddle::dialect::DenseTensorType::get( - ctx, fp32_dtype, dims, data_layout, lod, offset)}; - pir::Operation *op = - pir::Operation::Create(op_inputs, - CreateAttributeMap(attribute_names, attributes), - op_output_types, - pir::OpInfo()); - return op; + sym_dim_op1.SetDimSize(20); + sym_dim_op1.SetSymName("S2"); + sym_dim_op1.UpdateKnownNegativeOne(true); + sym_dim_op1.UpdateKnownNonSizeOne(true); + sym_dim_op1.UpdateKnownNonSizeZero(true); + sym_dim_op1.UpdateKnownNonNegative(true); + + EXPECT_FALSE(sym_dim_op1.Merge(sym_dim_op2)); + + EXPECT_EQ(sym_dim_op1.GetDimSize(), 20); + EXPECT_EQ(sym_dim_op1.GetSymName(), "S2"); + EXPECT_TRUE(sym_dim_op1.GetKnownNegativeOne()); + EXPECT_TRUE(sym_dim_op1.GetKnownNonSizeOne()); + EXPECT_TRUE(sym_dim_op1.GetKnownNonSizeZero()); + EXPECT_TRUE(sym_dim_op1.GetKnownNonNegative()); } -TEST(shape_op, dim) { +TEST(shape_op, dim_op) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); pir::Builder builder = pir::Builder(ctx, program.block()); - pir::dialect::DimOp dim_op = builder.Build("S0"); + pir::shape::DimOp dim_op = builder.Build("S0"); pir::OpResult res = dim_op.out(); - EXPECT_EQ(dim_op.getName(), "S0"); - dim_op.setName("S1"); - EXPECT_EQ(dim_op.getName(), "S1"); + EXPECT_EQ(dim_op.GetName(), "S0"); + dim_op.SetName("S1"); + EXPECT_EQ(dim_op.GetName(), "S1"); EXPECT_EQ(res.owner(), dim_op.operation()); EXPECT_EQ(res.type(), pir::IndexType::get(ctx)); } -TEST(shape_op, tie_product_equal) { +TEST(shape_op, tie_product_equal_op) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); pir::Builder builder = pir::Builder(ctx, program.block()); pir::SymbolTable symbolt_table(program.module_op()); - pir::OpResult dim_op0 = builder.Build("S0").out(); - pir::OpResult dim_op1 = builder.Build("S1").out(); - pir::OpResult dim_op2 = builder.Build("S2").out(); - pir::OpResult dim_op3 = builder.Build("S3").out(); - pir::OpResult dim_op4 = builder.Build("S4").out(); + pir::OpResult dim_op0 = builder.Build("S0").out(); + pir::OpResult dim_op1 = builder.Build("S1").out(); + pir::OpResult dim_op2 = builder.Build("S2").out(); + pir::OpResult dim_op3 = builder.Build("S3").out(); + pir::OpResult dim_op4 = builder.Build("S4").out(); - pir::dialect::TieProductEqualOp tie_product_equal = - builder.Build( + pir::shape::TieProductEqualOp tie_product_equal_op = + builder.Build( 2, 3, std::vector{dim_op0, dim_op1, dim_op2, dim_op3, dim_op4}); - std::vector lhs = tie_product_equal.lhs(); - std::vector rhs = tie_product_equal.rhs(); + std::vector lhs = tie_product_equal_op.lhs(); + std::vector rhs = tie_product_equal_op.rhs(); std::vector lhs_ref{dim_op0, dim_op1}; std::vector rhs_ref{dim_op2, dim_op3, dim_op4}; - EXPECT_EQ(symbolt_table.insert(tie_product_equal), "tie_product_equal"); + EXPECT_EQ(symbolt_table.insert(tie_product_equal_op), "tie_product_equal"); EXPECT_EQ( - symbolt_table.Lookup("tie_product_equal") + symbolt_table.Lookup("tie_product_equal") .size(), static_cast(1)); - EXPECT_EQ(symbolt_table.Lookup( + EXPECT_EQ(symbolt_table.Lookup( "tie_product_equal")[0], - tie_product_equal); + tie_product_equal_op); EXPECT_EQ(lhs, lhs_ref); EXPECT_EQ(rhs, rhs_ref); } -TEST(shape_op, tie_shape) { +TEST(shape_op, tie_shape_op) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); pir::Builder builder = pir::Builder(ctx, program.block()); - auto op = CreateDenseTensorOp( + auto op = test::CreateDenseTensorOp( ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); pir::OpResult res = op->result(0); - pir::dialect::TieShapeOp tie_shape_op = - builder.Build(res); - pir::Value tie_shape_op_value = tie_shape_op.value(); + pir::shape::TieShapeOp tie_shape_op = + builder.Build(res); + pir::Value tie_shape_op_input = tie_shape_op.input(); pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1"); @@ -137,28 +139,28 @@ TEST(shape_op, tie_shape) { auto array_attr = pir::ArrayAttribute::get(ctx, new_attrs); tie_shape_op->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), array_attr); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), array_attr); std::vector arr_attr_vec = tie_shape_op ->attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName()) + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName()) .AsVector(); - EXPECT_EQ(tie_shape_op_value, res); + EXPECT_EQ(tie_shape_op_input, res); EXPECT_EQ(arr_attr_vec.size(), static_cast(2)); EXPECT_EQ(arr_attr_vec[0].dyn_cast(), attr_s0); EXPECT_EQ(arr_attr_vec[1].dyn_cast(), attr_s1); EXPECT_TRUE(tie_shape_op->HasAttribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName())); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName())); } TEST(shape_op, func_op) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - pir::dialect::FuncOp func_op = builder.Build(); + pir::shape::FuncOp func_op = builder.Build(); auto func_block = func_op.block(); builder.SetInsertionPointToStart(func_block); builder.Build(pir::Int32Attribute::get(ctx, 2), @@ -168,19 +170,20 @@ TEST(shape_op, func_op) { EXPECT_EQ(func_block->size(), static_cast(1)); } -TEST(shape_op, tensor_dim) { +TEST(shape_op, tensor_dim_op) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); pir::Builder builder = pir::Builder(ctx, program.block()); - pir::Operation *op = CreateDenseTensorOp( + pir::Operation *op = test::CreateDenseTensorOp( ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); pir::OpResult res_dense_tensor_value = op->result(0); - pir::dialect::TensorDimOp tensor_dim_op0 = - builder.Build(res_dense_tensor_value, 0); + pir::shape::TensorDimOp tensor_dim_op0 = + builder.Build(res_dense_tensor_value, 0); pir::OpResult res0 = tensor_dim_op0.out(); + std::optional index0 = tensor_dim_op0.GetConstantIndex(); pir::OpResult index_value = builder @@ -188,14 +191,117 @@ TEST(shape_op, tensor_dim) { pir::Int64Attribute::get(pir::IrContext::Instance(), 1), pir::IndexType::get(pir::IrContext::Instance())) ->result(0); - pir::dialect::TensorDimOp tensor_dim_op1 = - builder.Build(res_dense_tensor_value, - index_value); + pir::shape::TensorDimOp tensor_dim_op1 = + builder.Build(res_dense_tensor_value, + index_value); pir::OpResult res1 = tensor_dim_op1.out(); EXPECT_EQ(res0.type(), pir::IndexType::get(ctx)); + EXPECT_EQ(*index0, static_cast(0)); EXPECT_EQ(res1.type(), pir::IndexType::get(ctx)); EXPECT_EQ(tensor_dim_op0.source(), res_dense_tensor_value); EXPECT_EQ(tensor_dim_op1.source(), res_dense_tensor_value); EXPECT_EQ(tensor_dim_op1.index(), index_value); } + +TEST(shape_op, shape_of_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + auto op = test::CreateDenseTensorOp( + ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); + pir::OpResult res = op->result(0); + + pir::shape::ShapeOfOp shape_of_op = builder.Build(res); + pir::Value shape_of_op_input = shape_of_op.input(); + EXPECT_EQ(shape_of_op_input, res); +} + +TEST(shape_op, from_elements_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::Int32Attribute int32_attr0 = builder.int32_attr(0); + pir::Int32Attribute int32_attr1 = builder.int32_attr(1); + pir::Int32Attribute int32_attr2 = builder.int32_attr(2); + pir::Int32Type int32_type = builder.int32_type(); + + pir::OpResult element0 = + builder.Build(int32_attr0, int32_type).out(); + pir::OpResult element1 = + builder.Build(int32_attr1, int32_type).out(); + pir::OpResult element2 = + builder.Build(int32_attr2, int32_type).out(); + + std::vector elements_in = {element0, element1, element2}; + + pir::shape::FromElementsOp from_elements_op = + builder.Build(elements_in); + + std::vector elements_out = from_elements_op.elements(); + for (size_t i = 0; i < elements_in.size(); i++) { + EXPECT_EQ(elements_in[i], elements_out[i]); + } +} + +TEST(shape_op, extract_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + auto op = test::CreateDenseTensorOp(ctx, {3, 2}, {"op_attr"}, {"op_name"}); + pir::OpResult res = op->result(0); + + pir::Int32Attribute int32_attr = builder.int32_attr(1); + pir::Int32Type int32_type = builder.int32_type(); + pir::OpResult indice = + builder.Build(int32_attr, int32_type).out(); + std::vector indice_in = {indice, indice}; + + pir::shape::ExtractOp extract_op = + builder.Build(res, indice_in); + pir::Value input = extract_op.tensor(); + std::vector indice_out = extract_op.indices(); + + EXPECT_EQ(input, res); + for (size_t i = 0; i < indice_in.size(); i++) { + EXPECT_EQ(indice_in[i], indice_out[i]); + } +} + +TEST(shape_op, constant_index_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::shape::ConstantIndexOp constant_index_op = + builder.Build(1); + + EXPECT_EQ( + constant_index_op.value().dyn_cast().data() == 1, + true); +} + +TEST(shape_op, index_cast_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::IndexAttribute index_attr = builder.index_attr(1); + pir::IndexType index_type = builder.index_type(); + pir::OpResult in = + builder.Build(index_attr, index_type).out(); + + pir::shape::IndexCastOp index_cast_op = + builder.Build(builder.int32_type(), in); + pir::Value index_cast_op_input = index_cast_op.in(); + + EXPECT_EQ(index_cast_op_input, in); +} diff --git a/test/cpp/pir/shape_dialect/shape_struct_test.cc b/test/cpp/pir/shape_dialect/shape_struct_test.cc index 64b58a399a150..a9020f5e31ad9 100644 --- a/test/cpp/pir/shape_dialect/shape_struct_test.cc +++ b/test/cpp/pir/shape_dialect/shape_struct_test.cc @@ -15,97 +15,24 @@ #include #include #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/pir/core/block.h" #include "paddle/pir/core/builder.h" -#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/core/dialect.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/program.h" #include "paddle/pir/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/dialect/shape/ir/shape_op.h" -#include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/dialect/shape/utils/symbol_table.h" -pir::AttributeMap CreateAttributeMap( - const std::vector &attribute_names, - const std::vector &attributes) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::AttributeMap attr_map; - for (size_t i = 0; i < attribute_names.size(); i++) { - pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); - attr_map.insert( - std::pair(attribute_names[i], attr_value)); - } - return attr_map; -} - -pir::Operation *CreateDenseTensorOp( - pir::IrContext *ctx, - const phi::DDim &dims, - const std::vector &attribute_names, - const std::vector &attributes) { - std::vector op_inputs = {}; - pir::Type fp32_dtype = pir::Float32Type::get(ctx); - phi::DataLayout data_layout = phi::DataLayout::NCHW; - phi::LoD lod = {{0, 1, 2}}; - size_t offset = 0; - std::vector op_output_types = { - paddle::dialect::DenseTensorType::get( - ctx, fp32_dtype, dims, data_layout, lod, offset)}; - pir::Operation *op = - pir::Operation::Create(op_inputs, - CreateAttributeMap(attribute_names, attributes), - op_output_types, - pir::OpInfo()); - return op; -} - -TEST(shape_struct_test, symbolic_dim) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - pir::Builder builder = pir::Builder(ctx, program.block()); - - pir::dialect::SymbolicDim sym_dim1 = builder.Build( - "S0", 10, false, false, false, false); - pir::dialect::SymbolicDim sym_dim2 = builder.Build( - "S1", 10, false, false, false, false); - - EXPECT_EQ(sym_dim1.GetDimSize(), 10); - EXPECT_EQ(sym_dim1.GetSymName(), "S0"); - EXPECT_FALSE(sym_dim1.GetKnownNegativeOne()); - EXPECT_FALSE(sym_dim1.GetKnownNonSizeOne()); - EXPECT_FALSE(sym_dim1.GetKnownNonSizeZero()); - EXPECT_FALSE(sym_dim1.GetKnownNonNegative()); - - EXPECT_FALSE(sym_dim1.IsDynamic()); - EXPECT_TRUE(sym_dim1.Merge(sym_dim2)); - - sym_dim1.SetDimSize(20); - sym_dim1.SetSymName("S2"); - sym_dim1.UpdateKnownNegativeOne(true); - sym_dim1.UpdateKnownNonSizeOne(true); - sym_dim1.UpdateKnownNonSizeZero(true); - sym_dim1.UpdateKnownNonNegative(true); - - EXPECT_FALSE(sym_dim1.Merge(sym_dim2)); - - EXPECT_EQ(sym_dim1.GetDimSize(), 20); - EXPECT_EQ(sym_dim1.GetSymName(), "S2"); - EXPECT_TRUE(sym_dim1.GetKnownNegativeOne()); - EXPECT_TRUE(sym_dim1.GetKnownNonSizeOne()); - EXPECT_TRUE(sym_dim1.GetKnownNonSizeZero()); - EXPECT_TRUE(sym_dim1.GetKnownNonNegative()); -} +#include "test/cpp/pir/tools/test_pir_utils.h" TEST(shape_struct_test, symbolic_dim_product) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); pir::Builder builder = pir::Builder(ctx, program.block()); - pir::dialect::SymbolicDim sym_dim = builder.Build( + pir::shape::SymbolicDimOp sym_dim = builder.Build( "S0", pir::ShapedTypeInterface::kDynamic, false, false, false, false); pir::SymbolicDimProduct sym_dim_product1; pir::SymbolicDimProduct sym_dim_product2; @@ -119,39 +46,39 @@ TEST(shape_struct_test, symbolic_dim_product) { TEST(shape_struct_test, symbolic_dim_table) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); pir::Builder builder = pir::Builder(ctx, program.block()); - pir::dialect::SymbolicDim sym_dim = builder.Build( + pir::shape::SymbolicDimOp sym_dim = builder.Build( "S0", 10, false, false, false, false); pir::SymbolTable symbol_table(program.module_op()); EXPECT_EQ(symbol_table.insert(sym_dim), "S0"); - EXPECT_EQ(symbol_table.Lookup("S0"), sym_dim); + EXPECT_EQ(symbol_table.Lookup("S0"), sym_dim); EXPECT_EQ(symbol_table.getOp(), program.module_op()); - EXPECT_FALSE(symbol_table.Lookup("S1")); + EXPECT_FALSE(symbol_table.Lookup("S1")); } TEST(shape_struct_test, symbolic_dim_mgr_simple) { /******************************************************/ - /* Mgr simple version, only SymbolicDim related func. */ + /* Mgr simple version, only SymbolicDimOp related func. */ /******************************************************/ pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); pir::SymbolicDimMgr sym_dim_mgr(program.module_op()); - pir::dialect::SymbolicDim sym_dim_s0 = sym_dim_mgr.NewSymbolicDim(); - pir::dialect::SymbolicDim sym_dim_s1 = sym_dim_mgr.NewSymbolicDim(); - pir::dialect::SymbolicDim sym_dim_c10 = + pir::shape::SymbolicDimOp sym_dim_s0 = sym_dim_mgr.NewSymbolicDim(); + pir::shape::SymbolicDimOp sym_dim_s1 = sym_dim_mgr.NewSymbolicDim(); + pir::shape::SymbolicDimOp sym_dim_c10 = sym_dim_mgr.NewConstantSymbolicDim(10); sym_dim_mgr.MapSymbolicDimEqual(sym_dim_s0, sym_dim_s1); - auto op = CreateDenseTensorOp( + auto op = test::CreateDenseTensorOp( ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); pir::Value res = op->result(0); - std::vector sym_dim_vec = + std::vector sym_dim_vec = sym_dim_mgr.CreateSymbolicDimsForRankedValue(res); EXPECT_EQ(sym_dim_s0.GetSymName(), "S0"); @@ -161,9 +88,9 @@ TEST(shape_struct_test, symbolic_dim_mgr_simple) { EXPECT_EQ(sym_dim_c10.GetDimSize(), 10); EXPECT_EQ(sym_dim_vec[0].GetSymName(), "S2"); EXPECT_EQ(sym_dim_vec[1].GetSymName(), "C2"); - EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("S0"), + EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("S0"), sym_dim_s0); - EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("C10"), + EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("C10"), sym_dim_c10); EXPECT_EQ(sym_dim_mgr.GetRootSymbolicDim(sym_dim_s1), sym_dim_s0); EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s0, sym_dim_s1)); @@ -176,47 +103,47 @@ TEST(shape_struct_test, symbolic_dim_mgr_complex) { /***************************************************************/ pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); pir::SymbolicDimMgr sym_dim_mgr(program.module_op()); auto func_op = - sym_dim_mgr.symbolTable().getOp()->dyn_cast(); + sym_dim_mgr.symbolTable().getOp()->dyn_cast(); pir::Builder builder = pir::Builder(ctx, func_op.block()); - pir::dialect::SymbolicDim sym_dim_s0 = sym_dim_mgr.NewSymbolicDim("S0"); - pir::dialect::SymbolicDim sym_dim_s1 = sym_dim_mgr.NewSymbolicDim("S1"); - pir::dialect::SymbolicDim sym_dim_s2 = sym_dim_mgr.NewSymbolicDim("S2"); - pir::dialect::SymbolicDim sym_dim_s3 = sym_dim_mgr.NewSymbolicDim("S3"); - pir::dialect::SymbolicDim sym_dim_s4 = sym_dim_mgr.NewSymbolicDim("S4"); - pir::dialect::SymbolicDim sym_dim_s5 = sym_dim_mgr.NewSymbolicDim("S5"); - pir::dialect::SymbolicDim sym_dim_s6 = sym_dim_mgr.NewSymbolicDim("S6"); - pir::dialect::SymbolicDim sym_dim_s7 = sym_dim_mgr.NewSymbolicDim("S7"); - pir::dialect::SymbolicDim sym_dim_s8 = sym_dim_mgr.NewSymbolicDim("S8"); - pir::dialect::SymbolicDim sym_dim_s9 = sym_dim_mgr.NewSymbolicDim("S9"); - pir::dialect::SymbolicDim sym_dim_s10 = sym_dim_mgr.NewSymbolicDim("S10"); - pir::dialect::SymbolicDim sym_dim_s11 = sym_dim_mgr.NewSymbolicDim("S11"); - pir::dialect::SymbolicDim sym_dim_s12 = sym_dim_mgr.NewSymbolicDim("S12"); - pir::dialect::SymbolicDim sym_dim_c10 = + pir::shape::SymbolicDimOp sym_dim_s0 = sym_dim_mgr.NewSymbolicDim("S0"); + pir::shape::SymbolicDimOp sym_dim_s1 = sym_dim_mgr.NewSymbolicDim("S1"); + pir::shape::SymbolicDimOp sym_dim_s2 = sym_dim_mgr.NewSymbolicDim("S2"); + pir::shape::SymbolicDimOp sym_dim_s3 = sym_dim_mgr.NewSymbolicDim("S3"); + pir::shape::SymbolicDimOp sym_dim_s4 = sym_dim_mgr.NewSymbolicDim("S4"); + pir::shape::SymbolicDimOp sym_dim_s5 = sym_dim_mgr.NewSymbolicDim("S5"); + pir::shape::SymbolicDimOp sym_dim_s6 = sym_dim_mgr.NewSymbolicDim("S6"); + pir::shape::SymbolicDimOp sym_dim_s7 = sym_dim_mgr.NewSymbolicDim("S7"); + pir::shape::SymbolicDimOp sym_dim_s8 = sym_dim_mgr.NewSymbolicDim("S8"); + pir::shape::SymbolicDimOp sym_dim_s9 = sym_dim_mgr.NewSymbolicDim("S9"); + pir::shape::SymbolicDimOp sym_dim_s10 = sym_dim_mgr.NewSymbolicDim("S10"); + pir::shape::SymbolicDimOp sym_dim_s11 = sym_dim_mgr.NewSymbolicDim("S11"); + pir::shape::SymbolicDimOp sym_dim_s12 = sym_dim_mgr.NewSymbolicDim("S12"); + pir::shape::SymbolicDimOp sym_dim_c10 = sym_dim_mgr.NewConstantSymbolicDim(10); - pir::dialect::SymbolicDim sym_dim_c20 = + pir::shape::SymbolicDimOp sym_dim_c20 = sym_dim_mgr.NewConstantSymbolicDim(20); - pir::OpResult dim_op_s0 = builder.Build("S0").out(); - pir::OpResult dim_op_s1 = builder.Build("S1").out(); - pir::OpResult dim_op_s2 = builder.Build("S2").out(); - pir::OpResult dim_op_s3 = builder.Build("S3").out(); - pir::OpResult dim_op_s4 = builder.Build("S4").out(); - pir::OpResult dim_op_s5 = builder.Build("S5").out(); - pir::OpResult dim_op_s6 = builder.Build("S6").out(); - pir::OpResult dim_op_s7 = builder.Build("S7").out(); - pir::OpResult dim_op_s8 = builder.Build("S8").out(); - pir::OpResult dim_op_s9 = builder.Build("S9").out(); - pir::OpResult dim_op_s10 = builder.Build("S10").out(); - pir::OpResult dim_op_s11 = builder.Build("S11").out(); - pir::OpResult dim_op_c10 = builder.Build("C10").out(); - pir::OpResult dim_op_c20 = builder.Build("C20").out(); + pir::OpResult dim_op_s0 = builder.Build("S0").out(); + pir::OpResult dim_op_s1 = builder.Build("S1").out(); + pir::OpResult dim_op_s2 = builder.Build("S2").out(); + pir::OpResult dim_op_s3 = builder.Build("S3").out(); + pir::OpResult dim_op_s4 = builder.Build("S4").out(); + pir::OpResult dim_op_s5 = builder.Build("S5").out(); + pir::OpResult dim_op_s6 = builder.Build("S6").out(); + pir::OpResult dim_op_s7 = builder.Build("S7").out(); + pir::OpResult dim_op_s8 = builder.Build("S8").out(); + pir::OpResult dim_op_s9 = builder.Build("S9").out(); + pir::OpResult dim_op_s10 = builder.Build("S10").out(); + pir::OpResult dim_op_s11 = builder.Build("S11").out(); + pir::OpResult dim_op_c10 = builder.Build("C10").out(); + pir::OpResult dim_op_c20 = builder.Build("C20").out(); pir::OpResult constant = builder .Build(pir::Int32Attribute::get(ctx, 2), @@ -224,62 +151,62 @@ TEST(shape_struct_test, symbolic_dim_mgr_complex) { ->result(0); // Mark S1 == S2. - builder.Build( + builder.Build( 2, 2, std::vector{constant, dim_op_s1, dim_op_s2, constant}); // Mark S0 * S1 == S2 * S3, For check S0 == S3. - builder.Build( + builder.Build( 2, 2, std::vector{dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3}); // Mark S4 * S0 * S1 == S2 * S3 * S5, For check S4 == S5. - builder.Build( + builder.Build( 3, 3, std::vector{ dim_op_s4, dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3, dim_op_s5}); // For check S6 == C10 * C20. - builder.Build( + builder.Build( 1, 2, std::vector{dim_op_s6, dim_op_c10, dim_op_c20}); // Mark C10 * S0 * S1 == S2 * S3 * S7, for check C10 == S7. - builder.Build( + builder.Build( 3, 3, std::vector{ dim_op_c10, dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3, dim_op_s7}); // For unsimplify product case: S8 * S9 == S10 * S11 - builder.Build( + builder.Build( 2, 2, std::vector{dim_op_s8, dim_op_s9, dim_op_s10, dim_op_s11}); - auto op = CreateDenseTensorOp(ctx, - {pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic}, - {"op0_attr"}, - {"op0_name"}); - auto op_ = CreateDenseTensorOp(ctx, - {pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - 10, - 20}, - {"op1_attr"}, - {"op1_name"}); + auto op = test::CreateDenseTensorOp(ctx, + {pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic}, + {"op0_attr"}, + {"op0_name"}); + auto op_ = test::CreateDenseTensorOp(ctx, + {pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + 10, + 20}, + {"op1_attr"}, + {"op1_name"}); pir::OpResult res = op->result(0); pir::OpResult res_ = op_->result(0); builder.SetInsertionPointToEnd(program.block()); - pir::dialect::TieShapeOp tie_shape_op1 = - builder.Build(res); - pir::dialect::TieShapeOp tie_shape_op2 = - builder.Build(res_); + pir::shape::TieShapeOp tie_shape_op1 = + builder.Build(res); + pir::shape::TieShapeOp tie_shape_op2 = + builder.Build(res_); pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1"); @@ -314,9 +241,9 @@ TEST(shape_struct_test, symbolic_dim_mgr_complex) { auto array_attr_ref = pir::ArrayAttribute::get(ctx, new_attrs_ref); tie_shape_op1->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), array_attr1); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), array_attr1); tie_shape_op2->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), array_attr2); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), array_attr2); EXPECT_TRUE(sym_dim_mgr.Load()); @@ -380,7 +307,7 @@ TEST(shape_struct_test, symbolic_dim_mgr_complex) { EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s0, sym_dim_s3)); EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s4, sym_dim_s5)); EXPECT_EQ(sym_dim_s6.GetDimSize(), 200); - EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("C20"), + EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("C20"), sym_dim_c20); EXPECT_EQ(sym_dim_s7.GetDimSize(), sym_dim_c10.GetDimSize()); EXPECT_EQ(simplified_product_s7.factor, 10); @@ -402,11 +329,11 @@ TEST(shape_struct_test, symbolic_dim_mgr_complex) { EXPECT_TRUE(sym_dim_mgr_new.Load()); auto attrs = tie_shape_op1.attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName()); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName()); EXPECT_FALSE( - sym_dim_mgr_new.symbolTable().Lookup("S7")); + sym_dim_mgr_new.symbolTable().Lookup("S7")); EXPECT_EQ(sym_dim_mgr_new.symbolTable() - .Lookup("tie_product_equal") + .Lookup("tie_product_equal") .size(), static_cast(1)); @@ -416,52 +343,56 @@ TEST(shape_struct_test, symbolic_dim_mgr_complex) { TEST(shape_struct_test, shape_analysis) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - pir::dialect::FuncOp func_op = builder.Build(); + pir::shape::FuncOp func_op = builder.Build(); phi::DDim dims_D_2 = {pir::ShapedTypeInterface::kDynamic, 2}; phi::DDim dims_2_2 = {2, 2}; phi::DDim dims_D = {pir::ShapedTypeInterface::kDynamic}; // same shape with dynamic: value1 == value2 - auto op1 = CreateDenseTensorOp(ctx, dims_D_2, {"op1_attr"}, {"op1_name"}); - auto op2 = CreateDenseTensorOp(ctx, dims_D_2, {"op2_attr"}, {"op2_name"}); + auto op1 = + test::CreateDenseTensorOp(ctx, dims_D_2, {"op1_attr"}, {"op1_name"}); + auto op2 = + test::CreateDenseTensorOp(ctx, dims_D_2, {"op2_attr"}, {"op2_name"}); pir::OpResult value1 = op1->result(0); pir::OpResult value2 = op2->result(0); // same shape with static: value3 == value4 - auto op3 = CreateDenseTensorOp(ctx, dims_2_2, {"op3_attr"}, {"op3_name"}); - auto op4 = CreateDenseTensorOp(ctx, dims_2_2, {"op4_attr"}, {"op4_name"}); + auto op3 = + test::CreateDenseTensorOp(ctx, dims_2_2, {"op3_attr"}, {"op3_name"}); + auto op4 = + test::CreateDenseTensorOp(ctx, dims_2_2, {"op4_attr"}, {"op4_name"}); pir::OpResult value3 = op3->result(0); pir::OpResult value4 = op4->result(0); // one dimension with dynamic: value5 != value1 != value3 - auto op5 = CreateDenseTensorOp(ctx, dims_D, {"op5_attr"}, {"op5_name"}); + auto op5 = test::CreateDenseTensorOp(ctx, dims_D, {"op5_attr"}, {"op5_name"}); pir::OpResult value5 = op5->result(0); - pir::dialect::TieShapeOp tie_shape_op1 = - builder.Build(value1); - pir::dialect::TieShapeOp tie_shape_op2 = - builder.Build(value2); - pir::dialect::TieShapeOp tie_shape_op3 = - builder.Build(value3); - pir::dialect::TieShapeOp tie_shape_op4 = - builder.Build(value4); - pir::dialect::TieShapeOp tie_shape_op5 = - builder.Build(value5); + pir::shape::TieShapeOp tie_shape_op1 = + builder.Build(value1); + pir::shape::TieShapeOp tie_shape_op2 = + builder.Build(value2); + pir::shape::TieShapeOp tie_shape_op3 = + builder.Build(value3); + pir::shape::TieShapeOp tie_shape_op4 = + builder.Build(value4); + pir::shape::TieShapeOp tie_shape_op5 = + builder.Build(value5); builder.SetInsertionPointToEnd(func_op.block()); - builder.Build("C2", 2, true, false, true, true); - pir::dialect::SymbolicDim sym_dim_s0 = - builder.Build( + builder.Build("C2", 2, true, false, true, true); + pir::shape::SymbolicDimOp sym_dim_s0 = + builder.Build( "S0", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::dialect::SymbolicDim sym_dim_s1 = - builder.Build( + pir::shape::SymbolicDimOp sym_dim_s1 = + builder.Build( "S1", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::dialect::SymbolicDim sym_dim_s2 = - builder.Build( + pir::shape::SymbolicDimOp sym_dim_s2 = + builder.Build( "S2", pir::ShapedTypeInterface::kDynamic, false, false, true, true); pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); @@ -476,15 +407,15 @@ TEST(shape_struct_test, shape_analysis) { auto attr_op5 = pir::ArrayAttribute::get(ctx, {attr_s2}); tie_shape_op1->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op1); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op1); tie_shape_op2->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op2); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op2); tie_shape_op3->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op3); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op3); tie_shape_op4->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op4); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op4); tie_shape_op5->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op5); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op5); pir::ShapeConstraintIRAnalysis shape_analysis(program.module_op()); EXPECT_TRUE(shape_analysis.IsShapeEqual(value3, value4)); diff --git a/test/cpp/pir/tools/test_pir_utils.h b/test/cpp/pir/tools/test_pir_utils.h new file mode 100644 index 0000000000000..d71ddb0d2ea95 --- /dev/null +++ b/test/cpp/pir/tools/test_pir_utils.h @@ -0,0 +1,59 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/dialect/shape/utils/shape_utils.h" + +namespace test { + +pir::AttributeMap CreateAttributeMap( + const std::vector &attribute_names, + const std::vector &attributes) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::AttributeMap attr_map; + for (size_t i = 0; i < attribute_names.size(); i++) { + pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); + attr_map.insert( + std::pair(attribute_names[i], attr_value)); + } + return attr_map; +} + +pir::Operation *CreateDenseTensorOp( + pir::IrContext *ctx, + const phi::DDim &dims, + const std::vector &attribute_names, + const std::vector &attributes, + const pir::Type &dtype = + pir::Float32Type::get(pir::IrContext::Instance())) { + std::vector op_inputs = {}; + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + std::vector op_output_types = { + paddle::dialect::DenseTensorType::get( + ctx, dtype, dims, data_layout, lod, offset)}; + + pir::Builder builder = pir::Builder(ctx); + pir::Operation *op = + builder.Build(op_inputs, + CreateAttributeMap(attribute_names, attributes), + op_output_types, + pir::OpInfo()); + return op; +} + +} // namespace test diff --git a/test/cpp/prim/CMakeLists.txt b/test/cpp/prim/CMakeLists.txt index 6499c2fae6c6e..bf3e597de81e2 100644 --- a/test/cpp/prim/CMakeLists.txt +++ b/test/cpp/prim/CMakeLists.txt @@ -17,27 +17,7 @@ set(prim_generated_deps final_dygraph_function final_dygraph_node if(WITH_CINN) set(CINN_DEPS cinn_compiler) endif() -cc_test_old( - test_comp_static - SRCS - test_static_prim.cc - DEPS - fleet_executor - static_utils - static_prim_api - generated_op - prim_utils - operator - elementwise_mul_op - elementwise_sub_op - fill_constant_op - activation_op - phi - static_global_utils - static_tensor_operants - generated_static_op - ${CINN_DEPS} - python) +paddle_test(test_comp_static SRCS test_static_prim.cc) if(NOT (NOT WITH_PYTHON AND ON_INFER)) if(WITH_CINN) @@ -53,9 +33,9 @@ endif() if(NOT WIN32) cc_test( - test_vjp_new_ir + test_vjp_pir SRCS test_vjp.cc - DEPS pir_adaptor pd_op_dialect pir) + DEPS pir_adaptor op_dialect_vjp pir) endif() if(WITH_ONNXRUNTIME AND WIN32) # Copy onnxruntime for some c++ test in Windows, since the test will diff --git a/test/cpp/prim/test_static_prim.cc b/test/cpp/prim/test_static_prim.cc index d4f5dcb8998ae..8fd7d79bacbc3 100644 --- a/test/cpp/prim/test_static_prim.cc +++ b/test/cpp/prim/test_static_prim.cc @@ -31,46 +31,6 @@ PD_DECLARE_bool(prim_enabled); PHI_DECLARE_string(tensor_operants_mode); -PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(pow, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(subtract, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(concat, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(less_equal, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(less_than, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(less_than_raw, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(equal, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(not_equal, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(greater_equal, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(greater_than, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_and, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_or, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_xor, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_not, CPU, ALL_LAYOUT); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(tanh, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(tanh_grad, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(pow, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(subtract, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(less_equal, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(less_than, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(less_than_raw, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(equal, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(not_equal, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(greater_equal, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(greater_than, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_and, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_or, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_xor, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_not, KPS, ALL_LAYOUT); -#endif namespace paddle { namespace prim { @@ -569,20 +529,3 @@ TEST(StaticPrim, TestFlags) { } // namespace prim } // namespace paddle -USE_OP_ITSELF(fill_constant); -USE_OP_ITSELF(tanh); -USE_OP_ITSELF(tanh_grad); -USE_OP_ITSELF(elementwise_mul); -USE_OP_ITSELF(elementwise_sub); -USE_OP_ITSELF(elementwise_pow); -USE_OP_ITSELF(scale); -USE_OP_ITSELF(less_equal); -USE_OP_ITSELF(less_than); -USE_OP_ITSELF(equal); -USE_OP_ITSELF(not_equal); -USE_OP_ITSELF(greater_equal); -USE_OP_ITSELF(greater_than); -USE_OP_ITSELF(bitwise_xor); -USE_OP_ITSELF(bitwise_and); -USE_OP_ITSELF(bitwise_not); -USE_OP_ITSELF(bitwise_or); diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index a28e412a9ebba..7e420635ad210 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -14,7 +14,7 @@ #include -#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" +#include "paddle/fluid/framework/new_executor/pir_interpreter.h" #include "paddle/fluid/framework/new_executor/standalone_executor.h" #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" @@ -42,6 +42,16 @@ PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT); namespace paddle { namespace framework { +pir::Operation* GetOpFromProgram(const std::string& op_name, + const pir::Program& program) { + for (auto op : *(program.block())) { + if (op->name() == op_name) { + return op; + } + } + return nullptr; +} + TEST(VJP, TanhBackwardTest) { pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); @@ -59,38 +69,38 @@ TEST(VJP, TanhBackwardTest) { std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}}; + std::vector> inputs{{op1.out()}}; + std::vector> outputs{{op2.out()}}; std::vector> out_grads{{op3.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.tanh"); auto tanh_vjp_interface_impl = op2_info.GetInterfaceImpl(); - tanh_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); + tanh_vjp_interface_impl->vjp_( + op2.operation(), inputs, outputs, out_grads, stop_gradients); + + builder->Build(op2->result(0), "tanh_out"); + builder->Build( + GetOpFromProgram("pd_op.tanh_grad", program)->result(0), "tanh_grad_out"); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; - ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars( - {prefix_str + "_inner_var_1", prefix_str + "_inner_var_3"}); + test_core.SetSkipGcVars({"tanh_out", "tanh_grad_out"}); test_core.Run({}); - auto out_tensor = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_1")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_1") - ->Get(); + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("tanh_out")->Get() + : test_core.local_scope() + ->FindVar("tanh_out") + ->Get(); auto grad_out_tensor = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_3")->Get() + ? scope.FindVar("tanh_grad_out")->Get() : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_3") + ->FindVar("tanh_grad_out") ->Get(); ASSERT_NEAR(out_tensor.data()[0], 0.76159, 1e-5); @@ -114,38 +124,39 @@ TEST(VJP, Tanh_BackwardTest) { std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}}; + std::vector> inputs{{op1.out()}}; + std::vector> outputs{{op2.out()}}; std::vector> out_grads{{op3.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.tanh_"); auto tanh_vjp_interface_impl = op2_info.GetInterfaceImpl(); - tanh_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); + tanh_vjp_interface_impl->vjp_( + op2.operation(), inputs, outputs, out_grads, stop_gradients); + + std::string tanh_out = "tanh_out"; + std::string tanh_grad_out = "tanh_grad_out"; + builder->Build(op2->result(0), tanh_out); + builder->Build( + GetOpFromProgram("pd_op.tanh_grad", program)->result(0), tanh_grad_out); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; - ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars( - {prefix_str + "_inner_var_0", prefix_str + "_inner_var_2"}); + test_core.SetSkipGcVars({tanh_out, tanh_grad_out}); test_core.Run({}); auto out_tensor = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_0")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_0") - ->Get(); + ? scope.FindVar(tanh_out)->Get() + : test_core.local_scope()->FindVar(tanh_out)->Get(); auto grad_out_tensor = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_2")->Get() + ? scope.FindVar(tanh_grad_out)->Get() : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_2") + ->FindVar(tanh_grad_out) ->Get(); ASSERT_NEAR(out_tensor.data()[0], 0.76159, 1e-5); @@ -169,12 +180,19 @@ TEST(VJP, MeanBackwardTest) { std::vector{}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}}; + std::vector> inputs{{op1.out()}}; + std::vector> outputs{{op2.out()}}; std::vector> out_grads{{op3.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.mean"); auto mean_vjp_interface_impl = op2_info.GetInterfaceImpl(); - mean_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); + mean_vjp_interface_impl->vjp_( + op2.operation(), inputs, outputs, out_grads, stop_gradients); + + builder->Build(op2->result(0), "mean_out"); + builder->Build( + GetOpFromProgram("pd_op.mean_grad", program)->result(0), "mean_grad_out"); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); @@ -183,24 +201,18 @@ TEST(VJP, MeanBackwardTest) { ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars( - {prefix_str + "_inner_var_1", prefix_str + "_inner_var_3"}); + test_core.SetSkipGcVars({"mean_out", "mean_grad_out"}); test_core.Run({}); - auto out_tensor = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_1")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_1") - ->Get(); + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("mean_out")->Get() + : test_core.local_scope() + ->FindVar("mean_out") + ->Get(); auto grad_out_tensor = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_3")->Get() + ? scope.FindVar("mean_grad_out")->Get() : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_3") + ->FindVar("mean_grad_out") ->Get(); ASSERT_EQ(out_tensor.data()[0], 2.0); ASSERT_EQ(grad_out_tensor.data()[0], 0.25); @@ -227,11 +239,22 @@ TEST(VJP, ConcatBackwardTest) { paddle::dialect::FullOp op4 = builder->Build( std::vector{2, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false, false}}; + std::vector> inputs{{op1.out(), op1.out()}, + {op3.axis()}}; + std::vector> outputs{{op3.out()}}; std::vector> out_grads{{op4.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.concat"); auto concat_vjp_interface_impl = op2_info.GetInterfaceImpl(); - concat_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients); + concat_vjp_interface_impl->vjp_( + op3.operation(), inputs, outputs, out_grads, stop_gradients); + + builder->Build(op3->result(0), "concat_out"); + builder->Build( + GetOpFromProgram("builtin.split", program)->result(0), "split_out_0"); + builder->Build( + GetOpFromProgram("builtin.split", program)->result(1), "split_out_1"); + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); @@ -239,31 +262,24 @@ TEST(VJP, ConcatBackwardTest) { ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars({prefix_str + "_inner_var_3", - prefix_str + "_inner_var_7", - prefix_str + "_inner_var_8"}); + test_core.SetSkipGcVars({"concat_out", "split_out_0", "split_out_1"}); test_core.Run({}); - auto out_tensor = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_3")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_3") - ->Get(); + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("concat_out")->Get() + : test_core.local_scope() + ->FindVar("concat_out") + ->Get(); auto grad_out_tensor_0 = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_7")->Get() + ? scope.FindVar("split_out_0")->Get() : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_7") + ->FindVar("split_out_0") ->Get(); auto grad_out_tensor_1 = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_8")->Get() + ? scope.FindVar("split_out_1")->Get() : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_8") + ->FindVar("split_out_1") ->Get(); ASSERT_EQ(out_tensor.data()[0], 2.0); ASSERT_EQ(grad_out_tensor_0.data()[0], 1.0); @@ -291,12 +307,21 @@ TEST(VJP, AddBackwardTest) { std::vector{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}, {false}}; + std::vector> inputs{{op1.out()}, {op2.out()}}; + std::vector> outputs{{op3.out()}}; std::vector> out_grads{{op4.out()}}; pir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd_op.add"); auto add_vjp_interface_impl = op3_info.GetInterfaceImpl(); - add_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients); + add_vjp_interface_impl->vjp_( + op3.operation(), inputs, outputs, out_grads, stop_gradients); + + builder->Build(op3->result(0), "add_out"); + builder->Build( + GetOpFromProgram("pd_op.add_grad", program)->result(0), "add_grad_out_0"); + builder->Build( + GetOpFromProgram("pd_op.add_grad", program)->result(1), "add_grad_out_1"); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); @@ -305,33 +330,24 @@ TEST(VJP, AddBackwardTest) { ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars({prefix_str + "_inner_var_2", - prefix_str + "_inner_var_4", - prefix_str + "_inner_var_5"}); + test_core.SetSkipGcVars({"add_out", "add_grad_out_0", "add_grad_out_1"}); test_core.Run({}); - auto out_tensor = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_2")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_2") - ->Get(); - auto dx = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_4")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_4") - ->Get(); - - auto dy = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_5")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_5") - ->Get(); + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("add_out")->Get() + : test_core.local_scope() + ->FindVar("add_out") + ->Get(); + auto dx = test_core.local_scope() == nullptr + ? scope.FindVar("add_grad_out_0")->Get() + : test_core.local_scope() + ->FindVar("add_grad_out_0") + ->Get(); + + auto dy = test_core.local_scope() == nullptr + ? scope.FindVar("add_grad_out_1")->Get() + : test_core.local_scope() + ->FindVar("add_grad_out_1") + ->Get(); ASSERT_EQ(out_tensor.data()[0], 4.0); ASSERT_EQ(dx.data()[0], 1.0); ASSERT_EQ(dy.data()[0], 1.0); @@ -356,13 +372,21 @@ TEST(VJP, Add_BackwardTest) { std::vector{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}, {false}}; + std::vector> inputs{{op1.out()}, {op2.out()}}; + std::vector> outputs{{op3.out()}}; std::vector> out_grads{{op4.out()}}; pir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd_op.add_"); auto add_inplace_vjp_interface_impl = op3_info.GetInterfaceImpl(); add_inplace_vjp_interface_impl->vjp_( - op3.operation(), out_grads, stop_gradients); + op3.operation(), inputs, outputs, out_grads, stop_gradients); + + builder->Build(op1->result(0), "full_op1_out"); + builder->Build( + GetOpFromProgram("pd_op.add_grad", program)->result(0), "add_grad_out_0"); + builder->Build( + GetOpFromProgram("pd_op.add_grad", program)->result(1), "add_grad_out_1"); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); @@ -371,33 +395,25 @@ TEST(VJP, Add_BackwardTest) { ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars({prefix_str + "_inner_var_0", - prefix_str + "_inner_var_3", - prefix_str + "_inner_var_4"}); - test_core.Run({}); - auto out_tensor = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_0")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_0") - ->Get(); - auto dx = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_3")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_3") - ->Get(); - auto dy = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_4")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_4") - ->Get(); + test_core.SetSkipGcVars({"full_op1_out", "add_grad_out_0", "add_grad_out_1"}); + test_core.Run({}); + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("full_op1_out")->Get() + : test_core.local_scope() + ->FindVar("full_op1_out") + ->Get(); + auto dx = test_core.local_scope() == nullptr + ? scope.FindVar("add_grad_out_0")->Get() + : test_core.local_scope() + ->FindVar("add_grad_out_0") + ->Get(); + + auto dy = test_core.local_scope() == nullptr + ? scope.FindVar("add_grad_out_1")->Get() + : test_core.local_scope() + ->FindVar("add_grad_out_1") + ->Get(); ASSERT_EQ(out_tensor.data()[0], 4.0); ASSERT_EQ(dx.data()[0], 1.0); ASSERT_EQ(dy.data()[0], 1.0); @@ -405,6 +421,7 @@ TEST(VJP, Add_BackwardTest) { TEST(VJP, SplitBackwardTest) { pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); pir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); @@ -422,44 +439,51 @@ TEST(VJP, SplitBackwardTest) { std::vector{1, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}}; + std::vector> inputs{ + {op2.x()}, {op2.sections()}, {op2.axis()}}; + std::vector> outputs{{op3.outputs()}}; std::vector> out_grads{{op3.result(0), op4.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.split"); auto concat_vjp_interface_impl = op2_info.GetInterfaceImpl(); - concat_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); + concat_vjp_interface_impl->vjp_( + op2.operation(), inputs, outputs, out_grads, stop_gradients); + + std::string split_out1 = "split_out1"; + std::string split_out2 = "split_out2"; + std::string concat_out = "concat_out"; + + builder->Build(op3->result(0), split_out1); + builder->Build(op3->result(1), split_out2); + builder->Build( + GetOpFromProgram("pd_op.concat", program)->result(0), concat_out); + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; - ProgramDesc prog_desc; + InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars({prefix_str + "_inner_var_4", - prefix_str + "_inner_var_5", - prefix_str + "_inner_var_8"}); + + test_core.SetSkipGcVars({split_out1, split_out2, concat_out}); test_core.Run({}); - auto out_tensor_0 = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_4")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_4") - ->Get(); - auto out_tensor_1 = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_5")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_5") - ->Get(); + auto out_tensor_0 = test_core.local_scope() == nullptr + ? scope.FindVar(split_out1)->Get() + : test_core.local_scope() + ->FindVar(split_out1) + ->Get(); + auto out_tensor_1 = test_core.local_scope() == nullptr + ? scope.FindVar(split_out2)->Get() + : test_core.local_scope() + ->FindVar(split_out2) + ->Get(); auto grad_out_tensor_0 = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_8")->Get() + ? scope.FindVar(concat_out)->Get() : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_8") + ->FindVar(concat_out) ->Get(); ASSERT_EQ(out_tensor_0.data()[0], 2.0); ASSERT_EQ(out_tensor_0.data()[1], 2.0); diff --git a/test/distributed_passes/test_auto_parallel_recompute_pass.py b/test/distributed_passes/test_auto_parallel_recompute_pass.py index 2b77ccdf34a63..6e98d9faaf1bb 100644 --- a/test/distributed_passes/test_auto_parallel_recompute_pass.py +++ b/test/distributed_passes/test_auto_parallel_recompute_pass.py @@ -37,7 +37,17 @@ def init(self): def apply_passes(self): dist_strategy = fleet.DistributedStrategy() dist_strategy.recompute = True - dist_strategy.recompute_configs = {"checkpoints": ["tmp_3", "tmp_6"]} + dist_strategy.recompute_configs = { + "checkpoints": ["tmp_3", "tmp_6"], + "refined_ops_patterns": [ + { + "main_ops": ["matmul_v2", "elementwise_add"], + "num": -1, + "pre_ops": [], + "suf_ops": [], + } + ], + } dist_strategy.semi_auto = True fleet.init(is_collective=True, strategy=dist_strategy) diff --git a/test/dygraph_to_static/CMakeLists.txt b/test/dygraph_to_static/CMakeLists.txt index 1beadd642a66e..e11c43a13de0b 100644 --- a/test/dygraph_to_static/CMakeLists.txt +++ b/test/dygraph_to_static/CMakeLists.txt @@ -3,7 +3,8 @@ file( RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -set(SOT_ENVS SOT_LOG_LEVEL=0 COST_MODEL=False MIN_GRAPH_SIZE=0 STRICT_MODE=0) +set(SOT_ENVS SOT_LOG_LEVEL=0 COST_MODEL=False MIN_GRAPH_SIZE=0 + STRICT_MODE=False) set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0) list(REMOVE_ITEM TEST_OPS test_lac) diff --git a/test/dygraph_to_static/dygraph_to_static_util.py b/test/dygraph_to_static/dygraph_to_static_util.py deleted file mode 100644 index 9a5b9bf22d92a..0000000000000 --- a/test/dygraph_to_static/dygraph_to_static_util.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import os -from functools import wraps - -import numpy as np - -from paddle import set_flags, static -from paddle.base import core - - -@contextlib.contextmanager -def enable_fallback_guard(enable): - flag = os.environ.get("ENABLE_FALL_BACK", None) - os.environ["ENABLE_FALL_BACK"] = enable - yield - if flag is not None: - os.environ["ENABLE_FALL_BACK"] = flag - else: - del os.environ["ENABLE_FALL_BACK"] - - -def to_ast(func): - """ - convert run fall_back to ast - """ - - def impl(*args, **kwargs): - with enable_fallback_guard("False"): - func(*args, **kwargs) - - return impl - - -def to_sot(func): - """ - convert run fall_back to ast - """ - # TODO(SigureMo): ENABLE_SOT should always be True, remove this - enable_sot = os.environ.get("ENABLE_SOT", "True") == "True" - - def impl(*args, **kwargs): - if enable_sot: - with enable_fallback_guard("True"): - func(*args, **kwargs) - else: - return - - return impl - - -def dy2static_unittest(cls): - """ - dy2static unittest must be decorated to each Dy2static Unittests. - run both in Fallback and Ast mode. - - Examples: - - >>> @dy2static_unittest - ... class TestA(unittest.TestCase): - ... ... - """ - for key in dir(cls): - if key.startswith("test"): - if not key.endswith("_ast"): - test_func = getattr(cls, key) - setattr(cls, key + "_ast", to_ast(test_func)) - test_func = getattr(cls, key) - setattr(cls, key, to_sot(test_func)) - return cls - - -def ast_only_test(func): - """ - run this test function in ast only mode. - - Examples: - - >>> @dy2static_unittest - ... class TestA(unittest.TestCase): - ... @ast_only_test - ... def test_ast_only(self): - ... pass - """ - - def impl(*args, **kwargs): - if os.environ.get("ENABLE_FALL_BACK", "False") == "False": - func(*args, **kwargs) - - return impl - - -def sot_only_test(func): - """ - run this test function in ast only mode. - - Examples: - - >>> @dy2static_unittest - ... class TestA(unittest.TestCase): - ... @sot_only_test - ... def test_sot_only(self): - ... pass - """ - - def impl(*args, **kwargs): - if os.environ.get("ENABLE_FALL_BACK", "False") == "True": - func(*args, **kwargs) - - return impl - - -def test_with_new_ir(func): - @wraps(func) - def impl(*args, **kwargs): - ir_outs = None - if os.environ.get('FLAGS_use_stride_kernel', False): - return - with static.scope_guard(static.Scope()): - with static.program_guard(static.Program()): - try: - new_ir_flag = 'FLAGS_enable_new_ir_in_executor' - os.environ[new_ir_flag] = 'True' - set_flags({new_ir_flag: True}) - ir_outs = func(*args, **kwargs) - finally: - del os.environ[new_ir_flag] - set_flags({new_ir_flag: False}) - return ir_outs - - return impl - - -def test_and_compare_with_new_ir(need_check_output: bool = True): - def decorator(func): - @wraps(func) - def impl(*args, **kwargs): - outs = func(*args, **kwargs) - if core._is_bwd_prim_enabled() or core._is_fwd_prim_enabled(): - return outs - ir_outs = test_with_new_ir(func)(*args, **kwargs) - if not need_check_output: - return outs - np.testing.assert_equal( - outs, - ir_outs, - err_msg='Dy2St Unittest Check (' - + func.__name__ - + ') has diff ' - + '\nExpect ' - + str(outs) - + '\n' - + 'But Got' - + str(ir_outs), - ) - return outs - - return impl - - return decorator diff --git a/test/dygraph_to_static/dygraph_to_static_utils_new.py b/test/dygraph_to_static/dygraph_to_static_utils_new.py index 5e0ebacd8e1e3..e0af2406f77e4 100644 --- a/test/dygraph_to_static/dygraph_to_static_utils_new.py +++ b/test/dygraph_to_static/dygraph_to_static_utils_new.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import inspect import logging import os @@ -22,16 +21,18 @@ import numpy as np +import paddle from paddle import set_flags, static from paddle.base import core +from paddle.jit.api import sot_mode_guard """ # Usage: class MyTest(Dy2StTestBase): @set_to_static_mode( - ToStaticMode.LEGACY_AST | ToStaticMode.SOT | ToStaticMode.PIR_AST + ToStaticMode.AST | ToStaticMode.SOT ) - @set_ir_mode(IrMode.LEGACY_PROGRAM | IrMode.PIR) + @set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_EXE | IrMode.PIR_API) def test_case1(self): raise ValueError("MyTest 1") @@ -49,8 +50,7 @@ def test_case1(self): class ToStaticMode(Flag): - LEGACY_AST = auto() - PIR_AST = auto() + AST = auto() SOT = auto() def lower_case_name(self): @@ -58,30 +58,18 @@ def lower_case_name(self): class IrMode(Flag): - LEGACY_PROGRAM = auto() - PIR = auto() + LEGACY_IR = auto() + # pir translator mode, Reference link: https://github.com/PaddlePaddle/community/blob/master/pfcc/paddle-code-reading/IR_Dialect/program_translator.md + PIR_EXE = auto() + # using native pir api mode + PIR_API = auto() def lower_case_name(self): return self.name.lower() -DEFAULT_TO_STATIC_MODE = ToStaticMode.LEGACY_AST | ToStaticMode.SOT -DEFAULT_IR_MODE = IrMode.LEGACY_PROGRAM - - -def in_sot_mode(): - return os.getenv("ENABLE_FALL_BACK", "False") == "True" - - -@contextlib.contextmanager -def enable_fallback_guard(enable): - flag = os.environ.get("ENABLE_FALL_BACK", None) - os.environ["ENABLE_FALL_BACK"] = enable - yield - if flag is not None: - os.environ["ENABLE_FALL_BACK"] = flag - else: - del os.environ["ENABLE_FALL_BACK"] +DEFAULT_TO_STATIC_MODE = ToStaticMode.AST | ToStaticMode.SOT +DEFAULT_IR_MODE = IrMode.LEGACY_IR def to_legacy_ast_test(fn): @@ -92,7 +80,7 @@ def to_legacy_ast_test(fn): @wraps(fn) def impl(*args, **kwargs): logger.info("[AST] running AST") - with enable_fallback_guard("False"): + with sot_mode_guard(False): fn(*args, **kwargs) return impl @@ -106,41 +94,50 @@ def to_sot_test(fn): @wraps(fn) def impl(*args, **kwargs): logger.info("[SOT] running SOT") - with enable_fallback_guard("True"): + with sot_mode_guard(True): fn(*args, **kwargs) return impl -def to_pir_ast_test(fn): - raise TypeError("Don't enable PIR AST mode now!") - - -def to_legacy_program_test(fn): +def to_legacy_ir_test(fn): def impl(*args, **kwargs): - logger.info("[Program] running legacy program") + logger.info("[Program] running legacy ir") return fn(*args, **kwargs) return impl -def to_pir_test(fn): +def to_pir_exe_test(fn): @wraps(fn) def impl(*args, **kwargs): - logger.info("[PIR] running pir") + logger.info("[PIR_EXE] running pir exe") ir_outs = None if os.environ.get('FLAGS_use_stride_kernel', False): return with static.scope_guard(static.Scope()): with static.program_guard(static.Program()): try: - new_ir_flag = 'FLAGS_enable_new_ir_in_executor' - os.environ[new_ir_flag] = 'True' - set_flags({new_ir_flag: True}) + pir_flag = 'FLAGS_enable_pir_in_executor' + os.environ[pir_flag] = 'True' + set_flags({pir_flag: True}) ir_outs = fn(*args, **kwargs) finally: - del os.environ[new_ir_flag] - set_flags({new_ir_flag: False}) + del os.environ[pir_flag] + set_flags({pir_flag: False}) + return ir_outs + + return impl + + +def to_pir_api_test(fn): + @wraps(fn) + def impl(*args, **kwargs): + logger.info("[PIR_API] running pir api") + ir_outs = None + with paddle.pir_utils.IrGuard(): + paddle.disable_static() + ir_outs = fn(*args, **kwargs) return ir_outs return impl @@ -150,13 +147,13 @@ def impl(*args, **kwargs): class Dy2StTestMeta(type): TO_STATIC_HANDLER_MAP = { ToStaticMode.SOT: to_sot_test, - ToStaticMode.LEGACY_AST: to_legacy_ast_test, - ToStaticMode.PIR_AST: to_pir_ast_test, + ToStaticMode.AST: to_legacy_ast_test, } IR_HANDLER_MAP = { - IrMode.LEGACY_PROGRAM: to_legacy_program_test, - IrMode.PIR: to_pir_test, + IrMode.LEGACY_IR: to_legacy_ir_test, + IrMode.PIR_EXE: to_pir_exe_test, + IrMode.PIR_API: to_pir_api_test, } def __new__(cls, name, bases, attrs): @@ -179,7 +176,7 @@ def __new__(cls, name, bases, attrs): # Disable inherited test cases for base in bases: for attr in dir(base): - if attr.startswith(fn_name): + if attr.startswith(f"{fn_name}__"): new_attrs[attr] = None fn_to_static_modes = getattr( fn, "to_static_mode", DEFAULT_TO_STATIC_MODE @@ -205,11 +202,11 @@ def __new__(cls, name, bases, attrs): ) # Generate all test cases for to_static_mode, ir_mode in to_static_with_ir_modes: + # NOTE(gouzil): Temporarily not supported SOT + PIR, link: https://github.com/PaddlePaddle/Paddle/pull/58630 if ( - to_static_mode == ToStaticMode.PIR_AST - and ir_mode == IrMode.LEGACY_PROGRAM + to_static_mode == ToStaticMode.SOT + and ir_mode == IrMode.PIR_API ): - # PIR with LEGACY_PROGRAM is not a valid combination continue new_attrs[ Dy2StTestMeta.test_case_name( @@ -263,31 +260,43 @@ def decorator(fn): # Suger decorators # These decorators can be simply composed by base decorators -def ast_only_test(fn): - fn = set_to_static_mode(ToStaticMode.LEGACY_AST)(fn) +def test_ast_only(fn): + fn = set_to_static_mode(ToStaticMode.AST)(fn) return fn -def sot_only_test(fn): +def test_sot_only(fn): fn = set_to_static_mode(ToStaticMode.SOT)(fn) return fn -def test_with_new_ir(fn): - fn = set_ir_mode(IrMode.PIR)(fn) +def test_pir_only(fn): + fn = set_ir_mode(IrMode.PIR_EXE)(fn) + return fn + + +def test_legacy_and_pir(fn): + fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_EXE)(fn) + return fn + + +def test_legacy_and_pir_api(fn): + fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_API)(fn) + return fn + + +def test_legacy_and_pir_exe_and_pir_api(fn): + fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_API | IrMode.PIR_EXE)(fn) return fn -def _test_and_compare_with_new_ir(fn): +def compare_legacy_with_pir(fn): @wraps(fn) def impl(*args, **kwargs): outs = fn(*args, **kwargs) if core._is_bwd_prim_enabled() or core._is_fwd_prim_enabled(): return outs - # Disable SOT + PIR test temprorily - if in_sot_mode(): - return outs - ir_outs = to_pir_test(fn)(*args, **kwargs) + ir_outs = to_pir_exe_test(fn)(*args, **kwargs) np.testing.assert_equal( outs, ir_outs, @@ -300,17 +309,6 @@ def impl(*args, **kwargs): return impl -def test_and_compare_with_new_ir(need_check_output: bool = True): - def decorator(fn): - fn = set_ir_mode(IrMode.LEGACY_PROGRAM | IrMode.PIR)(fn) - if need_check_output: - logger.info(f"[need_check_output] {fn.__name__}") - fn = _test_and_compare_with_new_ir(fn) - return fn - - return decorator - - # For debug def show_all_test_cases(test_class): logger.info(f"[showing {test_class.__name__}]") diff --git a/test/dygraph_to_static/test_assert.py b/test/dygraph_to_static/test_assert.py index 210e904454fd9..2e5066b801e52 100644 --- a/test/dygraph_to_static/test_assert.py +++ b/test/dygraph_to_static/test_assert.py @@ -17,8 +17,8 @@ import numpy from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -37,12 +37,11 @@ def dyfunc_assert_non_variable(x=True): assert x -# @dy2static_unittest class TestAssertVariable(Dy2StTestBase): def _run(self, func, x, with_exception, to_static): paddle.jit.enable_to_static(to_static) if with_exception: - with self.assertRaises(BaseException): + with self.assertRaises(BaseException): # noqa: B017 with base.dygraph.guard(): func(x) else: @@ -53,8 +52,8 @@ def _run_dy_static(self, func, x, with_exception): self._run(func, x, with_exception, True) self._run(func, x, with_exception, False) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_non_variable(self): self._run_dy_static( dyfunc_assert_non_variable, x=False, with_exception=True @@ -63,8 +62,8 @@ def test_non_variable(self): dyfunc_assert_non_variable, x=True, with_exception=False ) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_bool_variable(self): self._run_dy_static( dyfunc_assert_variable, x=numpy.array([False]), with_exception=True @@ -73,8 +72,8 @@ def test_bool_variable(self): dyfunc_assert_variable, x=numpy.array([True]), with_exception=False ) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_int_variable(self): self._run_dy_static( dyfunc_assert_variable, x=numpy.array([0]), with_exception=True diff --git a/test/dygraph_to_static/test_ast_util.py b/test/dygraph_to_static/test_ast_util.py index c2468765e3438..a6421e4cc60ba 100644 --- a/test/dygraph_to_static/test_ast_util.py +++ b/test/dygraph_to_static/test_ast_util.py @@ -19,8 +19,8 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) from ifelse_simple_func import ( dyfunc_with_if_else, @@ -35,7 +35,6 @@ from paddle.utils import gast -# @dy2static_unittest class TestAST2Func(Dy2StTestBase): """ TestCase for the transformation from ast.AST into python callable function. @@ -48,7 +47,7 @@ def _ast2func(self, func): transformed_func, _ = ast_to_func(ast_root, func) return transformed_func - @ast_only_test + @test_ast_only def test_ast2func(self): def func(x, y): return x + y @@ -56,7 +55,7 @@ def func(x, y): x, y = 10, 20 self.assertEqual(func(x, y), self._ast2func(func)(x, y)) - @ast_only_test + @test_ast_only def test_ast2func_dygraph(self): paddle.disable_static() funcs = [dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else] @@ -68,8 +67,8 @@ def test_ast2func_dygraph(self): test_ret = self._ast2func(func)(x_v).numpy() self.assertTrue((true_ret == test_ret).all()) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_ast2func_static(self): paddle.enable_static() @@ -88,7 +87,7 @@ def func(x): ret = exe.run(main_program, fetch_list=[true_ret, test_ret]) self.assertTrue((ret[0] == ret[1]).all()) - @ast_only_test + @test_ast_only def test_ast2func_error(self): with self.assertRaises(Exception) as e: self.assertRaises(TypeError, ast_to_func("x = a + b", 'foo')) diff --git a/test/dygraph_to_static/test_backward_without_params.py b/test/dygraph_to_static/test_backward_without_params.py index e233259dc514e..e11ee387ec69c 100644 --- a/test/dygraph_to_static/test_backward_without_params.py +++ b/test/dygraph_to_static/test_backward_without_params.py @@ -15,13 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_utils_new import ( - Dy2StTestBase, - IrMode, - ToStaticMode, - disable_test_case, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -35,10 +29,8 @@ def forward(self, x): return out -# @dy2static_unittest class TestBackwardWithoutParams(Dy2StTestBase): - @test_and_compare_with_new_ir(False) - @disable_test_case((ToStaticMode.SOT, IrMode.PIR)) + @test_legacy_and_pir def test_run(self): net = paddle.jit.to_static(Net()) @@ -61,10 +53,8 @@ def forward(self, x): return y, out -# @dy2static_unittest class TestZeroSizeNet(Dy2StTestBase): - @test_and_compare_with_new_ir(False) - @disable_test_case((ToStaticMode.SOT, IrMode.PIR)) + @test_legacy_and_pir def test_run(self): net = paddle.jit.to_static(ZeroSizeNet()) x = paddle.ones([2, 2]) diff --git a/test/dygraph_to_static/test_basic_api_transformation.py b/test/dygraph_to_static/test_basic_api_transformation.py index e0998b8fe1e67..51ddbe6e11a1c 100644 --- a/test/dygraph_to_static/test_basic_api_transformation.py +++ b/test/dygraph_to_static/test_basic_api_transformation.py @@ -16,10 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import base, to_tensor @@ -72,8 +69,7 @@ def dyfunc_bool_to_tensor(x): return paddle.to_tensor(True) -@dy2static_unittest -class TestDygraphBasicApi_ToVariable(unittest.TestCase): +class TestDygraphBasicApi_ToVariable(Dy2StTestBase): def setUp(self): self.input = np.ones(5).astype("int32") self.test_funcs = [ @@ -96,7 +92,7 @@ def get_dygraph_output(self): res = self.dygraph_func(self.input).numpy() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): main_program = base.Program() main_program.random_seed = SEED @@ -234,8 +230,7 @@ def dyfunc_Prelu(input): return res -@dy2static_unittest -class TestDygraphBasicApi(unittest.TestCase): +class TestDygraphBasicApi(Dy2StTestBase): # Compare results of dynamic graph and transformed static graph function which only # includes basic Api. @@ -252,7 +247,7 @@ def get_dygraph_output(self): return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -286,7 +281,7 @@ def get_dygraph_output(self): res = self.dygraph_func(self.input1, self.input2).numpy() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -401,8 +396,7 @@ def dyfunc_PolynomialDecay(): return paddle.to_tensor(lr) -@dy2static_unittest -class TestDygraphBasicApi_CosineDecay(unittest.TestCase): +class TestDygraphBasicApi_CosineDecay(Dy2StTestBase): def setUp(self): self.dygraph_func = dyfunc_CosineDecay @@ -413,7 +407,7 @@ def get_dygraph_output(self): res = self.dygraph_func().numpy() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -444,7 +438,7 @@ def get_dygraph_output(self): res = self.dygraph_func() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -471,7 +465,7 @@ def get_dygraph_output(self): res = self.dygraph_func() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -498,7 +492,7 @@ def get_dygraph_output(self): res = self.dygraph_func() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -545,8 +539,7 @@ def _dygraph_fn(): np.random.random(1) -@dy2static_unittest -class TestDygraphApiRecognition(unittest.TestCase): +class TestDygraphApiRecognition(Dy2StTestBase): def setUp(self): self.src = inspect.getsource(_dygraph_fn) self.root = gast.parse(self.src) diff --git a/test/dygraph_to_static/test_bert.py b/test/dygraph_to_static/test_bert.py index ba8e2350794aa..b2e853b5755bb 100644 --- a/test/dygraph_to_static/test_bert.py +++ b/test/dygraph_to_static/test_bert.py @@ -20,10 +20,10 @@ import numpy as np from bert_dygraph_model import PretrainModelLayer from bert_utils import get_bert_config, get_feed_data_reader -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_pir_only, ) from predictor_utils import PredictorTools @@ -78,8 +78,7 @@ def __len__(self): return len(self.src_ids) -@dy2static_unittest -class TestBert(unittest.TestCase): +class TestBert(Dy2StTestBase): def setUp(self): self.bert_config = get_bert_config() self.data_reader = get_feed_data_reader(self.bert_config) @@ -266,8 +265,8 @@ def predict_analysis_inference(self, data): out = output() return out - @test_with_new_ir - def test_train_new_ir(self): + @test_pir_only + def test_train_pir(self): static_loss, static_ppl = self.train_static( self.bert_config, self.data_reader ) @@ -277,7 +276,7 @@ def test_train_new_ir(self): np.testing.assert_allclose(static_loss, dygraph_loss, rtol=1e-05) np.testing.assert_allclose(static_ppl, dygraph_ppl, rtol=1e-05) - @ast_only_test + @test_ast_only def test_train(self): static_loss, static_ppl = self.train_static( self.bert_config, self.data_reader diff --git a/test/dygraph_to_static/test_bmn.py b/test/dygraph_to_static/test_bmn.py index f5f8d35759869..a72170d5ef491 100644 --- a/test/dygraph_to_static/test_bmn.py +++ b/test/dygraph_to_static/test_bmn.py @@ -18,7 +18,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest, test_with_new_ir +from dygraph_to_static_utils_new import Dy2StTestBase, test_pir_only from predictor_utils import PredictorTools import paddle @@ -637,8 +637,7 @@ def val_bmn(model, args): return loss_data -@dy2static_unittest -class TestTrain(unittest.TestCase): +class TestTrain(Dy2StTestBase): def setUp(self): self.args = Args() self.place = ( @@ -751,8 +750,8 @@ def train_bmn(self, args, place, to_static): break return np.array(loss_data) - @test_with_new_ir - def test_train_new_ir(self): + @test_pir_only + def test_train_pir(self): static_res = self.train_bmn(self.args, self.place, to_static=True) dygraph_res = self.train_bmn(self.args, self.place, to_static=False) np.testing.assert_allclose( diff --git a/test/dygraph_to_static/test_break_continue.py b/test/dygraph_to_static/test_break_continue.py index a803c1d4bf49e..e1df868435e8f 100644 --- a/test/dygraph_to_static/test_break_continue.py +++ b/test/dygraph_to_static/test_break_continue.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle from paddle import base @@ -26,14 +26,13 @@ np.random.seed(SEED) -@dy2static_unittest -class TestDy2staticException(unittest.TestCase): +class TestDy2staticException(Dy2StTestBase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = None self.error = "Your if/else have different number of return value." - @ast_only_test + @test_ast_only def test_error(self): if self.dyfunc: with self.assertRaisesRegex(Dygraph2StaticException, self.error): @@ -205,8 +204,7 @@ def test_optim_break_in_while(x): return x -@dy2static_unittest -class TestContinueInFor(unittest.TestCase): +class TestContinueInFor(Dy2StTestBase): def setUp(self): self.input = np.zeros(1).astype('int64') self.place = ( diff --git a/test/dygraph_to_static/test_build_strategy.py b/test/dygraph_to_static/test_build_strategy.py index 85e934afb020b..ee19dad5842f9 100644 --- a/test/dygraph_to_static/test_build_strategy.py +++ b/test/dygraph_to_static/test_build_strategy.py @@ -15,14 +15,13 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only from test_resnet import ResNetHelper import paddle -@dy2static_unittest -class TestResnetWithPass(unittest.TestCase): +class TestResnetWithPass(Dy2StTestBase): def setUp(self): self.build_strategy = paddle.static.BuildStrategy() self.build_strategy.fuse_elewise_add_act_ops = True @@ -62,7 +61,7 @@ def verify_predict(self): err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.', ) - @ast_only_test + @test_ast_only def test_resnet(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) @@ -74,7 +73,7 @@ def test_resnet(self): ) self.verify_predict() - @ast_only_test + @test_ast_only def test_in_static_mode_mkldnn(self): paddle.base.set_flags({'FLAGS_use_mkldnn': True}) try: @@ -84,8 +83,7 @@ def test_in_static_mode_mkldnn(self): paddle.base.set_flags({'FLAGS_use_mkldnn': False}) -@dy2static_unittest -class TestError(unittest.TestCase): +class TestError(Dy2StTestBase): def test_type_error(self): def foo(x): out = x + 1 diff --git a/test/dygraph_to_static/test_cache_program.py b/test/dygraph_to_static/test_cache_program.py index 199c3e980e20c..9683afb05bdda 100644 --- a/test/dygraph_to_static/test_cache_program.py +++ b/test/dygraph_to_static/test_cache_program.py @@ -16,7 +16,7 @@ from collections import Counter import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase from test_fetch_feed import Linear, Pool2D import paddle @@ -25,8 +25,7 @@ from paddle.jit.dy2static import convert_to_static -@dy2static_unittest -class TestCacheProgram(unittest.TestCase): +class TestCacheProgram(Dy2StTestBase): def setUp(self): self.batch_num = 5 self.dygraph_class = Pool2D @@ -76,8 +75,7 @@ def setUp(self): self.data = np.random.random((4, 10)).astype('float32') -@dy2static_unittest -class TestCacheProgramWithOptimizer(unittest.TestCase): +class TestCacheProgramWithOptimizer(Dy2StTestBase): def setUp(self): self.dygraph_class = Linear self.data = np.random.random((4, 10)).astype('float32') @@ -126,8 +124,7 @@ def simple_func(x): return mean -@dy2static_unittest -class TestConvertWithCache(unittest.TestCase): +class TestConvertWithCache(Dy2StTestBase): def test_cache(self): static_func = convert_to_static(simple_func) # Get transformed function from cache. @@ -157,8 +154,7 @@ def sum_under_while(limit): return ret_sum -@dy2static_unittest -class TestToOutputWithCache(unittest.TestCase): +class TestToOutputWithCache(Dy2StTestBase): def test_output(self): with base.dygraph.guard(): ret = sum_even_until_limit(80, 10) diff --git a/test/dygraph_to_static/test_cast.py b/test/dygraph_to_static/test_cast.py index a01f2712cc764..48564e2776395 100644 --- a/test/dygraph_to_static/test_cast.py +++ b/test/dygraph_to_static/test_cast.py @@ -17,8 +17,8 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) from paddle import base @@ -28,14 +28,12 @@ np.random.seed(SEED) -@to_static def test_bool_cast(x): x = base.dygraph.to_variable(x) x = bool(x) return x -@to_static def test_int_cast(x): x = base.dygraph.to_variable(x) x = int(x) @@ -48,13 +46,11 @@ def test_float_cast(x): return x -@to_static def test_not_var_cast(x): x = int(x) return x -@to_static def test_mix_cast(x): x = base.dygraph.to_variable(x) x = int(x) @@ -64,7 +60,6 @@ def test_mix_cast(x): return x -# @dy2static_unittest class TestCastBase(Dy2StTestBase): def setUp(self): self.place = ( @@ -86,16 +81,15 @@ def prepare(self): self.cast_dtype = 'bool' def set_func(self): - self.func = test_bool_cast + self.func = to_static(full_graph=True)(test_bool_cast) def do_test(self): with base.dygraph.guard(): res = self.func(self.input) return res - @ast_only_test # TODO: add new symbolic only test. - @test_and_compare_with_new_ir(False) - # @set_to_static_mode(ToStaticMode.LEGACY_AST) + @test_ast_only # TODO: add new sot only test. + @test_legacy_and_pir def test_cast_result(self): res = self.do_test().numpy() self.assertTrue( @@ -125,7 +119,7 @@ def prepare(self): self.cast_dtype = 'int32' def set_func(self): - self.func = test_int_cast + self.func = to_static(full_graph=True)(test_int_cast) class TestFloatCast(TestCastBase): @@ -140,7 +134,7 @@ def prepare(self): self.cast_dtype = 'float32' def set_func(self): - self.func = to_static(test_float_cast) + self.func = to_static(full_graph=True)(test_float_cast) class TestMixCast(TestCastBase): @@ -158,10 +152,10 @@ def prepare(self): self.cast_dtype = 'float32' def set_func(self): - self.func = test_mix_cast + self.func = to_static(full_graph=True)(test_mix_cast) - @ast_only_test # TODO: add new symbolic only test. - @test_and_compare_with_new_ir(False) + @test_ast_only # TODO: add new symbolic only test. + @test_legacy_and_pir def test_cast_result(self): res = self.do_test().numpy() self.assertTrue( @@ -190,10 +184,10 @@ def prepare(self): self.cast_dtype = 'int' def set_func(self): - self.func = test_not_var_cast + self.func = to_static(full_graph=True)(test_not_var_cast) - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_cast_result(self): # breakpoint() # print("run once!!!") diff --git a/test/dygraph_to_static/test_cinn.py b/test/dygraph_to_static/test_cinn.py index 84e619149c800..0f8f5c962934c 100644 --- a/test/dygraph_to_static/test_cinn.py +++ b/test/dygraph_to_static/test_cinn.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -45,8 +42,7 @@ def apply_to_static(net, use_cinn): return paddle.jit.to_static(net, build_strategy=build_strategy) -@dy2static_unittest -class TestCINN(unittest.TestCase): +class TestCINN(Dy2StTestBase): def setUp(self): self.x = paddle.randn([2, 4]) self.x.stop_gradient = False @@ -83,7 +79,7 @@ def train(self, use_cinn): return res - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_cinn(self): dy_res = self.train(use_cinn=False) cinn_res = self.train(use_cinn=True) diff --git a/test/dygraph_to_static/test_cinn_prim.py b/test/dygraph_to_static/test_cinn_prim.py index 2ed5326f7b9d0..95df5d498c6fb 100644 --- a/test/dygraph_to_static/test_cinn_prim.py +++ b/test/dygraph_to_static/test_cinn_prim.py @@ -15,10 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -43,8 +43,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestPrimForward(unittest.TestCase): +class TestPrimForward(Dy2StTestBase): """ This case only tests prim_forward + to_static + cinn. Thus we need to set this flag as False to avoid prim_backward. @@ -94,7 +93,7 @@ def check_prim(self, net, use_prim): # Ensure that softmax is splitted into small ops self.assertTrue('softmax' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim_forward(self): dy_res = self.train(use_prim=False) cinn_res = self.train(use_prim=True) @@ -105,8 +104,7 @@ def test_cinn_prim_forward(self): ) -@dy2static_unittest -class TestPrimForwardAndBackward(unittest.TestCase): +class TestPrimForwardAndBackward(Dy2StTestBase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph """ @@ -161,7 +159,7 @@ def check_prim(self, net, use_prim): if op != "matmul_v2_grad": self.assertTrue("_grad" not in op) - @ast_only_test + @test_ast_only def test_cinn_prim(self): dy_res = self.train(use_prim=False) cinn_res = self.train(use_prim=True) @@ -172,9 +170,8 @@ def test_cinn_prim(self): ) -@dy2static_unittest -class TestBackend(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestBackend(Dy2StTestBase): + @test_legacy_and_pir def test_backend(self): x = paddle.randn([2, 4]) out1 = self.forward(x, 'CINN') diff --git a/test/dygraph_to_static/test_cinn_prim_gelu.py b/test/dygraph_to_static/test_cinn_prim_gelu.py index be2e8f67c1e98..ab9b3697eba62 100644 --- a/test/dygraph_to_static/test_cinn_prim_gelu.py +++ b/test/dygraph_to_static/test_cinn_prim_gelu.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle import paddle.nn.functional as F @@ -53,8 +53,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestPrimForwardAndBackward(unittest.TestCase): +class TestPrimForwardAndBackward(Dy2StTestBase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph """ @@ -106,7 +105,7 @@ def check_prim(self, net, use_prim): # Ensure that gelu is splitted into small ops self.assertTrue('gelu' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim(self): for shape in self.shapes: for dtype in self.dtypes: diff --git a/test/dygraph_to_static/test_cinn_prim_layer_norm.py b/test/dygraph_to_static/test_cinn_prim_layer_norm.py index 42bf36d731eca..94186bb1bff39 100644 --- a/test/dygraph_to_static/test_cinn_prim_layer_norm.py +++ b/test/dygraph_to_static/test_cinn_prim_layer_norm.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle import paddle.nn.functional as F @@ -52,8 +52,7 @@ def forward(self, x, w, b): return out[0] -@dy2static_unittest -class TestPrimForward(unittest.TestCase): +class TestPrimForward(Dy2StTestBase): """ This case only tests prim_forward + to_static + cinn. Thus we need to set this flag as False to avoid prim_backward. @@ -103,7 +102,7 @@ def check_prim(self, net, use_prim): # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim_forward(self): for dtype in self.dtypes: if paddle.device.get_device() == "cpu": @@ -125,8 +124,7 @@ def test_cinn_prim_forward(self): ) -@dy2static_unittest -class TestPrimForwardAndBackward(unittest.TestCase): +class TestPrimForwardAndBackward(Dy2StTestBase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph """ @@ -172,7 +170,7 @@ def check_prim(self, net, use_prim): # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim(self): for dtype in self.dtypes: if paddle.device.get_device() == "cpu": diff --git a/test/dygraph_to_static/test_cinn_prim_mean.py b/test/dygraph_to_static/test_cinn_prim_mean.py index cb32f5b466035..fe82e9cfe0a5b 100644 --- a/test/dygraph_to_static/test_cinn_prim_mean.py +++ b/test/dygraph_to_static/test_cinn_prim_mean.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle from paddle import tensor @@ -55,8 +55,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestPrimForward(unittest.TestCase): +class TestPrimForward(Dy2StTestBase): """ This case only tests prim_forward + to_static + cinn. Thus we need to set this flag as False to avoid prim_backward. @@ -112,7 +111,7 @@ def check_prim(self, net, use_prim): # Ensure that reduce_mean is splitted into small ops self.assertTrue('reduce_mean' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim_forward(self): for shape in self.shapes: for dtype in self.dtypes: @@ -134,8 +133,7 @@ def test_cinn_prim_forward(self): ) -@dy2static_unittest -class TestPrimForwardAndBackward(unittest.TestCase): +class TestPrimForwardAndBackward(Dy2StTestBase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph """ @@ -187,7 +185,7 @@ def check_prim(self, net, use_prim): # Ensure that reduce_mean is splitted into small ops self.assertTrue('reduce_mean' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim(self): for shape in self.shapes: for dtype in self.dtypes: diff --git a/test/dygraph_to_static/test_closure_analysis.py b/test/dygraph_to_static/test_closure_analysis.py index de1d1e12d6502..fe390108ed7d5 100644 --- a/test/dygraph_to_static/test_closure_analysis.py +++ b/test/dygraph_to_static/test_closure_analysis.py @@ -15,10 +15,7 @@ import inspect import unittest -from dygraph_to_static_utils_new import ( - Dy2StTestBase, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from numpy import append import paddle @@ -263,7 +260,7 @@ def init_dygraph_func(self): class TestPushPopTrans(Dy2StTestBase): - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test(self): def vlist_of_dict(x): ma = {'a': []} @@ -274,7 +271,7 @@ def vlist_of_dict(x): x = paddle.to_tensor([3]) print(paddle.jit.to_static(vlist_of_dict)(x)) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test2(self): import numpy as np @@ -287,7 +284,7 @@ def vlist_of_dict(x): x = paddle.to_tensor([3]) print(paddle.jit.to_static(vlist_of_dict)(x)) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test3(self): import numpy as np @@ -300,7 +297,7 @@ def vlist_of_dict(x): x = paddle.to_tensor([3]) print(paddle.jit.to_static(vlist_of_dict)(x)) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test4(self): import numpy as np @@ -313,7 +310,7 @@ def vlist_of_dict(x): x = paddle.to_tensor([3]) print(paddle.jit.to_static(vlist_of_dict)(x)) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test5(self): import numpy as np diff --git a/test/dygraph_to_static/test_container.py b/test/dygraph_to_static/test_container.py index 412362ba725c5..964bc270b59a4 100644 --- a/test/dygraph_to_static/test_container.py +++ b/test/dygraph_to_static/test_container.py @@ -17,7 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle @@ -70,8 +70,7 @@ def forward(self, x): return self.layers(x) -@dy2static_unittest -class TestSequential(unittest.TestCase): +class TestSequential(Dy2StTestBase): def setUp(self): paddle.set_device('cpu') self.seed = 2021 diff --git a/test/dygraph_to_static/test_convert_call.py b/test/dygraph_to_static/test_convert_call.py index 723d3f910debd..bd21698579d93 100644 --- a/test/dygraph_to_static/test_convert_call.py +++ b/test/dygraph_to_static/test_convert_call.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle import paddle.jit.dy2static as _jst @@ -77,8 +77,7 @@ def dyfunc_with_staticmethod(x_v): return a.add(x_v, x_v) -@dy2static_unittest -class TestRecursiveCall1(unittest.TestCase): +class TestRecursiveCall1(Dy2StTestBase): def setUp(self): self.input = np.random.random([10, 16]).astype('float32') self.place = ( @@ -169,8 +168,7 @@ def forward(self, inputs): return self.act(out) -@dy2static_unittest -class TestRecursiveCall2(unittest.TestCase): +class TestRecursiveCall2(Dy2StTestBase): def setUp(self): self.input = np.random.random((1, 3, 3, 5)).astype('float32') self.place = ( @@ -253,7 +251,6 @@ def test_code(self): ) -@dy2static_unittest class TestNotToConvert2(TestRecursiveCall2): def set_func(self): self.net = NotToStaticHelper() @@ -266,7 +263,7 @@ def test_conversion_options(self): self.assertIsNotNone(options) self.assertTrue(options.not_convert) - @ast_only_test + @test_ast_only def test_code(self): self.dygraph_func = paddle.jit.to_static(self.net.sum) # check 'if statement' is not converted @@ -281,23 +278,22 @@ def forward(self, x): return x -@dy2static_unittest -class TestConvertPaddleAPI(unittest.TestCase): - @ast_only_test +class TestConvertPaddleAPI(Dy2StTestBase): + @test_ast_only def test_functional_api(self): func = paddle.nn.functional.relu func = paddle.jit.to_static(func) self.assertNotIn("_jst.IfElse", func.code) self.assertIn("if in_dynamic_or_pir_mode()", func.code) - @ast_only_test + @test_ast_only def test_class_api(self): bn = paddle.nn.SyncBatchNorm(2) paddle.jit.to_static(bn) self.assertNotIn("_jst.IfElse", bn.forward.code) self.assertIn("if in_dynamic_mode()", bn.forward.code) - @ast_only_test + @test_ast_only def test_class_patch_api(self): paddle.nn.SyncBatchNorm.forward = forward bn = paddle.nn.SyncBatchNorm(2) diff --git a/test/dygraph_to_static/test_convert_call_generator.py b/test/dygraph_to_static/test_convert_call_generator.py index dd9d93c907c55..b3793fa22d289 100644 --- a/test/dygraph_to_static/test_convert_call_generator.py +++ b/test/dygraph_to_static/test_convert_call_generator.py @@ -14,10 +14,10 @@ import unittest -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -36,11 +36,10 @@ def main_func(): print(i) -@dy2static_unittest -class TestConvertGenerator(unittest.TestCase): +class TestConvertGenerator(Dy2StTestBase): # fallback will ok. - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_raise_error(self): translator_logger.verbosity_level = 1 with self.assertLogs( diff --git a/test/dygraph_to_static/test_convert_operators.py b/test/dygraph_to_static/test_convert_operators.py index 02d0c09a70857..05a6d4de9c7d9 100644 --- a/test/dygraph_to_static/test_convert_operators.py +++ b/test/dygraph_to_static/test_convert_operators.py @@ -15,10 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -44,10 +44,9 @@ def forward(self): net.forward = "A string so that convert forward will fail" -@dy2static_unittest -class TestConvertCall(unittest.TestCase): +class TestConvertCall(Dy2StTestBase): # fallback mode will raise a InnerError, it's ok. - @ast_only_test + @test_ast_only def test_class_exception(self): @paddle.jit.to_static def call_not_exist(): @@ -73,8 +72,7 @@ def callable_list(x, y): self.assertEqual(callable_list(1, 2), 3) -@dy2static_unittest -class TestConvertShapeCompare(unittest.TestCase): +class TestConvertShapeCompare(Dy2StTestBase): def test_non_variable(self): self.assertEqual( paddle.jit.dy2static.convert_shape_compare(1, "<", 2), True @@ -136,7 +134,7 @@ def error_func(): False, ) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_variable(self): paddle.enable_static() with paddle.static.program_guard( @@ -210,9 +208,8 @@ def forward(self, x): return out -@dy2static_unittest -class TestChooseShapeAttrOrApiWithLayer(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestChooseShapeAttrOrApiWithLayer(Dy2StTestBase): + @test_legacy_and_pir def test_tensor_shape(self): x = paddle.zeros(shape=[4, 1], dtype='float32') net = ShapeLayer() @@ -221,9 +218,8 @@ def test_tensor_shape(self): np.testing.assert_array_equal(out.numpy(), x.numpy()) -@dy2static_unittest -class TestIfElseNoValue(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestIfElseNoValue(Dy2StTestBase): + @test_legacy_and_pir def test_else_ret_none(self): input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) @@ -253,7 +249,7 @@ def without_common_value(x, use_cache=False): out = without_common_value(input_x, False) self.assertIsNone(out) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_else_ret_c(self): input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) @@ -286,7 +282,7 @@ def without_common_value(x, use_cache=False): self.assertListEqual(paddle.tolist(y), paddle.tolist(input_x + 1)) self.assertListEqual(paddle.tolist(z), paddle.tolist(input_x + 2)) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_else_ret_cz(self): input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) diff --git a/test/dygraph_to_static/test_cpu_cuda_to_tensor.py b/test/dygraph_to_static/test_cpu_cuda_to_tensor.py index b6e55b8900c1e..1d199dc8138df 100644 --- a/test/dygraph_to_static/test_cpu_cuda_to_tensor.py +++ b/test/dygraph_to_static/test_cpu_cuda_to_tensor.py @@ -15,18 +15,16 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - sot_only_test, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle -@dy2static_unittest -class TestCpuCuda(unittest.TestCase): +class TestCpuCuda(Dy2StTestBase): def test_cpu_cuda(self): def func(x): x = paddle.to_tensor([1, 2, 3, 4]) @@ -39,9 +37,8 @@ def func(x): # print(paddle.jit.to_static(func)(x)) -@dy2static_unittest -class TestToTensor(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestToTensor(Dy2StTestBase): + @test_legacy_and_pir def test_to_tensor_with_variable_list(self): def func(x): ones = paddle.to_tensor(1) @@ -58,10 +55,9 @@ def func(x): ) -@dy2static_unittest -class TestToTensor1(unittest.TestCase): - @ast_only_test - @test_and_compare_with_new_ir(False) +class TestToTensor1(Dy2StTestBase): + @test_ast_only + @test_legacy_and_pir def test_to_tensor_with_variable_list(self): def func(x): ones = paddle.to_tensor([1]) @@ -79,8 +75,8 @@ def func(x): rtol=1e-05, ) - @sot_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_to_tensor_with_variable_list_sot(self): def func(x): ones = paddle.to_tensor([1]) @@ -99,10 +95,9 @@ def func(x): ) -@dy2static_unittest -class TestToTensor2(unittest.TestCase): - @ast_only_test - @test_and_compare_with_new_ir(False) +class TestToTensor2(Dy2StTestBase): + @test_ast_only + @test_legacy_and_pir def test_to_tensor_with_variable_list(self): def func(x): x = paddle.to_tensor([[1], [2], [3], [4]]) @@ -115,8 +110,8 @@ def func(x): rtol=1e-05, ) - @sot_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_to_tensor_with_variable_list_sot(self): def func(x): x = paddle.to_tensor([[1], [2], [3], [4]]) diff --git a/test/dygraph_to_static/test_cycle_gan.py b/test/dygraph_to_static/test_cycle_gan.py index fb06a52407ec6..58560286b3020 100644 --- a/test/dygraph_to_static/test_cycle_gan.py +++ b/test/dygraph_to_static/test_cycle_gan.py @@ -26,14 +26,10 @@ # Use GPU:0 to elimate the influence of other tasks. os.environ["CUDA_VISIBLE_DEVICES"] = "1" -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle from paddle.base.dygraph import to_variable -from paddle.jit.api import to_static from paddle.nn import BatchNorm # Note: Set True to eliminate randomness. @@ -74,7 +70,6 @@ def __init__(self, input_channel, istrain=True): input_channel ) - @to_static def forward(self, input_A, input_B): """ Generator of GAN model. @@ -125,7 +120,6 @@ def forward(self, input_A, input_B): g_loss, ) - @to_static def discriminatorA(self, input_A, input_B): """ Discriminator A of GAN model. @@ -135,7 +129,6 @@ def discriminatorA(self, input_A, input_B): return rec_B, fake_pool_rec_B - @to_static def discriminatorB(self, input_A, input_B): """ Discriminator B of GAN model. @@ -547,7 +540,6 @@ def train(args, to_static): ) paddle.jit.enable_to_static(to_static) - with base.dygraph.guard(place): max_images_num = args.max_images_num data_shape = [-1] + args.image_shape @@ -561,7 +553,9 @@ def train(args, to_static): B_pool = ImagePool() A_reader = paddle.batch(reader_creater(), args.batch_size)() B_reader = paddle.batch(reader_creater(), args.batch_size)() - cycle_gan = Cycle_Gan(input_channel=data_shape[1], istrain=True) + cycle_gan = paddle.jit.to_static( + Cycle_Gan(input_channel=data_shape[1], istrain=True) + ) t_time = 0 vars_G = ( @@ -623,9 +617,9 @@ def train(args, to_static): fake_pool_A = to_variable(fake_pool_A) # optimize the d_A network - rec_B, fake_pool_rec_B = cycle_gan.discriminatorA( - data_B, fake_pool_B - ) + rec_B, fake_pool_rec_B = paddle.jit.to_static( + cycle_gan.discriminatorA + )(data_B, fake_pool_B) d_loss_A = ( paddle.square(fake_pool_rec_B) + paddle.square(rec_B - 1) ) / 2.0 @@ -636,9 +630,9 @@ def train(args, to_static): cycle_gan.clear_gradients() # optimize the d_B network - rec_A, fake_pool_rec_A = cycle_gan.discriminatorB( - data_A, fake_pool_A - ) + rec_A, fake_pool_rec_A = paddle.jit.to_static( + cycle_gan.discriminatorB + )(data_A, fake_pool_A) d_loss_B = ( paddle.square(fake_pool_rec_A) + paddle.square(rec_A - 1) ) / 2.0 @@ -679,8 +673,7 @@ def train(args, to_static): return np.array(loss_data) -@dy2static_unittest -class TestCycleGANModel(unittest.TestCase): +class TestCycleGANModel(Dy2StTestBase): def setUp(self): self.args = Args() @@ -688,7 +681,7 @@ def train(self, to_static): out = train(self.args, to_static) return out - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_train(self): st_out = self.train(to_static=True) dy_out = self.train(to_static=False) diff --git a/test/dygraph_to_static/test_declarative.py b/test/dygraph_to_static/test_declarative.py index 12b098cc10ac5..7c6eac567641f 100644 --- a/test/dygraph_to_static/test_declarative.py +++ b/test/dygraph_to_static/test_declarative.py @@ -19,8 +19,8 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) from test_basic_api_transformation import dyfunc_to_variable @@ -41,12 +41,15 @@ def __init__(self): super().__init__() self.linear = paddle.nn.Linear(10, 3) - @to_static(input_spec=[InputSpec(shape=[None, 10], dtype='float32')]) + @to_static( + input_spec=[InputSpec(shape=[None, 10], dtype='float32')], + full_graph=True, + ) def forward(self, x, a=1, b=2): y = self.inner_function(x) return y - @to_static + @to_static(full_graph=True) def inner_function(self, x): y = self.linear(x) return y @@ -55,7 +58,10 @@ def add_func(self, x, y): z = x + y return z - @to_static(input_spec=[[InputSpec([None, 10]), InputSpec([None, 10])]]) + @to_static( + input_spec=[[InputSpec([None, 10]), InputSpec([None, 10])]], + full_graph=True, + ) def func_with_list(self, l, int_val=1): x, y = l z = x + y @@ -63,7 +69,8 @@ def func_with_list(self, l, int_val=1): return z @to_static( - input_spec=[{'x': InputSpec([None, 10]), 'y': InputSpec([None, 10])}] + input_spec=[{'x': InputSpec([None, 10]), 'y': InputSpec([None, 10])}], + full_graph=True, ) def func_with_dict(self, d): x = d['x'] @@ -78,7 +85,8 @@ def func_with_dict(self, d): InputSpec([None]), {'x': InputSpec([None, 10]), 'y': InputSpec([None, 10])}, ] - ] + ], + full_graph=True, ) def func_with_list_dict(self, dl): bias = dl[0] @@ -116,8 +124,8 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_with_input_spec(self): with base.dygraph.guard(base.CPUPlace()): x = to_variable(np.ones([4, 10]).astype('float32')) @@ -178,7 +186,7 @@ def test_with_error(self): ) net.add_func(x, y) - @ast_only_test + @test_ast_only def test_concrete_program(self): with base.dygraph.guard(base.CPUPlace()): x = to_variable(np.ones([4, 10]).astype('float32')) @@ -218,8 +226,8 @@ class TestDifferentInputSpecCacheProgram(Dy2StTestBase): def setUp(self): paddle.jit.enable_to_static(True) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_with_different_input(self): with base.dygraph.guard(base.CPUPlace()): x_data = np.ones([16, 10]).astype('float32') @@ -265,7 +273,7 @@ def test_with_different_input(self): recent_program = foo.program_cache.last() self.assertTrue(first_program == recent_program) - @ast_only_test + @test_ast_only def test_get_concrete_program(self): foo = to_static(foo_func) @@ -306,8 +314,8 @@ def test_get_concrete_program(self): InputSpec([10]), InputSpec([10]), e=4 ) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_concrete_program(self): with base.dygraph.guard(base.CPUPlace()): # usage 1 @@ -356,7 +364,7 @@ def test_nest_input(self): class TestDeclarativeAPI(Dy2StTestBase): - @ast_only_test + @test_ast_only def test_error(self): func = to_static(dyfunc_to_variable) @@ -380,15 +388,15 @@ def setUp(self): paddle.jit.enable_to_static(True) self.x = to_variable(np.ones([4, 10]).astype('float32')) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_fake_input(self): net = SimpleNet() net = to_static(net) y = net(self.x) self.assertTrue(len(net.forward.program_cache) == 1) - @ast_only_test + @test_ast_only def test_input_spec(self): net = SimpleNet() net = to_static(net, input_spec=[InputSpec([None, 8, 10])]) @@ -430,7 +438,7 @@ def __init__(self): super().__init__() self.sub = CallNonForwardFuncSubNet() - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def forward(self): return self.sub.func() @@ -446,7 +454,7 @@ def func(self): class TestCallNonForwardFunc(Dy2StTestBase): - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_call_non_forward(self): paddle.disable_static() net = CallNonForwardFuncNet() @@ -460,7 +468,7 @@ def __init__(self): super().__init__() self.a = paddle.to_tensor([1]) - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def forward(self): self.a = self.a + 1 return self.a @@ -471,7 +479,7 @@ def __init__(self): super().__init__() self.b = paddle.to_tensor([2]) - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def forward(self): self.b = None self.b = paddle.to_tensor([3]) @@ -486,7 +494,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_set_buffers1(self): paddle.disable_static() net = SetBuffersNet1() @@ -495,7 +503,7 @@ def test_set_buffers1(self): paddle.jit.save(net, self.model_path) paddle.enable_static() - @ast_only_test + @test_ast_only def test_set_buffers2(self): paddle.disable_static() net = SetBuffersNet2() diff --git a/test/dygraph_to_static/test_decorator_transform.py b/test/dygraph_to_static/test_decorator_transform.py index 4f4096d607dc8..4ab416cceaa10 100644 --- a/test/dygraph_to_static/test_decorator_transform.py +++ b/test/dygraph_to_static/test_decorator_transform.py @@ -21,8 +21,8 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -186,7 +186,7 @@ def deco_with_paddle_api(): class TestDecoratorTransform(Dy2StTestBase): - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_deco_transform(self): outs = paddle.jit.to_static(forward)() np.testing.assert_allclose(outs[0], np.array(3), rtol=1e-05) @@ -198,7 +198,7 @@ def test_deco_transform(self): np.testing.assert_allclose(outs[6], np.array(9), rtol=1e-05) np.testing.assert_allclose(outs[7], np.array(10), rtol=1e-05) - @ast_only_test + @test_ast_only def test_contextmanager_warning(self): paddle.disable_static() with warnings.catch_warnings(record=True) as w: @@ -215,7 +215,7 @@ def test_contextmanager_warning(self): break self.assertTrue(flag) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_deco_with_paddle_api(self): self.assertTrue(deco_with_paddle_api()) diff --git a/test/dygraph_to_static/test_deepcopy.py b/test/dygraph_to_static/test_deepcopy.py index d291927b73ddd..5d281ba8ea213 100644 --- a/test/dygraph_to_static/test_deepcopy.py +++ b/test/dygraph_to_static/test_deepcopy.py @@ -16,23 +16,15 @@ from copy import deepcopy import numpy as np -from dygraph_to_static_utils_new import ( - Dy2StTestBase, - IrMode, - ToStaticMode, - disable_test_case, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from test_rollback import Net, foo import paddle from paddle.jit.dy2static.program_translator import StaticFunction -# @dy2static_unittest class TestDeepCopy(Dy2StTestBase): - @test_and_compare_with_new_ir(False) - @disable_test_case((ToStaticMode.SOT, IrMode.PIR)) + @test_legacy_and_pir def test_net(self): net = Net() net = paddle.jit.to_static(net) @@ -48,7 +40,7 @@ def test_net(self): self.assertTrue(id(copy_net), id(copy_net.forward.__self__)) np.testing.assert_array_equal(src_out.numpy(), copy_out.numpy()) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_func(self): st_foo = paddle.jit.to_static(foo) x = paddle.randn([3, 4]) diff --git a/test/dygraph_to_static/test_dict.py b/test/dygraph_to_static/test_dict.py index 99364c1343a7d..c88496fd86b3e 100644 --- a/test/dygraph_to_static/test_dict.py +++ b/test/dygraph_to_static/test_dict.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import base @@ -119,8 +116,7 @@ def update_cache(cache): return cache -@dy2static_unittest -class TestNetWithDict(unittest.TestCase): +class TestNetWithDict(Dy2StTestBase): """ TestCase for the transformation from control flow `if/else` dependent on tensor in Dygraph into Static `base.layers.cond`. @@ -130,7 +126,7 @@ def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.batch_size = self.x.shape[0] - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def _run_static(self): return self.train(to_static=True) @@ -173,8 +169,7 @@ def test_dic_pop_2(x): return out -@dy2static_unittest -class TestDictPop(unittest.TestCase): +class TestDictPop(Dy2StTestBase): def setUp(self): self.input = np.random.random(3).astype('int32') self.place = ( @@ -187,7 +182,7 @@ def setUp(self): def _set_test_func(self): self.dygraph_func = test_dic_pop - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def _run_static(self): return self._run(to_static=True) @@ -254,8 +249,7 @@ def test_ast_to_func(self): ) -@dy2static_unittest -class TestDictCmpInFor(unittest.TestCase): +class TestDictCmpInFor(Dy2StTestBase): def test_with_for(self): def func(): pos = [1, 3] diff --git a/test/dygraph_to_static/test_drop_path.py b/test/dygraph_to_static/test_drop_path.py index aad752007ceb0..d559ce7f55ac2 100644 --- a/test/dygraph_to_static/test_drop_path.py +++ b/test/dygraph_to_static/test_drop_path.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -39,15 +36,14 @@ def forward(self, x): return drop_path(x, self.training) -@dy2static_unittest -class TestTrainEval(unittest.TestCase): +class TestTrainEval(Dy2StTestBase): def setUp(self): self.model = DropPath() def tearDown(self): pass - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_train_and_eval(self): x = paddle.to_tensor([1, 2, 3]).astype("int64") eval_out = x.numpy() diff --git a/test/dygraph_to_static/test_duplicate_output.py b/test/dygraph_to_static/test_duplicate_output.py index add3a7262446a..70637729671f0 100644 --- a/test/dygraph_to_static/test_duplicate_output.py +++ b/test/dygraph_to_static/test_duplicate_output.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -41,8 +38,7 @@ def forward(self, x): return x, x -@dy2static_unittest -class TestDuplicateOutput(unittest.TestCase): +class TestDuplicateOutput(Dy2StTestBase): """ TestCase for the transformation from control flow `if/else` dependent on tensor in Dygraph into Static `base.layers.cond`. @@ -52,11 +48,14 @@ def setUp(self): self.net = paddle.jit.to_static(SimpleNet()) self.x = paddle.to_tensor([1.0]) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def _run_static(self): + param = self.net.parameters() + param[0].clear_grad() + loss0, loss1 = self.net(self.x) loss0.backward() - param = self.net.parameters() + self.assertEqual(param[0].grad.numpy(), 1.0) def test_ast_to_func(self): diff --git a/test/dygraph_to_static/test_error.py b/test/dygraph_to_static/test_error.py index c12dc3887f23d..9bb23945970c6 100644 --- a/test/dygraph_to_static/test_error.py +++ b/test/dygraph_to_static/test_error.py @@ -29,7 +29,7 @@ def inner_func(): return # noqa: PLR1711 -@paddle.jit.to_static +@paddle.jit.to_static(full_graph=True) def func_error_in_compile_time(x): x = base.dygraph.to_variable(x) inner_func() @@ -40,14 +40,14 @@ def func_error_in_compile_time(x): return x_v -@paddle.jit.to_static +@paddle.jit.to_static(full_graph=True) def func_error_in_compile_time_2(x): x = base.dygraph.to_variable(x) x = paddle.reshape(x, shape=[1, 2]) return x -@paddle.jit.to_static +@paddle.jit.to_static(full_graph=True) def func_error_in_runtime(x): x = base.dygraph.to_variable(x) two = paddle.tensor.fill_constant(shape=[1], value=2, dtype="int32") @@ -56,12 +56,12 @@ def func_error_in_runtime(x): @unwrap -@paddle.jit.to_static() +@paddle.jit.to_static(full_graph=True) def func_decorated_by_other_1(): return 1 -@paddle.jit.to_static() +@paddle.jit.to_static(full_graph=True) @unwrap def func_decorated_by_other_2(): return 1 @@ -73,7 +73,8 @@ def __init__(self, fc_size=20): self._linear = paddle.nn.Linear(fc_size, fc_size) @paddle.jit.to_static( - input_spec=[paddle.static.InputSpec(shape=[20, 20], dtype='float32')] + input_spec=[paddle.static.InputSpec(shape=[20, 20], dtype='float32')], + full_graph=True, ) def forward(self, x): y = self._linear(x) @@ -86,7 +87,7 @@ class LayerErrorInCompiletime2(paddle.nn.Layer): def __init__(self): super().__init__() - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def forward(self): self.test_func() @@ -98,7 +99,7 @@ def test_func(self): return # noqa: PLR1711 -@paddle.jit.to_static +@paddle.jit.to_static(full_graph=True) def func_error_in_runtime_with_empty_line(x): x = base.dygraph.to_variable(x) two = paddle.tensor.fill_constant(shape=[1], value=2, dtype="int32") @@ -113,7 +114,7 @@ def __init__(self): super().__init__() self.inner_net = SuggestionErrorTestNet2() - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def forward(self, x): return self.inner_net.forward(x) @@ -255,9 +256,7 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 35, in func_error_in_compile_time', 'inner_func()', - f'File "{self.filepath}", line 28, in inner_func', 'def inner_func():', 'paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")', '<--- HERE', @@ -284,7 +283,6 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 46, in func_error_in_compile_time_2', 'def func_error_in_compile_time_2(x):', 'x = base.dygraph.to_variable(x)', 'x = paddle.reshape(x, shape=[1, 2])', @@ -308,7 +306,6 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 91, in forward', '@paddle.jit.to_static', 'def forward(self):', 'self.test_func()', @@ -332,7 +329,6 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 54, in func_error_in_runtime', 'x = base.dygraph.to_variable(x)', 'two = paddle.tensor.fill_constant(shape=[1], value=2, dtype="int32")', 'x = paddle.reshape(x, shape=[1, two])', @@ -347,9 +343,6 @@ def set_func(self): def set_message(self): self.expected_message = [ - 'File "{}", line 106, in func_error_in_runtime_with_empty_line'.format( - self.filepath - ), 'two = paddle.tensor.fill_constant(shape=[1], value=2, dtype="int32")', 'x = paddle.reshape(x, shape=[1, two])', '<--- HERE', @@ -370,7 +363,6 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 80, in forward', 'def forward(self, x):', 'y = self._linear(x)', 'z = paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")', @@ -389,7 +381,7 @@ def test_error(self): self._test_raise_new_exception() -@paddle.jit.to_static +@paddle.jit.to_static(full_graph=True) def func_ker_error(x): d = {'x': x} y = d['y'] + x @@ -404,7 +396,7 @@ def test_key_error(self): func_ker_error(x) -@paddle.jit.to_static +@paddle.jit.to_static(full_graph=True) def NpApiErr(): a = paddle.to_tensor([1, 2]) b = np.sum(a.numpy()) @@ -434,7 +426,7 @@ def __init__(self): super().__init__() self.linear = paddle.nn.Linear(5, 2) - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def forward(self, x): old_dict = self.state_dict() wgt = old_dict['linear.weight'] diff --git a/test/dygraph_to_static/test_fallback.py b/test/dygraph_to_static/test_fallback.py index 58394feda2a68..9cfcc66b9fdc9 100644 --- a/test/dygraph_to_static/test_fallback.py +++ b/test/dygraph_to_static/test_fallback.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle @@ -51,8 +51,7 @@ def forward(self, x): return unsupport_func(x - 1) -@dy2static_unittest -class TestFallback(unittest.TestCase): +class TestFallback(Dy2StTestBase): def setUp(self): self.x = paddle.to_tensor([2]).astype('int') @@ -86,7 +85,7 @@ def test_case_net_fallback(self): u_net(self.x).numpy(), ) - @ast_only_test + @test_ast_only def test_case_net_error(self): s_net = SuppportNet() u_net = UnsuppportNet() diff --git a/test/dygraph_to_static/test_fetch_feed.py b/test/dygraph_to_static/test_fetch_feed.py index b44578fad2c9e..0df5e766df317 100644 --- a/test/dygraph_to_static/test_fetch_feed.py +++ b/test/dygraph_to_static/test_fetch_feed.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import base @@ -65,8 +62,7 @@ def forward(self, x): return pre, loss -@dy2static_unittest -class TestPool2D(unittest.TestCase): +class TestPool2D(Dy2StTestBase): def setUp(self): self.dygraph_class = Pool2D self.data = np.random.random((1, 2, 4, 4)).astype('float32') @@ -83,7 +79,7 @@ def train(self, to_static=False): return prediction.numpy() - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def train_static(self): return self.train(to_static=True) diff --git a/test/dygraph_to_static/test_for_enumerate.py b/test/dygraph_to_static/test_for_enumerate.py index dc9505a5cf6fc..2c686678a41b2 100644 --- a/test/dygraph_to_static/test_for_enumerate.py +++ b/test/dygraph_to_static/test_for_enumerate.py @@ -17,7 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle from paddle import base @@ -354,8 +354,7 @@ def tensor_array_slice_in_enumerate(): return feat_n2 -@dy2static_unittest -class TestTransformBase(unittest.TestCase): +class TestTransformBase(Dy2StTestBase): def setUp(self): self.place = ( base.CUDAPlace(0) @@ -558,8 +557,7 @@ def test_transformed_result_compare(self): self.transformed_result_compare() -@dy2static_unittest -class TestForZip(unittest.TestCase): +class TestForZip(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_full_name_usage.py b/test/dygraph_to_static/test_full_name_usage.py index 39a80acb566ea..db15692b6fb5e 100644 --- a/test/dygraph_to_static/test_full_name_usage.py +++ b/test/dygraph_to_static/test_full_name_usage.py @@ -15,13 +15,13 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle from paddle import base -@paddle.jit.to_static +@paddle.jit.to_static(full_graph=True) def dygraph_decorated_func(x): x = base.dygraph.to_variable(x) if paddle.mean(x) > 0: @@ -31,7 +31,7 @@ def dygraph_decorated_func(x): return x_v -@paddle.jit.to_static +@paddle.jit.to_static(full_graph=True) def jit_decorated_func(x): x = base.dygraph.to_variable(x) if paddle.mean(x) > 0: @@ -41,26 +41,25 @@ def jit_decorated_func(x): return x_v -@paddle.jit.to_static +@paddle.jit.to_static(full_graph=True) def decorated_call_decorated(x): return jit_decorated_func(x) class DoubleDecorated: @classmethod - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def double_decorated_func1(self, x): return dygraph_decorated_func(x) @classmethod - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def double_decorated_func2(self, x): return jit_decorated_func(x) -@dy2static_unittest -class TestFullNameDecorator(unittest.TestCase): - @ast_only_test +class TestFullNameDecorator(Dy2StTestBase): + @test_ast_only def test_run_success(self): x = np.ones([1, 2]).astype("float32") answer = np.zeros([1, 2]).astype("float32") diff --git a/test/dygraph_to_static/test_grad.py b/test/dygraph_to_static/test_grad.py index ceca09e789548..5bef08d9232d9 100644 --- a/test/dygraph_to_static/test_grad.py +++ b/test/dygraph_to_static/test_grad.py @@ -17,7 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle @@ -65,8 +65,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestGrad(unittest.TestCase): +class TestGrad(Dy2StTestBase): def setUp(self): self.func = paddle.jit.to_static(GradLayer()) self.x = paddle.ones(shape=[10, 2, 5], dtype='float32') @@ -84,7 +83,6 @@ def test_forward(self): np.testing.assert_allclose(static_res, dygraph_res, rtol=1e-05) -@dy2static_unittest class TestGradLinear(TestGrad): def setUp(self): self.func = paddle.jit.to_static(GradLinearLayer()) diff --git a/test/dygraph_to_static/test_gradient_aggregation.py b/test/dygraph_to_static/test_gradient_aggregation.py index 4172fb87197df..67b3ca8a987c7 100644 --- a/test/dygraph_to_static/test_gradient_aggregation.py +++ b/test/dygraph_to_static/test_gradient_aggregation.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -40,9 +37,8 @@ def forward(self, x): # return [out2, out1] # 梯度正常 -@dy2static_unittest -class TestGradientAggregationInDy2Static(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestGradientAggregationInDy2Static(Dy2StTestBase): + @test_legacy_and_pir def test_to_static(self): def simplenet_grad(inp, to_static=False): net = SimpleNet() diff --git a/test/dygraph_to_static/test_gradname_parse.py b/test/dygraph_to_static/test_gradname_parse.py index ca63511a5f6a3..7b46961207af4 100644 --- a/test/dygraph_to_static/test_gradname_parse.py +++ b/test/dygraph_to_static/test_gradname_parse.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle from paddle.nn import BatchNorm, Linear @@ -41,8 +41,7 @@ def forward(self, x): return dx[0] -@dy2static_unittest -class TestGradNameParse(unittest.TestCase): +class TestGradNameParse(Dy2StTestBase): def test_grad_name_parse(self): net = SimpleNet() opt = paddle.optimizer.Adam( @@ -69,8 +68,7 @@ def tanh_high_order_grad(x): return paddle.grad(y, x, create_graph=True)[0] -@dy2static_unittest -class TestTanhHighOrderGrad(unittest.TestCase): +class TestTanhHighOrderGrad(Dy2StTestBase): def setUp(self): self.func = tanh_high_order_grad @@ -118,7 +116,6 @@ def matmul_high_order_grad(x, y): return g[0] -@dy2static_unittest class TestMatMulHighOrderGrad1(TestTanhHighOrderGrad): def setUp(self): self.func = matmul_high_order_grad @@ -138,7 +135,6 @@ def setUp(self): self.dy2st_grad_input = (x2,) -@dy2static_unittest class TestMatMulHighOrderGrad2(TestTanhHighOrderGrad): def setUp(self): self.func = matmul_high_order_grad diff --git a/test/dygraph_to_static/test_grid_generator.py b/test/dygraph_to_static/test_grid_generator.py index 7c1a9189366e0..586302f385574 100644 --- a/test/dygraph_to_static/test_grid_generator.py +++ b/test/dygraph_to_static/test_grid_generator.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_utils_new import ( - Dy2StTestBase, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import ParamAttr, nn @@ -133,7 +130,7 @@ class TestGridGenerator(Dy2StTestBase): def setUp(self): self.x = paddle.uniform(shape=[1, 20, 2], dtype='float32') - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def _run(self, to_static): paddle.jit.enable_to_static(to_static) diff --git a/test/dygraph_to_static/test_ifelse.py b/test/dygraph_to_static/test_ifelse.py index 6c141aed8ff13..0fb5e5eb3c343 100644 --- a/test/dygraph_to_static/test_ifelse.py +++ b/test/dygraph_to_static/test_ifelse.py @@ -15,7 +15,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, +) from ifelse_simple_func import ( NetWithControlFlowIf, add_fn, @@ -55,14 +59,14 @@ place = base.CPUPlace() -@dy2static_unittest -class TestDy2staticException(unittest.TestCase): +class TestDy2staticException(Dy2StTestBase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = None self.error = "Your if/else have different number of return value." - @ast_only_test + @test_ast_only + @test_legacy_and_pir def test_error(self): if self.dyfunc: with self.assertRaisesRegex(Dygraph2StaticException, self.error): @@ -72,8 +76,7 @@ def test_error(self): paddle.jit.enable_to_static(False) -@dy2static_unittest -class TestDygraphIfElse(unittest.TestCase): +class TestDygraphIfElse(Dy2StTestBase): """ TestCase for the transformation from control flow `if/else` dependent on tensor in Dygraph into Static `base.layers.cond`. @@ -95,6 +98,7 @@ def _run_dygraph(self, to_static=False): ret = self.dyfunc(x_v) return ret.numpy() + @test_legacy_and_pir def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -123,11 +127,27 @@ def setUp(self): self.dyfunc = dyfunc_with_if_else_with_list_generator -class TestDygraphNestedIfElse(TestDygraphIfElse): +class TestDygraphNestedIfElse(Dy2StTestBase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = nested_if_else + def _run_static(self): + return self._run_dygraph(to_static=True) + + def _run_dygraph(self, to_static=False): + with base.dygraph.guard(place): + x_v = paddle.to_tensor(self.x) + if to_static: + ret = paddle.jit.to_static(self.dyfunc)(x_v) + else: + ret = self.dyfunc(x_v) + return ret.numpy() + + @test_legacy_and_pir + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + class TestDygraphNestedIfElse2(TestDygraphIfElse): def setUp(self): @@ -233,14 +253,29 @@ def setUp(self): self.dyfunc = if_with_class_var -class TestDygraphIfTensor(TestDygraphIfElse): +class TestDygraphIfTensor(Dy2StTestBase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = if_tensor_case + def _run_static(self): + return self._run_dygraph(to_static=True) + + def _run_dygraph(self, to_static=False): + with base.dygraph.guard(place): + x_v = paddle.to_tensor(self.x) + if to_static: + ret = paddle.jit.to_static(self.dyfunc)(x_v) + else: + ret = self.dyfunc(x_v) + return ret.numpy() + + @test_legacy_and_pir + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + -@dy2static_unittest -class TestDygraphIfElseNet(unittest.TestCase): +class TestDygraphIfElseNet(Dy2StTestBase): """ TestCase for the transformation from control flow `if/else` dependent on tensor in Dygraph into Static `base.layers.cond`. @@ -265,6 +300,7 @@ def _run(self, to_static=False): ret = net(x_v) return ret.numpy() + @test_legacy_and_pir def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -318,6 +354,10 @@ def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.Net = NetWithExternalFunc + @test_legacy_and_pir + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + class DiffModeNet1(paddle.nn.Layer): def __init__(self, mode): @@ -352,8 +392,7 @@ def forward(self, x, y): raise ValueError('Illegal mode') -@dy2static_unittest -class TestDiffModeNet(unittest.TestCase): +class TestDiffModeNet(Dy2StTestBase): """ TestCase for the net with different modes """ @@ -373,6 +412,7 @@ def _run(self, mode, to_static): ret = net(self.x, self.y) return ret.numpy() + @test_legacy_and_pir def test_train_mode(self): self.assertTrue( ( @@ -381,6 +421,7 @@ def test_train_mode(self): ).all() ) + @test_legacy_and_pir def test_infer_mode(self): self.assertTrue( ( @@ -395,8 +436,8 @@ def init_net(self): self.Net = DiffModeNet2 -@dy2static_unittest -class TestNewVarCreateInOneBranch(unittest.TestCase): +class TestNewVarCreateInOneBranch(Dy2StTestBase): + @test_legacy_and_pir def test_var_used_in_another_for(self): def case_func(training): # targets and targets_list is dynamically defined by training @@ -419,8 +460,7 @@ def case_func(training): self.assertEqual(paddle.jit.to_static(case_func)(True), -2) -@dy2static_unittest -class TestDy2StIfElseRetInt1(unittest.TestCase): +class TestDy2StIfElseRetInt1(Dy2StTestBase): def setUp(self): self.x = np.random.random([5]).astype('float32') self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int1) @@ -433,7 +473,8 @@ def get_dy2stat_out(self): paddle.jit.enable_to_static(False) return out - @ast_only_test + @test_ast_only + @test_legacy_and_pir def test_ast_to_func(self): self.setUp() self.assertIsInstance(self.out[0], (paddle.Tensor, core.eager.Tensor)) @@ -447,26 +488,26 @@ def setUp(self): self.dyfunc = dyfunc_ifelse_ret_int2 -@dy2static_unittest class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1): def setUp(self): self.x = np.random.random([5]).astype('float32') self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int3) self.out = self.get_dy2stat_out() - @ast_only_test + @test_ast_only + @test_legacy_and_pir def test_ast_to_func(self): self.setUp() self.assertIsInstance(self.out, (paddle.Tensor, core.eager.Tensor)) -@dy2static_unittest class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1): def setUp(self): self.x = np.random.random([5]).astype('float32') self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int4) - @ast_only_test + @test_ast_only + @test_legacy_and_pir def test_ast_to_func(self): paddle.jit.enable_to_static(True) with self.assertRaises(Dygraph2StaticException): @@ -501,8 +542,8 @@ def forward(self, a, b, c): return b -@dy2static_unittest -class TestDy2StIfElseBackward(unittest.TestCase): +class TestDy2StIfElseBackward(Dy2StTestBase): + # TODO(zhangbo): open pir test (IfOp grad execution not yet supported) def test_run_backward(self): a = paddle.randn((4, 3), dtype='float32') a.stop_gradient = False diff --git a/test/dygraph_to_static/test_inplace_assign.py b/test/dygraph_to_static/test_inplace_assign.py index 7eaaba72e2733..bbc46524b072b 100644 --- a/test/dygraph_to_static/test_inplace_assign.py +++ b/test/dygraph_to_static/test_inplace_assign.py @@ -15,12 +15,12 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle -class TestInplaceAssign(unittest.TestCase): +class TestInplaceAssign(Dy2StTestBase): def test_case0(self): a = paddle.ones((1024, 2)) * 1 b = paddle.ones((1024, 3)) * 2 @@ -45,7 +45,7 @@ def func(x): y.mean().backward() np.testing.assert_array_equal(x.grad.numpy(), np.array([2.0])) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_case2(self): @paddle.jit.to_static def func(a, x): diff --git a/test/dygraph_to_static/test_isinstance.py b/test/dygraph_to_static/test_isinstance.py index 7dfd05989dabe..23dcc38edddf8 100644 --- a/test/dygraph_to_static/test_isinstance.py +++ b/test/dygraph_to_static/test_isinstance.py @@ -26,10 +26,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import nn @@ -78,7 +75,7 @@ def forward(self, x): return res -@test_and_compare_with_new_ir(True) +@compare_legacy_with_pir def train(model, to_static): paddle.jit.enable_to_static(to_static) @@ -88,8 +85,7 @@ def train(model, to_static): return out.numpy() -@dy2static_unittest -class TestIsinstance(unittest.TestCase): +class TestIsinstance(Dy2StTestBase): def test_isinstance_simple_return_layer(self): model = IsInstanceLayer(SimpleReturnLayer()) self._test_model(model) diff --git a/test/dygraph_to_static/test_jit_property_save.py b/test/dygraph_to_static/test_jit_property_save.py index 965168dedc6ea..6a254215fc816 100644 --- a/test/dygraph_to_static/test_jit_property_save.py +++ b/test/dygraph_to_static/test_jit_property_save.py @@ -14,13 +14,12 @@ import unittest -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle -@dy2static_unittest -class TestPropertySave(unittest.TestCase): +class TestPropertySave(Dy2StTestBase): """test jit property save""" def setUp(self): diff --git a/test/dygraph_to_static/test_jit_setitem.py b/test/dygraph_to_static/test_jit_setitem.py index 219e6a6c9de74..0496c413aca50 100644 --- a/test/dygraph_to_static/test_jit_setitem.py +++ b/test/dygraph_to_static/test_jit_setitem.py @@ -16,14 +16,13 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle import paddle.nn.functional as F -@dy2static_unittest -class TestSetItemBase(unittest.TestCase): +class TestSetItemBase(Dy2StTestBase): def setUp(self) -> None: pass @@ -244,16 +243,12 @@ def foo(x, H, W): pad_list[3] = H // 2 pad_list[1] = W // 2 - # 问题在这里,进去F.pad以后,pad_list是初始变量而非赋值后的变量 - # 在修改前,赋值前后的变量是同一个,没有问题 - # 修改后,期望接收赋值后的变量,接收赋值前变量结果是不对的 x = F.pad(x, pad_list, data_format="NHWC") return x return foo def run_dygraph(self, func): - # 注释这句看结果diff x = paddle.ones((1, 6, 6, 3)) H = paddle.full([1], 6, dtype='int32') W = paddle.full([1], 6, dtype='int32') diff --git a/test/dygraph_to_static/test_lac.py b/test/dygraph_to_static/test_lac.py index 461b03fe7a5ed..d1feacae22262 100644 --- a/test/dygraph_to_static/test_lac.py +++ b/test/dygraph_to_static/test_lac.py @@ -22,7 +22,7 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "2" -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle from paddle import _legacy_C_ops, base @@ -515,8 +515,7 @@ def create_dataloader(reader, place): return data_loader -@dy2static_unittest -class TestLACModel(unittest.TestCase): +class TestLACModel(Dy2StTestBase): def setUp(self): self.args = Args() self.place = ( diff --git a/test/dygraph_to_static/test_lambda.py b/test/dygraph_to_static/test_lambda.py index add572cb6dfcf..5f80f85ba5cfb 100644 --- a/test/dygraph_to_static/test_lambda.py +++ b/test/dygraph_to_static/test_lambda.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle import paddle.nn.functional as F @@ -80,8 +80,7 @@ def call_lambda_with_ifExpr2(x): return out -@dy2static_unittest -class TestLambda(unittest.TestCase): +class TestLambda(Dy2StTestBase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.x = np.array([1, 3]).astype('float32') diff --git a/test/dygraph_to_static/test_layer_hook.py b/test/dygraph_to_static/test_layer_hook.py index d19b9ea9abfc9..7f4979b620e74 100644 --- a/test/dygraph_to_static/test_layer_hook.py +++ b/test/dygraph_to_static/test_layer_hook.py @@ -17,10 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle @@ -59,8 +56,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestNestLayerHook(unittest.TestCase): +class TestNestLayerHook(Dy2StTestBase): def setUp(self): paddle.seed(2022) self.x = paddle.randn([4, 10]) @@ -70,7 +66,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def train_net(self, to_static=False): paddle.seed(2022) net = SimpleNet() diff --git a/test/dygraph_to_static/test_len.py b/test/dygraph_to_static/test_len.py index 340ba86ff50c2..33c984a5520b2 100644 --- a/test/dygraph_to_static/test_len.py +++ b/test/dygraph_to_static/test_len.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle from paddle import base @@ -43,8 +43,7 @@ def len_with_lod_tensor_array(x): return arr_len -@dy2static_unittest -class TestLen(unittest.TestCase): +class TestLen(Dy2StTestBase): def setUp(self): self.place = ( base.CUDAPlace(0) @@ -115,8 +114,7 @@ def len_with_selected_rows(place): return result -@dy2static_unittest -class TestLenWithSelectedRows(unittest.TestCase): +class TestLenWithSelectedRows(Dy2StTestBase): def setUp(self): self.place = ( base.CUDAPlace(0) diff --git a/test/dygraph_to_static/test_list.py b/test/dygraph_to_static/test_list.py index 51b28ce3fe38a..111a3109b786c 100644 --- a/test/dygraph_to_static/test_list.py +++ b/test/dygraph_to_static/test_list.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle from paddle import base @@ -208,8 +208,7 @@ def test_list_pop_in_while_loop(x, iter_num): return a[0], b[2] -@dy2static_unittest -class TestListWithoutControlFlow(unittest.TestCase): +class TestListWithoutControlFlow(Dy2StTestBase): def setUp(self): self.place = ( base.CUDAPlace(0) @@ -337,7 +336,6 @@ def __init__(self): # Add *args to test function.__self__ in FunctionSpec. # DO NOT remove *args. - @paddle.jit.to_static def forward(self, x, index, *args): y = paddle.nn.functional.relu(x) a = [] @@ -356,13 +354,12 @@ def forward(self, x, index, *args): return z -@dy2static_unittest -class TestListWithCondGradInferVarType(unittest.TestCase): +class TestListWithCondGradInferVarType(Dy2StTestBase): def test_to_static(self): net = ListWithCondNet() x = paddle.to_tensor([2, 3, 4], dtype='float32') index = paddle.to_tensor([1]) - res = net(x, index) + res = paddle.jit.to_static(net)(x, index) self.assertEqual(res, 48.0) diff --git a/test/dygraph_to_static/test_load_transformer.py b/test/dygraph_to_static/test_load_transformer.py index 1e36145537f43..65f16a8bdcb2d 100644 --- a/test/dygraph_to_static/test_load_transformer.py +++ b/test/dygraph_to_static/test_load_transformer.py @@ -16,13 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_utils_new import ( - Dy2StTestBase, - IrMode, - ToStaticMode, - disable_test_case, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -51,8 +45,7 @@ class TestFallback(Dy2StTestBase): def setUp(self): self.x = paddle.to_tensor(1.0).astype('int') - @test_and_compare_with_new_ir(False) - @disable_test_case((ToStaticMode.SOT, IrMode.PIR)) + @test_legacy_and_pir def test_name_load(self): net_dy = Net() net_st = Net() @@ -62,8 +55,7 @@ def test_name_load(self): class TestLoad2(Dy2StTestBase): - @test_and_compare_with_new_ir(False) - @disable_test_case((ToStaticMode.SOT, IrMode.PIR)) + @test_legacy_and_pir def test_name_load_nograd(self): @paddle.no_grad() def func(x): diff --git a/test/dygraph_to_static/test_logical.py b/test/dygraph_to_static/test_logical.py index a05f91b7c0493..8a768a41e1340 100644 --- a/test/dygraph_to_static/test_logical.py +++ b/test/dygraph_to_static/test_logical.py @@ -18,7 +18,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle from paddle import base @@ -29,7 +29,6 @@ np.random.seed(22) -@paddle.jit.to_static def test_logical_not(x): x = paddle.to_tensor(x) if not x: @@ -51,7 +50,6 @@ def test_logical_not(x): return x -@paddle.jit.to_static def test_logical_not_2(x): x = paddle.to_tensor(x) @@ -64,7 +62,6 @@ def test_logical_not_2(x): return x -@paddle.jit.to_static def test_logical_and(x): x = paddle.to_tensor(x) @@ -82,7 +79,6 @@ def test_logical_and(x): return x -@paddle.jit.to_static def test_logical_and_2(x): x = paddle.to_tensor(x) @@ -106,7 +102,6 @@ def test_logical_and_2(x): return x -@paddle.jit.to_static def test_logical_or(x): x = paddle.to_tensor(x) @@ -124,7 +119,6 @@ def test_logical_or(x): return x -@paddle.jit.to_static def test_logical_or_2(x): x = paddle.to_tensor(x) @@ -136,7 +130,6 @@ def test_logical_or_2(x): return x -@paddle.jit.to_static def test_logical_not_and_or(x): x = paddle.to_tensor(x) @@ -148,7 +141,6 @@ def test_logical_not_and_or(x): return x -@paddle.jit.to_static def test_shape_equal(x): x = paddle.to_tensor(x) y = paddle.zeros([1, 2, 3]) @@ -158,7 +150,6 @@ def test_shape_equal(x): return paddle.ones([1, 2, 3]) -@paddle.jit.to_static def test_shape_not_equal(x): x = paddle.to_tensor(x) y = paddle.zeros([1, 2, 3]) @@ -168,8 +159,7 @@ def test_shape_not_equal(x): return paddle.ones([1, 2, 3]) -@dy2static_unittest -class TestLogicalBase(unittest.TestCase): +class TestLogicalBase(Dy2StTestBase): def setUp(self): self.input = np.array([3]).astype('int32') self.place = ( @@ -187,7 +177,7 @@ def _set_test_func(self): def _run(self, to_static): paddle.jit.enable_to_static(to_static) with base.dygraph.guard(self.place): - result = self.dygraph_func(self.input) + result = paddle.jit.to_static(self.dygraph_func)(self.input) return result.numpy() def _run_dygraph(self): @@ -264,8 +254,7 @@ def _set_test_func(self): self.dygraph_func = test_shape_not_equal -@dy2static_unittest -class TestCmpopNodeToStr(unittest.TestCase): +class TestCmpopNodeToStr(Dy2StTestBase): def test_exception(self): with self.assertRaises(KeyError): cmpop_node_to_str(gast.Or()) diff --git a/test/dygraph_to_static/test_loop.py b/test/dygraph_to_static/test_loop.py index 422508d6cd97e..3aefa231d6d27 100644 --- a/test/dygraph_to_static/test_loop.py +++ b/test/dygraph_to_static/test_loop.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle import paddle.nn.functional as F @@ -230,8 +230,7 @@ def for_loop_dufunc_with_listcomp(array): return res -@dy2static_unittest -class TestNameVisitor(unittest.TestCase): +class TestNameVisitor(Dy2StTestBase): def setUp(self): self.loop_funcs = [ while_loop_dyfunc, @@ -301,8 +300,7 @@ def test_nested_loop_vars(self): i += 1 -@dy2static_unittest -class TestTransformWhileLoop(unittest.TestCase): +class TestTransformWhileLoop(Dy2StTestBase): def setUp(self): self.place = ( base.CUDAPlace(0) @@ -381,8 +379,7 @@ def _init_dyfunc(self): self.dyfunc = loop_var_contains_property -@dy2static_unittest -class TestTransformForLoop(unittest.TestCase): +class TestTransformForLoop(Dy2StTestBase): def setUp(self): self.place = ( base.CUDAPlace(0) @@ -464,8 +461,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestForLoopMeetDict(unittest.TestCase): +class TestForLoopMeetDict(Dy2StTestBase): def test_start(self): net = Net() model = paddle.jit.to_static( diff --git a/test/dygraph_to_static/test_lstm.py b/test/dygraph_to_static/test_lstm.py index 79329d1d92bcb..09c89f78223dc 100644 --- a/test/dygraph_to_static/test_lstm.py +++ b/test/dygraph_to_static/test_lstm.py @@ -17,7 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle from paddle import nn @@ -45,8 +45,7 @@ def forward(self, x): return x -@dy2static_unittest -class TestLstm(unittest.TestCase): +class TestLstm(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.net = Net(12, 2) @@ -71,8 +70,7 @@ def test_lstm_to_static(self): static_out = self.run_lstm(to_static=True) np.testing.assert_allclose(dygraph_out, static_out, rtol=1e-05) - @ast_only_test - def test_save_in_eval(self, with_training=True): + def save_in_eval(self, with_training: bool): paddle.jit.enable_to_static(True) net = self.net x = self.inputs @@ -115,8 +113,13 @@ def test_save_in_eval(self, with_training=True): err_msg=f'dygraph_out is {dygraph_out}\n static_out is \n{train_out}', ) + @test_ast_only def test_save_without_training(self): - self.test_save_in_eval(with_training=False) + self.save_in_eval(with_training=False) + + @test_ast_only + def test_save_with_training(self): + self.save_in_eval(with_training=True) class TestLstmWithProjsize(unittest.TestCase): @@ -139,8 +142,7 @@ def forward(self, x): return y -@dy2static_unittest -class TestSaveInEvalMode(unittest.TestCase): +class TestSaveInEvalMode(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -183,8 +185,7 @@ def test_save_in_eval(self): ) -@dy2static_unittest -class TestEvalAfterSave(unittest.TestCase): +class TestEvalAfterSave(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_mnist.py b/test/dygraph_to_static/test_mnist.py index 984176a83afe0..9f3a307c44bb3 100644 --- a/test/dygraph_to_static/test_mnist.py +++ b/test/dygraph_to_static/test_mnist.py @@ -18,10 +18,10 @@ from time import time import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + compare_legacy_with_pir, + test_ast_only, ) from predictor_utils import PredictorTools @@ -130,8 +130,7 @@ def inference(self, inputs): return x -@dy2static_unittest -class TestMNIST(unittest.TestCase): +class TestMNIST(Dy2StTestBase): def setUp(self): self.epoch_num = 1 self.batch_size = 64 @@ -158,14 +157,14 @@ class TestMNISTWithToStatic(TestMNIST): still works if model is trained in dygraph mode. """ - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def train_static(self): return self.train(to_static=True) def train_dygraph(self): return self.train(to_static=False) - @ast_only_test + @test_ast_only def test_mnist_to_static(self): dygraph_loss = self.train_dygraph() static_loss = self.train_static() @@ -199,7 +198,7 @@ def train(self, to_static=False): base.default_startup_program().random_seed = SEED mnist = MNIST() if to_static: - mnist = paddle.jit.to_static(mnist) + mnist = paddle.jit.to_static(mnist, full_graph=True) adam = Adam(learning_rate=0.001, parameters=mnist.parameters()) for epoch in range(self.epoch_num): diff --git a/test/dygraph_to_static/test_mnist_amp.py b/test/dygraph_to_static/test_mnist_amp.py index 3e4b9d1b11657..8f0755351d767 100644 --- a/test/dygraph_to_static/test_mnist_amp.py +++ b/test/dygraph_to_static/test_mnist_amp.py @@ -16,7 +16,7 @@ from time import time import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import test_legacy_and_pir from test_mnist import MNIST, SEED, TestMNIST import paddle @@ -33,7 +33,7 @@ def train_static(self): def train_dygraph(self): return self.train(to_static=False) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_mnist_to_static(self): dygraph_loss = self.train_dygraph() static_loss = self.train_static() diff --git a/test/dygraph_to_static/test_mnist_pure_fp16.py b/test/dygraph_to_static/test_mnist_pure_fp16.py index c1489cc6e9158..7ba230c2a4686 100644 --- a/test/dygraph_to_static/test_mnist_pure_fp16.py +++ b/test/dygraph_to_static/test_mnist_pure_fp16.py @@ -16,7 +16,7 @@ from time import time import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import test_legacy_and_pir from test_mnist import MNIST, SEED, TestMNIST import paddle @@ -32,7 +32,7 @@ def train_static(self): def train_dygraph(self): return self.train(to_static=False) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_mnist_to_static(self): if paddle.base.is_compiled_with_cuda(): dygraph_loss = self.train_dygraph() diff --git a/test/dygraph_to_static/test_mobile_net.py b/test/dygraph_to_static/test_mobile_net.py index cca77999d5e7d..b594f3a6817c4 100644 --- a/test/dygraph_to_static/test_mobile_net.py +++ b/test/dygraph_to_static/test_mobile_net.py @@ -19,7 +19,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest, test_with_new_ir +from dygraph_to_static_utils_new import Dy2StTestBase, test_pir_only from predictor_utils import PredictorTools import paddle @@ -656,8 +656,7 @@ def predict_analysis_inference(args, data): return out -@dy2static_unittest -class TestMobileNet(unittest.TestCase): +class TestMobileNet(Dy2StTestBase): def setUp(self): self.args = Args() self.temp_dir = tempfile.TemporaryDirectory() @@ -727,8 +726,8 @@ def assert_same_predict(self, model_name): err_msg=f'inference_pred_res:\n {predictor_pre}\n, st_pre: \n{st_pre}.', ) - @test_with_new_ir - def test_mobile_net_new_ir(self): + @test_pir_only + def test_mobile_net_pir(self): # MobileNet-V1 self.assert_same_loss("MobileNetV1") # MobileNet-V2 diff --git a/test/dygraph_to_static/test_multi_forward.py b/test/dygraph_to_static/test_multi_forward.py index 2cf8e592f3fa0..bdcbda03de259 100644 --- a/test/dygraph_to_static/test_multi_forward.py +++ b/test/dygraph_to_static/test_multi_forward.py @@ -14,10 +14,7 @@ import unittest -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -36,9 +33,8 @@ def forward(self, x): return self.linear(x) -@dy2static_unittest -class TestBackward(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestBackward(Dy2StTestBase): + @test_legacy_and_pir def test_order_0(self): """ loss = 1 * w * 1 + 2 * w * 2 @@ -53,7 +49,7 @@ def test_order_0(self): loss.backward() self.assertEqual(model.linear.weight.grad, 5) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_order_1(self): """ loss = 2 * w * 2 + 1 * w * 1 diff --git a/test/dygraph_to_static/test_no_gradient.py b/test/dygraph_to_static/test_no_gradient.py new file mode 100644 index 0000000000000..b3bc726762ee4 --- /dev/null +++ b/test/dygraph_to_static/test_no_gradient.py @@ -0,0 +1,55 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy +from dygraph_to_static_utils_new import Dy2StTestBase + +import paddle + + +def static_func(x, no_grad_x): + tx = 2 * no_grad_x + tx.stop_gradient = True + return 2 * x + + +def main_func(x, index): + tmp = paddle.gather(x, index) + out = paddle.jit.to_static(static_func)(x, tmp) + return out + + +class TestNoGradientCase(Dy2StTestBase): + def test_no_gradient(self): + paddle.disable_static() + x = paddle.randn([10, 3]) + index = paddle.arange(0, 10, 1, dtype='int32') + x.stop_gradient = False + index.stop_gradient = True + + func = main_func + output = func(x, index).mean() + output.backward() + + self.assertTrue(x.grad is not None) + self.assertTrue( + numpy.all(x.grad.numpy() == paddle.full([10, 3], 2.0 / 30).numpy()) + ) + self.assertTrue(index.grad is None) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/dygraph_to_static/test_op_attr.py b/test/dygraph_to_static/test_op_attr.py index 6aaf1cdbf2138..012a10c3aa4a3 100644 --- a/test/dygraph_to_static/test_op_attr.py +++ b/test/dygraph_to_static/test_op_attr.py @@ -14,7 +14,7 @@ import unittest -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle from paddle.static import InputSpec @@ -42,7 +42,7 @@ def forward(self, x): out = self.bn(out) return out - @paddle.jit.to_static(input_spec=[InputSpec([10, 16])]) + @paddle.jit.to_static(input_spec=[InputSpec([10, 16])], full_graph=True) def with_cond(self, x): if paddle.mean(x) > 0.0: out = self.linear(x) @@ -52,8 +52,7 @@ def with_cond(self, x): return out -@dy2static_unittest -class CheckOpAttr(unittest.TestCase): +class CheckOpAttr(Dy2StTestBase): def setUp(self): self.in_num = 16 self.out_num = 16 @@ -78,7 +77,7 @@ def expected_results(self): 'elementwise_sub': self.sub_attrs, } - @ast_only_test + @test_ast_only def test_set_op_attrs(self): net = NetWithOpAttr(self.in_num, self.out_num) # set attrs @@ -120,7 +119,7 @@ def check_op_attrs(self, main_program): else: self.assertEqual(op_val, expect_val) - @ast_only_test + @test_ast_only def test_set_op_attrs_with_sub_block(self): net = NetWithOpAttr(self.in_num, self.out_num) # set attrs diff --git a/test/dygraph_to_static/test_origin_info.py b/test/dygraph_to_static/test_origin_info.py index be38650b750c2..183db1b0e60af 100644 --- a/test/dygraph_to_static/test_origin_info.py +++ b/test/dygraph_to_static/test_origin_info.py @@ -16,7 +16,7 @@ import sys import unittest -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase from paddle.jit.api import to_static from paddle.jit.dy2static import DygraphToStaticAst @@ -56,8 +56,7 @@ def decorated_func2(x): return x -@dy2static_unittest -class TestOriginInfo(unittest.TestCase): +class TestOriginInfo(Dy2StTestBase): def setUp(self): self.set_test_func() self.dygraph_func = unwrap(self.func) diff --git a/test/dygraph_to_static/test_param_guard.py b/test/dygraph_to_static/test_param_guard.py index c6787db58fc89..8e2e917c6af05 100644 --- a/test/dygraph_to_static/test_param_guard.py +++ b/test/dygraph_to_static/test_param_guard.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle from paddle.jit import to_static @@ -53,8 +50,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestParameterList(unittest.TestCase): +class TestParameterList(Dy2StTestBase): def setUp(self): self.seed = 2021 self.iter_num = 5 @@ -79,7 +75,7 @@ def train(self, is_iter, to_static): return loss - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_parameter_list(self): static_loss = self.train(False, to_static=True) dygraph_loss = self.train(False, to_static=False) @@ -106,8 +102,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestRawParameterList(unittest.TestCase): +class TestRawParameterList(Dy2StTestBase): def setUp(self): self.seed = 2021 self.iter_num = 5 @@ -133,7 +128,7 @@ def train(self, to_static): return loss - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_parameter_list(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) diff --git a/test/dygraph_to_static/test_params_no_grad.py b/test/dygraph_to_static/test_params_no_grad.py index 3b3f3949fad57..0ee66206a48a4 100644 --- a/test/dygraph_to_static/test_params_no_grad.py +++ b/test/dygraph_to_static/test_params_no_grad.py @@ -14,7 +14,7 @@ import unittest -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle import paddle.distributed as dist @@ -54,8 +54,7 @@ def train(): print(loss) -@dy2static_unittest -class TestParamsNoGrad(unittest.TestCase): +class TestParamsNoGrad(Dy2StTestBase): def test_two_card(self): if ( paddle.is_compiled_with_cuda() diff --git a/test/dygraph_to_static/test_partial_program.py b/test/dygraph_to_static/test_partial_program.py index f742a0cdb5337..cc3c5678c4843 100644 --- a/test/dygraph_to_static/test_partial_program.py +++ b/test/dygraph_to_static/test_partial_program.py @@ -17,11 +17,8 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, - IrMode, - ToStaticMode, - ast_only_test, - disable_test_case, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) from test_fetch_feed import Linear @@ -84,14 +81,15 @@ def _run(self, to_static): self.fake_input() if to_static: - out = paddle.jit.to_static(nested_input)(self.x, self.y) + out = paddle.jit.to_static(nested_input, full_graph=True)( + self.x, self.y + ) else: out = nested_input(self.x, self.y) return out.numpy() - @test_and_compare_with_new_ir(False) - @disable_test_case((ToStaticMode.SOT, IrMode.PIR)) + @test_legacy_and_pir def test_nest(self): dygraph_res = self._run(to_static=False) static_res = self._run(to_static=True) @@ -110,13 +108,15 @@ def _run(self, to_static): self.y = fake_data([10, 16]) if to_static: - out = paddle.jit.to_static(nested_output)(self.x, self.y) + out = paddle.jit.to_static(nested_output, full_graph=True)( + self.x, self.y + ) else: out = nested_output(self.x, self.y) return out - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_nest(self): dygraph_res = self._run(to_static=False) dygraph_res = paddle.utils.flatten(dygraph_res) @@ -136,12 +136,12 @@ def test_nest(self): class TestWithTrainAndEval(Dy2StTestBase): - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_switch_eval_and_train(self): with base.dygraph.guard(): linear_net = Linear() - linear_net = paddle.jit.to_static(linear_net) + linear_net = paddle.jit.to_static(linear_net, full_graph=True) x_data = np.random.random((4, 10)).astype('float32') x = base.dygraph.to_variable(x_data) linear_net(x) @@ -169,12 +169,12 @@ def test_switch_eval_and_train(self): class TestWithNoGrad(Dy2StTestBase): - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_with_no_grad(self): with base.dygraph.guard(): linear_net = Linear() - linear_net = paddle.jit.to_static(linear_net) + linear_net = paddle.jit.to_static(linear_net, full_graph=True) x_data = np.random.random((5, 10)).astype('float32') x = base.dygraph.to_variable(x_data) @@ -197,7 +197,7 @@ def __init__(self): np.random.rand(2, 3).astype('float32') ) - @to_static + @to_static(full_graph=True) def forward(self, x): x = paddle.reshape(x, shape=[-1, 6]) x1, x2, x3 = paddle.split(x=x, axis=1, num_or_sections=3) @@ -205,7 +205,7 @@ def forward(self, x): class TestPruneUnusedParamInProgram(Dy2StTestBase): - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_prune(self): input_ids = np.array([[15, 11, 6, 3, 18, 13]]).astype("float32") diff --git a/test/dygraph_to_static/test_partial_program_hook.py b/test/dygraph_to_static/test_partial_program_hook.py index c10194f6187ad..1b50b5b4add91 100644 --- a/test/dygraph_to_static/test_partial_program_hook.py +++ b/test/dygraph_to_static/test_partial_program_hook.py @@ -12,20 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle from paddle.base import core +from paddle.jit.api import ENV_ENABLE_SOT from paddle.jit.dy2static import partial_program, program_translator -@dy2static_unittest -class TestPartiaProgramLayerHook(unittest.TestCase): +class TestPartiaProgramLayerHook(Dy2StTestBase): def setUp(self): - os.environ["ENABLE_FALL_BACK"] = "False" + ENV_ENABLE_SOT.set(False) self._hook = partial_program.PartialProgramLayerHook() def test_before_append_backward(self): @@ -38,10 +37,9 @@ def test_after_infer(self): self.assertIsNone(self._hook.after_infer(None)) -@dy2static_unittest -class TestPrimHook(unittest.TestCase): +class TestPrimHook(Dy2StTestBase): def setUp(self): - os.environ["ENABLE_FALL_BACK"] = "False" + ENV_ENABLE_SOT.set(False) core._set_prim_all_enabled(False) def f(): diff --git a/test/dygraph_to_static/test_new_ir_selectedrows.py b/test/dygraph_to_static/test_pir_selectedrows.py similarity index 93% rename from test/dygraph_to_static/test_new_ir_selectedrows.py rename to test/dygraph_to_static/test_pir_selectedrows.py index e403cbd6089a1..f91c569e857fc 100644 --- a/test/dygraph_to_static/test_new_ir_selectedrows.py +++ b/test/dygraph_to_static/test_pir_selectedrows.py @@ -15,7 +15,7 @@ import random import unittest -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle.jit.api import to_static @@ -77,7 +77,7 @@ def train_dygraph(): return train(net, adam, x) -@test_and_compare_with_new_ir(True) +@compare_legacy_with_pir def train_static(): paddle.seed(100) net = IRSelectedRowsTestNet() @@ -87,10 +87,10 @@ def train_static(): parameters=net.parameters(), learning_rate=0.01, grad_clip=clip ) - return to_static(train)(net, adam, x) + return to_static(train, full_graph=True)(net, adam, x) -class TestSimnet(unittest.TestCase): +class TestSimnet(Dy2StTestBase): def test_dygraph_static_same_loss(self): dygraph_loss = train_dygraph() static_loss = train_static() diff --git a/test/dygraph_to_static/test_place.py b/test/dygraph_to_static/test_place.py index f1cb7e80589a3..f9aaca6932906 100644 --- a/test/dygraph_to_static/test_place.py +++ b/test/dygraph_to_static/test_place.py @@ -14,13 +14,12 @@ import unittest -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle -@dy2static_unittest -class TestPlace(unittest.TestCase): +class TestPlace(Dy2StTestBase): def test_place(self): paddle.enable_static() x = paddle.to_tensor([1, 2, 3, 4]) diff --git a/test/dygraph_to_static/test_print.py b/test/dygraph_to_static/test_print.py index 251bca776e700..35022512ce7f6 100644 --- a/test/dygraph_to_static/test_print.py +++ b/test/dygraph_to_static/test_print.py @@ -15,10 +15,7 @@ import unittest import numpy -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import base @@ -87,8 +84,7 @@ def dyfunc_print_with_kwargs(x): print("Tensor", x_t, end='\n\n', sep=': ') -@dy2static_unittest -class TestPrintBase(unittest.TestCase): +class TestPrintBase(Dy2StTestBase): def setUp(self): self.input = numpy.ones(5).astype("int32") self.place = ( @@ -110,7 +106,7 @@ def _run(self, to_static): def get_dygraph_output(self): self._run(to_static=False) - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): self._run(to_static=True) diff --git a/test/dygraph_to_static/test_program_translator.py b/test/dygraph_to_static/test_program_translator.py index d2909d07a50b2..253a1a9b7d67e 100644 --- a/test/dygraph_to_static/test_program_translator.py +++ b/test/dygraph_to_static/test_program_translator.py @@ -18,7 +18,7 @@ import astor import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only from ifelse_simple_func import ( dyfunc_with_if_else_early_return1, dyfunc_with_if_else_early_return2, @@ -205,20 +205,19 @@ def false_fn_3(): class NetWithError(paddle.nn.Layer): - @to_static + @to_static(full_graph=True) def forward(self, x): linear = paddle.nn.Linear(32, 64) y = linear(x) return y -@dy2static_unittest -class TestEnableDeclarative(unittest.TestCase): +class TestEnableDeclarative(Dy2StTestBase): def setUp(self): self.x = np.random.randn(30, 10, 32).astype('float32') self.weight = np.random.randn(32, 64).astype('float32') - @ast_only_test + @test_ast_only def test_raise_error(self): with base.dygraph.guard(): paddle.jit.enable_to_static(True) @@ -263,14 +262,13 @@ def foo(self): return True -@paddle.jit.to_static +@paddle.jit.to_static(full_graph=True) def switch_mode_function(): return True -@dy2static_unittest -class TestFunctionTrainEvalMode(unittest.TestCase): - @ast_only_test +class TestFunctionTrainEvalMode(Dy2StTestBase): + @test_ast_only def test_switch_mode(self): paddle.disable_static() switch_mode_function.eval() @@ -299,8 +297,7 @@ def test_raise_error(self): net.foo.train() -@dy2static_unittest -class TestIfElseEarlyReturn(unittest.TestCase): +class TestIfElseEarlyReturn(Dy2StTestBase): def test_ifelse_early_return1(self): answer = np.zeros([2, 2]) + 1 static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1) @@ -314,8 +311,7 @@ def test_ifelse_early_return2(self): np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05) -@dy2static_unittest -class TestRemoveCommentInDy2St(unittest.TestCase): +class TestRemoveCommentInDy2St(Dy2StTestBase): def func_with_comment(self): # Comment1 x = paddle.to_tensor([1, 2, 3]) @@ -356,8 +352,7 @@ def func1(x): return func1(data) -@dy2static_unittest -class TestParameterRecorder(unittest.TestCase): +class TestParameterRecorder(Dy2StTestBase): def test_recorder(self): """function calls nn.Layer case.""" net = Net() diff --git a/test/dygraph_to_static/test_ptb_lm.py b/test/dygraph_to_static/test_ptb_lm.py index 76a35d57ac9ba..87a6cbd5a8fe1 100644 --- a/test/dygraph_to_static/test_ptb_lm.py +++ b/test/dygraph_to_static/test_ptb_lm.py @@ -17,10 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import base @@ -318,14 +315,13 @@ def train_dygraph(place): return train(place) -@test_and_compare_with_new_ir(True) +@compare_legacy_with_pir def train_static(place): paddle.jit.enable_to_static(True) return train(place) -@dy2static_unittest -class TestPtb(unittest.TestCase): +class TestPtb(Dy2StTestBase): def setUp(self): self.place = ( base.CUDAPlace(0) diff --git a/test/dygraph_to_static/test_ptb_lm_v2.py b/test/dygraph_to_static/test_ptb_lm_v2.py index 92d4d43d9d4ea..abc351d17f1ec 100644 --- a/test/dygraph_to_static/test_ptb_lm_v2.py +++ b/test/dygraph_to_static/test_ptb_lm_v2.py @@ -17,7 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle @@ -323,8 +323,7 @@ def train_static(place): return train(place) -@dy2static_unittest -class TestPtb(unittest.TestCase): +class TestPtb(Dy2StTestBase): def setUp(self): self.place = ( paddle.CUDAPlace(0) diff --git a/test/dygraph_to_static/test_pylayer.py b/test/dygraph_to_static/test_pylayer.py index 0e083a67b0e94..cde9c891a8e8b 100644 --- a/test/dygraph_to_static/test_pylayer.py +++ b/test/dygraph_to_static/test_pylayer.py @@ -26,7 +26,6 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest from test_jit_save_load import train import paddle @@ -178,7 +177,7 @@ def __init__(self, in_size, out_size): super().__init__() self.linear = paddle.nn.Linear(in_size, out_size) - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def forward(self, data): hidden = self.linear(data) z = cus_tanh_1.apply(hidden) @@ -213,7 +212,7 @@ class SimpleNetInplace(paddle.nn.Layer): def __init__(self): super().__init__() - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def forward(self, data): data = data**2 z = paddle.tanh(data) @@ -226,7 +225,7 @@ def __init__(self, in_size, out_size): super().__init__() self.linear = paddle.nn.Linear(in_size, out_size) - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def forward(self, x): y = self.linear(x) out = cus_tanh_2.apply(y, func1=paddle.tanh) @@ -240,7 +239,7 @@ def __init__(self, in_size, out_size): self.linear1 = paddle.nn.Linear(in_size, out_size) self.linear2 = paddle.nn.Linear(in_size, out_size) - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def forward(self, x1, x2): y1 = self.linear1(x1) y2 = self.linear1(x2) @@ -255,7 +254,7 @@ def __init__(self, in_size, out_size): super().__init__() self.linear = paddle.nn.Linear(in_size, out_size) - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def forward(self, x): y = self.linear(x) y.stop_gradient = True @@ -263,7 +262,6 @@ def forward(self, x): return out -@dy2static_unittest class TestPyLayerBase(unittest.TestCase): def setUp(self): self.place = "gpu" if paddle.is_compiled_with_cuda() else "cpu" @@ -361,7 +359,7 @@ def _run_and_compare(self, *args, **kwargs): class TestPyLayerWithoutContext(TestPyLayerBase): def test_single_in_single_out(self): - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def test_func(x): y = scaled_layer_1.apply(x) return y @@ -374,7 +372,7 @@ def test_func(x): self._run_and_compare(input1) def test_multi_in_single_out(self): - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def test_func(x1, x2): y = scaled_layer_2.apply(x1, x2) return y @@ -391,7 +389,7 @@ def test_func(x1, x2): class TestPyLayerWithContext(TestPyLayerBase): def test_single_in_single_out(self): - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def test_func(x): y = cus_tanh_1.apply(x) return y @@ -404,7 +402,7 @@ def test_func(x): self._run_and_compare(input1) def test_nested_pylayer(self): - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def test_func(x1, x2): y = nested_layer.apply(x1, x2) return y @@ -419,7 +417,7 @@ def test_func(x1, x2): self._run_and_compare(input1, input2) def test_apply_kwargs_pylayer(self): - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def test_func(x1, x2): y = scaled_layer_2.apply(x1=x2, x2=x1) return y @@ -434,7 +432,7 @@ def test_func(x1, x2): self._run_and_compare(input1, input2) def test_non_variable_inputs(self): - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def test_func(x): y = cus_tanh_2.apply(x, func1=paddle.tanh) return y @@ -447,7 +445,7 @@ def test_func(x): self._run_and_compare(input1) def test_simple_pylayer_return_none_with_no_grad(self): - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def test_func(input1, input2): z = cus_tanh_3.apply(input1, input2, paddle.tanh, paddle.square) z = z[2] + z[3] @@ -463,7 +461,7 @@ def test_func(input1, input2): self._run_and_compare(input1, input2) def test_non_variable_inputs_and_userdefined_call(self): - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def test_func(input1): y = cus_tanh_4.apply( input1, func=user_defined_square, name="cus_tanh_test" @@ -514,7 +512,6 @@ def test_pylayer_net_with_no_grad(self): self._run_and_compare(input1, input2) -@dy2static_unittest class PyLayerTrainHelper(unittest.TestCase): def setUp(self): self.place = "gpu" if paddle.is_compiled_with_cuda() else "cpu" @@ -533,7 +530,9 @@ def _run_train(self, to_static, layer_builder, build_strategy=None): # net = self.build_layer() net = layer_builder() if to_static: - net = paddle.jit.to_static(net, build_strategy=build_strategy) + net = paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) _, _, avg_loss = train(net) return avg_loss.numpy() @@ -586,7 +585,6 @@ def test_pylayer_net_no_grad(self): ) -@dy2static_unittest class TestPyLayerJitSaveLoad(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_reinforcement_learning.py b/test/dygraph_to_static/test_reinforcement_learning.py index ffbd0e315229d..a47607b561f8d 100644 --- a/test/dygraph_to_static/test_reinforcement_learning.py +++ b/test/dygraph_to_static/test_reinforcement_learning.py @@ -18,10 +18,7 @@ import gym import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle import paddle.nn.functional as F @@ -206,8 +203,7 @@ def finish_episode(): return np.array(loss_data) -@dy2static_unittest -class TestDeclarative(unittest.TestCase): +class TestDeclarative(Dy2StTestBase): def setUp(self): self.place = ( base.CUDAPlace(0) @@ -216,7 +212,7 @@ def setUp(self): ) self.args = Args() - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_train(self): st_out = train(self.args, self.place, to_static=True) dy_out = train(self.args, self.place, to_static=False) diff --git a/test/dygraph_to_static/test_resnet.py b/test/dygraph_to_static/test_resnet.py index cb57ce234b263..f9318af86e9d1 100644 --- a/test/dygraph_to_static/test_resnet.py +++ b/test/dygraph_to_static/test_resnet.py @@ -19,7 +19,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest, test_with_new_ir +from dygraph_to_static_utils_new import Dy2StTestBase, test_pir_only from predictor_utils import PredictorTools import paddle @@ -386,8 +386,7 @@ def predict_analysis_inference(self, data): return out -@dy2static_unittest -class TestResnet(unittest.TestCase): +class TestResnet(Dy2StTestBase): def setUp(self): self.resnet_helper = ResNetHelper() @@ -420,8 +419,8 @@ def verify_predict(self): err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.', ) - @test_with_new_ir - def test_resnet_new_ir(self): + @test_pir_only + def test_resnet_pir(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) np.testing.assert_allclose( diff --git a/test/dygraph_to_static/test_resnet_amp.py b/test/dygraph_to_static/test_resnet_amp.py index 0255c0c00db3b..d8e3b6963fbec 100644 --- a/test/dygraph_to_static/test_resnet_amp.py +++ b/test/dygraph_to_static/test_resnet_amp.py @@ -16,10 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from test_resnet import SEED, ResNet, optimizer_setting import paddle @@ -114,13 +111,12 @@ def train(to_static, build_strategy=None): return total_loss.numpy() -@dy2static_unittest -class TestResnet(unittest.TestCase): +class TestResnet(Dy2StTestBase): def train(self, to_static): paddle.jit.enable_to_static(to_static) return train(to_static) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_resnet(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) diff --git a/test/dygraph_to_static/test_resnet_pure_fp16.py b/test/dygraph_to_static/test_resnet_pure_fp16.py index 771f9033f99d7..b5c132ce43df0 100644 --- a/test/dygraph_to_static/test_resnet_pure_fp16.py +++ b/test/dygraph_to_static/test_resnet_pure_fp16.py @@ -16,10 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from test_resnet import SEED, ResNet, optimizer_setting import paddle @@ -115,8 +112,7 @@ def train(to_static, build_strategy=None): return loss_data -@dy2static_unittest -class TestResnet(unittest.TestCase): +class TestResnet(Dy2StTestBase): def train(self, to_static): paddle.jit.enable_to_static(to_static) build_strategy = paddle.static.BuildStrategy() @@ -125,7 +121,7 @@ def train(self, to_static): build_strategy.enable_inplace = False return train(to_static, build_strategy) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_resnet(self): if base.is_compiled_with_cuda(): static_loss = self.train(to_static=True) diff --git a/test/dygraph_to_static/test_resnet_v2.py b/test/dygraph_to_static/test_resnet_v2.py index 0f5d804427ca6..7adb93793bf95 100644 --- a/test/dygraph_to_static/test_resnet_v2.py +++ b/test/dygraph_to_static/test_resnet_v2.py @@ -19,7 +19,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest, test_with_new_ir +from dygraph_to_static_utils_new import Dy2StTestBase, test_pir_only from predictor_utils import PredictorTools import paddle @@ -242,8 +242,7 @@ def __len__(self): return len(self.img) -@dy2static_unittest -class TestResnet(unittest.TestCase): +class TestResnet(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -427,8 +426,8 @@ def verify_predict(self): err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.', ) - @test_with_new_ir - def test_resnet_new_ir(self): + @test_pir_only + def test_resnet_pir(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) np.testing.assert_allclose( diff --git a/test/dygraph_to_static/test_return.py b/test/dygraph_to_static/test_return.py index 0cd14b94267cd..3c1e1136d7364 100644 --- a/test/dygraph_to_static/test_return.py +++ b/test/dygraph_to_static/test_return.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only from ifelse_simple_func import dyfunc_with_if_else import paddle @@ -28,13 +28,13 @@ np.random.seed(SEED) -@to_static +@to_static(full_graph=True) def test_return_base(x): x = base.dygraph.to_variable(x) return x -@to_static +@to_static(full_graph=True) def test_inside_func_base(x): x = base.dygraph.to_variable(x) @@ -44,7 +44,7 @@ def inner_func(x): return inner_func(x) -@to_static +@to_static(full_graph=True) def test_return_if(x): x = base.dygraph.to_variable(x) if x < 0: @@ -54,7 +54,7 @@ def test_return_if(x): return x -@to_static +@to_static(full_graph=True) def test_return_if_else(x): x = base.dygraph.to_variable(x) if x > 0: @@ -67,7 +67,7 @@ def test_return_if_else(x): x -= 8888 # useless statement to test our code can handle it. -@to_static +@to_static(full_graph=True) def test_return_in_while(x): x = base.dygraph.to_variable(x) i = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=0) @@ -80,7 +80,7 @@ def test_return_in_while(x): return x -@to_static +@to_static(full_graph=True) def test_return_in_for(x): x = base.dygraph.to_variable(x) for i in range(10): @@ -92,13 +92,13 @@ def test_return_in_for(x): return x - 1 -@to_static +@to_static(full_graph=True) def test_recursive_return(x): x = base.dygraph.to_variable(x) return dyfunc_with_if_else(x) -@to_static +@to_static(full_graph=True) def test_return_different_length_if_body(x): x = base.dygraph.to_variable(x) y = x + 1 @@ -109,7 +109,7 @@ def test_return_different_length_if_body(x): return x -@to_static +@to_static(full_graph=True) def test_return_different_length_else(x): x = base.dygraph.to_variable(x) y = x + 1 @@ -120,13 +120,13 @@ def test_return_different_length_else(x): return x -@to_static +@to_static(full_graph=True) def test_no_return(x): x = base.dygraph.to_variable(x) y = x + 1 -@to_static +@to_static(full_graph=True) def test_return_none(x): x = base.dygraph.to_variable(x) y = x + 1 @@ -137,7 +137,7 @@ def test_return_none(x): return x, y -@to_static +@to_static(full_graph=True) def test_return_no_variable(x): x = base.dygraph.to_variable(x) y = x + 1 @@ -148,14 +148,14 @@ def test_return_no_variable(x): return -@to_static +@to_static(full_graph=True) def test_return_list_one_value(x): x = base.dygraph.to_variable(x) x += 1 return [x] -@to_static +@to_static(full_graph=True) def test_return_list_many_values(x): x = base.dygraph.to_variable(x) x += 1 @@ -164,14 +164,14 @@ def test_return_list_many_values(x): return [x, y, z] -@to_static +@to_static(full_graph=True) def test_return_tuple_one_value(x): x = base.dygraph.to_variable(x) x += 1 return (x,) -@to_static +@to_static(full_graph=True) def test_return_tuple_many_values(x): x = base.dygraph.to_variable(x) x += 1 @@ -189,7 +189,7 @@ def inner_func(x): return y -@to_static +@to_static(full_graph=True) def test_return_without_paddle_cond(x): # y shape is [10] y = paddle.ones([10]) @@ -213,7 +213,7 @@ def diff_return_hepler(x): return two_value(x) -@to_static +@to_static(full_graph=True) def test_diff_return(x): x = paddle.to_tensor(x) y, z = diff_return_hepler(x) @@ -222,7 +222,7 @@ def test_diff_return(x): return y, z -@to_static +@to_static(full_graph=True) def test_return_if_else_2(x): rr = 0 if True: @@ -232,7 +232,7 @@ def test_return_if_else_2(x): a = 0 -@to_static +@to_static(full_graph=True) def test_return_in_while_2(x): while True: a = 12 @@ -240,7 +240,7 @@ def test_return_in_while_2(x): return 10 -@to_static +@to_static(full_graph=True) def test_return_in_for_2(x): a = 12 for i in range(10): @@ -248,7 +248,7 @@ def test_return_in_for_2(x): return 10 -@to_static +@to_static(full_graph=True) def test_return_nested(x): def func(): rr = 0 @@ -264,8 +264,7 @@ def func(): return func() -@dy2static_unittest -class TestReturnBase(unittest.TestCase): +class TestReturnBase(Dy2StTestBase): def setUp(self): self.input = np.ones(1).astype('int32') self.place = ( @@ -303,6 +302,7 @@ def _test_value_impl(self): else: self.assertEqual(dygraph_res, static_res) + @test_ast_only def test_transformed_static_result(self): if hasattr(self, "error"): with self.assertRaisesRegex(Dygraph2StaticException, self.error): @@ -351,20 +351,12 @@ def init_dygraph_func(self): self.dygraph_func = test_return_in_while_2 self.error = "Found return statement in While or For body and loop" - @ast_only_test - def test_transformed_static_result(self): - super().test_transformed_static_result() - class TestReturnInFor2(TestReturnBase): def init_dygraph_func(self): self.dygraph_func = test_return_in_for_2 self.error = "Found return statement in While or For body and loop" - @ast_only_test - def test_transformed_static_result(self): - super().test_transformed_static_result() - class TestRecursiveReturn(TestReturnBase): def init_dygraph_func(self): @@ -377,20 +369,12 @@ def init_dygraph_func(self): self.dygraph_func = test_return_different_length_if_body self.error = "Your if/else have different number of return value." - @ast_only_test - def test_transformed_static_result(self): - super().test_transformed_static_result() - class TestReturnDifferentLengthElse(TestReturnBase): def init_dygraph_func(self): self.dygraph_func = test_return_different_length_else self.error = "Your if/else have different number of return value." - @ast_only_test - def test_transformed_static_result(self): - super().test_transformed_static_result() - class TestNoReturn(TestReturnBase): def init_dygraph_func(self): @@ -402,20 +386,12 @@ def init_dygraph_func(self): self.dygraph_func = test_return_none self.error = "Your if/else have different number of return value." - @ast_only_test - def test_transformed_static_result(self): - super().test_transformed_static_result() - class TestReturnNoVariable(TestReturnBase): def init_dygraph_func(self): self.dygraph_func = test_return_no_variable self.error = "Your if/else have different number of return value." - @ast_only_test - def test_transformed_static_result(self): - super().test_transformed_static_result() - class TestReturnListOneValue(TestReturnBase): def init_dygraph_func(self): diff --git a/test/dygraph_to_static/test_rollback.py b/test/dygraph_to_static/test_rollback.py index 7ee3456747b51..2cba4d9ed7d85 100644 --- a/test/dygraph_to_static/test_rollback.py +++ b/test/dygraph_to_static/test_rollback.py @@ -15,10 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -71,12 +71,11 @@ def foo(x, flag=False): return out -@dy2static_unittest -class TestRollBackPlainFunction(unittest.TestCase): +class TestRollBackPlainFunction(Dy2StTestBase): def setUp(self): paddle.set_device("cpu") - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_plain_func(self): st_foo = paddle.jit.to_static(foo) x = paddle.randn([3, 4]) @@ -91,13 +90,12 @@ def test_plain_func(self): np.testing.assert_array_equal(st_out.numpy(), dy_out.numpy()) -@dy2static_unittest -class TestRollBackNet(unittest.TestCase): +class TestRollBackNet(Dy2StTestBase): def setUp(self): paddle.set_device("cpu") - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_net(self): net = paddle.jit.to_static(Net()) x = paddle.randn([3, 4]) @@ -143,10 +141,9 @@ def func(self, x): return x + 2 -@dy2static_unittest -class TestRollBackNotForward(unittest.TestCase): - @ast_only_test - @test_and_compare_with_new_ir(False) +class TestRollBackNotForward(Dy2StTestBase): + @test_ast_only + @test_legacy_and_pir def test_rollback(self): x = paddle.zeros([2, 2]) net = FuncRollback() diff --git a/test/dygraph_to_static/test_save_inference_model.py b/test/dygraph_to_static/test_save_inference_model.py index 468541cfde39e..5054bad197738 100644 --- a/test/dygraph_to_static/test_save_inference_model.py +++ b/test/dygraph_to_static/test_save_inference_model.py @@ -17,10 +17,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + compare_legacy_with_pir, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -42,7 +43,7 @@ def __init__(self, fc_size): super().__init__() self._linear = paddle.nn.Linear(fc_size, fc_size) - @to_static + @to_static(full_graph=True) def forward(self, x): y = self._linear(x) z = self._linear(y) @@ -69,7 +70,7 @@ def __init__(self, fc_size): super().__init__() self._linear = paddle.nn.Linear(fc_size, fc_size) - @to_static + @to_static(full_graph=True) def forward(self, x): y = self._linear(x) out = cus_tanh.apply(y) @@ -77,15 +78,14 @@ def forward(self, x): return loss, out -@dy2static_unittest -class TestDyToStaticSaveInferenceModel(unittest.TestCase): +class TestDyToStaticSaveInferenceModel(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() def tearDown(self): self.temp_dir.cleanup() - @ast_only_test + @test_ast_only def test_save_inference_model(self): fc_size = 20 x_data = np.random.random((fc_size, fc_size)).astype('float32') @@ -127,7 +127,7 @@ def test_save_inference_model(self): layer, [x_data], dygraph_out.numpy(), feed=[x] ) - @ast_only_test + @test_ast_only def test_save_pylayer_model(self): fc_size = 20 x_data = np.random.random((fc_size, fc_size)).astype('float32') @@ -162,17 +162,17 @@ def test_save_pylayer_model(self): loss_out_numpy = float(loss_out) self.check_save_inference_model( - layer, [x_data], loss_out_numpy, enable_new_ir=False + layer, [x_data], loss_out_numpy, enable_pir=False ) self.check_save_inference_model( - layer, [x_data], loss_out_numpy, fetch=[loss], enable_new_ir=False + layer, [x_data], loss_out_numpy, fetch=[loss], enable_pir=False ) self.check_save_inference_model( - layer, [x_data], loss_out_numpy, feed=[x], enable_new_ir=False + layer, [x_data], loss_out_numpy, feed=[x], enable_pir=False ) def check_save_inference_model( - self, model, inputs, gt_out, feed=None, fetch=None, enable_new_ir=True + self, model, inputs, gt_out, feed=None, fetch=None, enable_pir=True ): expected_persistable_vars = {p.name for p in model.parameters()} @@ -190,8 +190,8 @@ def check_save_inference_model( input_spec=feed if feed else None, output_spec=fetch if fetch else None, ) - if enable_new_ir: - wrapped_load_and_run_inference = test_and_compare_with_new_ir(True)( + if enable_pir: + wrapped_load_and_run_inference = compare_legacy_with_pir( self.load_and_run_inference ) infer_out = wrapped_load_and_run_inference( @@ -228,10 +228,9 @@ def load_and_run_inference( return np.array(results[0]) -@dy2static_unittest -class TestPartialProgramRaiseError(unittest.TestCase): - @ast_only_test - @test_and_compare_with_new_ir(False) +class TestPartialProgramRaiseError(Dy2StTestBase): + @test_ast_only + @test_legacy_and_pir def test_param_type(self): paddle.jit.enable_to_static(True) x_data = np.random.random((20, 20)).astype('float32') diff --git a/test/dygraph_to_static/test_save_load.py b/test/dygraph_to_static/test_save_load.py index 92965aea2ccc2..674a7cfa1f559 100644 --- a/test/dygraph_to_static/test_save_load.py +++ b/test/dygraph_to_static/test_save_load.py @@ -17,10 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) from test_fetch_feed import Linear @@ -59,8 +59,7 @@ def forward_post_hook_for_prim_net(layer, input, output): return output * 2 -@dy2static_unittest -class TestDyToStaticSaveLoad(unittest.TestCase): +class TestDyToStaticSaveLoad(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.model_path = os.path.join( @@ -116,8 +115,8 @@ def test_save_load_same_result(self): dygraph_loss.numpy(), static_loss.numpy(), rtol=1e-05 ) - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_save_load_prim(self): with base.dygraph.guard(place): self.x = paddle.randn([4, 2, 6, 6], dtype="float32") @@ -158,8 +157,8 @@ def test_save_load_prim(self): self.assertIn("pool2d", load_op_type_list) np.testing.assert_allclose(res.numpy(), new_res.numpy(), rtol=1e-05) - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_save_load_prim_with_hook(self): with base.dygraph.guard(place): self.x = paddle.randn([4, 2, 6, 6], dtype="float32") diff --git a/test/dygraph_to_static/test_se_resnet.py b/test/dygraph_to_static/test_se_resnet.py index 3ef1e62bf1cda..f779babe69bb2 100644 --- a/test/dygraph_to_static/test_se_resnet.py +++ b/test/dygraph_to_static/test_se_resnet.py @@ -20,10 +20,10 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + compare_legacy_with_pir, + test_ast_only, ) from predictor_utils import PredictorTools @@ -321,7 +321,7 @@ def __init__(self, layers=50, class_dim=102): ), ) - @to_static + @to_static(full_graph=True) def forward(self, inputs, label): if self.layers == 50 or self.layers == 101: y = self.conv0(inputs) @@ -351,8 +351,7 @@ def forward(self, inputs, label): return out, avg_loss, acc_top1, acc_top5 -@dy2static_unittest -class TestSeResnet(unittest.TestCase): +class TestSeResnet(Dy2StTestBase): def setUp(self): self.train_reader = paddle.batch( paddle.dataset.flowers.train(use_xmap=False, cycle=True), @@ -374,7 +373,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def train(self, train_reader, to_static): paddle.jit.enable_to_static(to_static) @@ -496,7 +495,7 @@ def predict_dygraph(self, data): return pred_res.numpy() - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def predict_static(self, data): paddle.enable_static() exe = base.Executor(place) @@ -570,7 +569,7 @@ def verify_predict(self): ), ) - @ast_only_test + @test_ast_only def test_check_result(self): pred_1, loss_1, acc1_1, acc5_1 = self.train( self.train_reader, to_static=False diff --git a/test/dygraph_to_static/test_sentiment.py b/test/dygraph_to_static/test_sentiment.py index 60d3678a5a72b..3c6a52dd9bad0 100644 --- a/test/dygraph_to_static/test_sentiment.py +++ b/test/dygraph_to_static/test_sentiment.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from test_lac import DynamicGRU import paddle @@ -372,12 +369,11 @@ def train(args, to_static): return loss_data -@dy2static_unittest -class TestSentiment(unittest.TestCase): +class TestSentiment(Dy2StTestBase): def setUp(self): self.args = Args() - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def train_model(self, model_type='cnn_net'): self.args.model_type = model_type st_out = train(self.args, True) diff --git a/test/dygraph_to_static/test_seq2seq.py b/test/dygraph_to_static/test_seq2seq.py index b97752d4c57cb..db2d1e70c7a46 100644 --- a/test/dygraph_to_static/test_seq2seq.py +++ b/test/dygraph_to_static/test_seq2seq.py @@ -18,7 +18,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase from seq2seq_dygraph_model import AttentionModel, BaseModel from seq2seq_utils import Seq2SeqModelHyperParams, get_data_iter @@ -175,8 +175,7 @@ def infer(args, attn_model=False): return outputs.numpy() -@dy2static_unittest -class TestSeq2seq(unittest.TestCase): +class TestSeq2seq(Dy2StTestBase): def setUp(self): self.args = Seq2SeqModelHyperParams self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_set_dynamic_shape.py b/test/dygraph_to_static/test_set_dynamic_shape.py new file mode 100644 index 0000000000000..3a3843846a9a4 --- /dev/null +++ b/test/dygraph_to_static/test_set_dynamic_shape.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only + +import paddle + + +class TestSetDynamicShape(Dy2StTestBase): + @test_ast_only + def test_start(self): + def dygraph_func(loop_number): + mask = paddle.randn([2, 2]) + paddle.jit.dy2static.utils_helper.set_dynamic_shape(mask, [-1, 2]) + n = paddle.randn([1, 2]) + for i in range(loop_number): + mask = paddle.concat([mask, n], axis=0) + if mask.shape[0] == 5: + break + return mask + + loop_num = paddle.to_tensor(10) + expected_shape = dygraph_func(loop_num).shape + actual_shape = paddle.jit.to_static(dygraph_func)(loop_num).shape + self.assertEqual(expected_shape, actual_shape) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/dygraph_to_static/test_simnet.py b/test/dygraph_to_static/test_simnet.py index 90dce27f87eef..75d13d48fb292 100644 --- a/test/dygraph_to_static/test_simnet.py +++ b/test/dygraph_to_static/test_simnet.py @@ -17,10 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from simnet_dygraph_model import BOW, HingeLoss import paddle @@ -179,9 +176,8 @@ def train(conf_dict, to_static): return losses -@dy2static_unittest -class TestSimnet(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestSimnet(Dy2StTestBase): + @test_legacy_and_pir def test_dygraph_static_same_loss(self): if base.is_compiled_with_cuda(): base.set_flags({"FLAGS_cudnn_deterministic": True}) diff --git a/test/dygraph_to_static/test_simnet_v2.py b/test/dygraph_to_static/test_simnet_v2.py index 16fccfd731be0..9f05ca54759e8 100644 --- a/test/dygraph_to_static/test_simnet_v2.py +++ b/test/dygraph_to_static/test_simnet_v2.py @@ -17,10 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from simnet_dygraph_model_v2 import BOW, HingeLoss import paddle @@ -179,9 +176,8 @@ def train(conf_dict, to_static): return losses -@dy2static_unittest -class TestSimnet(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestSimnet(Dy2StTestBase): + @test_legacy_and_pir def test_dygraph_static_same_loss(self): if paddle.is_compiled_with_cuda(): paddle.base.set_flags({"FLAGS_cudnn_deterministic": True}) diff --git a/test/dygraph_to_static/test_slice.py b/test/dygraph_to_static/test_slice.py index 3bd4c5f8a2c83..17a4e8410d612 100644 --- a/test/dygraph_to_static/test_slice.py +++ b/test/dygraph_to_static/test_slice.py @@ -17,7 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle from paddle.static import InputSpec @@ -108,8 +108,7 @@ def forward(self, x): return x -@dy2static_unittest -class TestSliceWithoutControlFlow(unittest.TestCase): +class TestSliceWithoutControlFlow(Dy2StTestBase): def setUp(self): self.init_input() self.place = ( @@ -170,8 +169,7 @@ def init_dygraph_func(self): self.dygraph_func = test_set_value -@dy2static_unittest -class TestSetValueWithLayerAndSave(unittest.TestCase): +class TestSetValueWithLayerAndSave(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.model_path = os.path.join( @@ -181,7 +179,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() - @ast_only_test + @test_ast_only def test_set_value_with_save(self): paddle.jit.enable_to_static(True) model = LayerWithSetValue(input_dim=10, hidden=1) @@ -191,8 +189,7 @@ def test_set_value_with_save(self): ) -@dy2static_unittest -class TestSliceSupplementSpecialCase(unittest.TestCase): +class TestSliceSupplementSpecialCase(Dy2StTestBase): # unittest for slice index which abs(step)>0. eg: x[::2] def test_static_slice_step(self): paddle.enable_static() @@ -235,8 +232,7 @@ def func(inps): ) -@dy2static_unittest -class TestPaddleStridedSlice(unittest.TestCase): +class TestPaddleStridedSlice(Dy2StTestBase): def test_compare_paddle_strided_slice_with_numpy(self): paddle.disable_static() array = np.arange(5) @@ -297,8 +293,7 @@ def slice_zero_shape_tensor(x): return y -@dy2static_unittest -class TestSliceZeroShapeTensor(unittest.TestCase): +class TestSliceZeroShapeTensor(Dy2StTestBase): def test_slice(self): paddle.disable_static() x = paddle.ones([0, 0, 0, 0]) diff --git a/test/dygraph_to_static/test_spec_names.py b/test/dygraph_to_static/test_spec_names.py index 72ffdc845134a..7f2f9683e0951 100644 --- a/test/dygraph_to_static/test_spec_names.py +++ b/test/dygraph_to_static/test_spec_names.py @@ -16,8 +16,8 @@ from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -48,8 +48,8 @@ def read_from_dataset(self): self.m = paddle.randn([4, 2, 8]) self.n = paddle.randn([4, 2, 8]) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_spec_name_hash(self): net = Net() net = paddle.jit.to_static(net) diff --git a/test/dygraph_to_static/test_tensor_hook.py b/test/dygraph_to_static/test_tensor_hook.py index 06b1b288ad899..3a4174f0febcf 100644 --- a/test/dygraph_to_static/test_tensor_hook.py +++ b/test/dygraph_to_static/test_tensor_hook.py @@ -15,15 +15,14 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle from paddle import nn from paddle.jit import to_static -@dy2static_unittest -class TestStaticAnalysis(unittest.TestCase): +class TestStaticAnalysis(Dy2StTestBase): def test_hook_for_different_parameter(self): def f(x): def h(g): diff --git a/test/dygraph_to_static/test_tensor_memcpy_on_cpu.py b/test/dygraph_to_static/test_tensor_memcpy_on_cpu.py index a8e955be9e863..315a252d3bd24 100644 --- a/test/dygraph_to_static/test_tensor_memcpy_on_cpu.py +++ b/test/dygraph_to_static/test_tensor_memcpy_on_cpu.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -40,13 +41,14 @@ def tensor_copy_to_cuda_with_warning(x, device_id=None, blocking=True): return y -class TestTensorCopyToCpuOnDefaultCPU(unittest.TestCase): +class TestTensorCopyToCpuOnDefaultCPU(Dy2StTestBase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) x1 = paddle.ones([1, 2, 3]) x2 = tensor_copy_to_cpu(x1) return x1.place, x2.place, x2.numpy() + @test_legacy_and_pir def test_tensor_cpu_on_default_cpu(self): paddle.base.framework._set_expected_place(paddle.CPUPlace()) dygraph_x1_place, dygraph_place, dygraph_res = self._run( @@ -60,13 +62,14 @@ def test_tensor_cpu_on_default_cpu(self): self.assertTrue(static_place.is_cpu_place()) -class TestTensorCopyToCUDAOnDefaultCPU(unittest.TestCase): +class TestTensorCopyToCUDAOnDefaultCPU(Dy2StTestBase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) x1 = paddle.ones([1, 2, 3]) x2 = tensor_copy_to_cuda(x1) return x1.place, x2.place, x2.numpy() + @test_legacy_and_pir def test_tensor_cuda_on_default_cpu(self): if not paddle.base.is_compiled_with_cuda(): return diff --git a/test/dygraph_to_static/test_tensor_memcpy_on_gpu.py b/test/dygraph_to_static/test_tensor_memcpy_on_gpu.py index 30e8e55611959..45aa125fdd5d5 100644 --- a/test/dygraph_to_static/test_tensor_memcpy_on_gpu.py +++ b/test/dygraph_to_static/test_tensor_memcpy_on_gpu.py @@ -16,6 +16,7 @@ import unittest import numpy as np +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -41,13 +42,14 @@ def tensor_copy_to_cuda_with_warning(x, device_id=None, blocking=True): return y -class TestTensorCopyToCpuOnDefaultGPU(unittest.TestCase): +class TestTensorCopyToCpuOnDefaultGPU(Dy2StTestBase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) x1 = paddle.ones([1, 2, 3]) x2 = tensor_copy_to_cpu(x1) return x1.place, x2.place, x2.numpy() + @test_legacy_and_pir def test_tensor_cpu_on_default_gpu(self): if paddle.base.is_compiled_with_cuda(): place = paddle.CUDAPlace( @@ -67,13 +69,14 @@ def test_tensor_cpu_on_default_gpu(self): self.assertTrue(static_place.is_cpu_place()) -class TestTensorCopyToCUDAOnDefaultGPU(unittest.TestCase): +class TestTensorCopyToCUDAOnDefaultGPU(Dy2StTestBase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) x1 = paddle.ones([1, 2, 3]) x2 = tensor_copy_to_cuda(x1) return x1.place, x2.place, x2.numpy() + @test_legacy_and_pir def test_tensor_cuda_on_default_gpu(self): if paddle.base.is_compiled_with_cuda(): place = paddle.CUDAPlace( diff --git a/test/dygraph_to_static/test_tensor_methods.py b/test/dygraph_to_static/test_tensor_methods.py index 65981d65825a4..401428908f763 100644 --- a/test/dygraph_to_static/test_tensor_methods.py +++ b/test/dygraph_to_static/test_tensor_methods.py @@ -15,10 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -31,14 +31,13 @@ def tensor_clone(x): return y -@dy2static_unittest -class TestTensorClone(unittest.TestCase): +class TestTensorClone(Dy2StTestBase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) x = paddle.ones([1, 2, 3]) return tensor_clone(x).numpy() - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_tensor_clone(self): paddle.disable_static() dygraph_res = self._run(to_static=False) @@ -46,23 +45,22 @@ def test_tensor_clone(self): np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) -@paddle.jit.to_static +@paddle.jit.to_static(full_graph=True) def tensor_numpy(x): x = paddle.to_tensor(x) x.clear_gradient() return x -@dy2static_unittest -class TestTensorDygraphOnlyMethodError(unittest.TestCase): +class TestTensorDygraphOnlyMethodError(Dy2StTestBase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) x = paddle.zeros([2, 2]) y = tensor_numpy(x) return y.numpy() - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_to_static_numpy_report_error(self): paddle.disable_static() dygraph_res = self._run(to_static=False) @@ -70,15 +68,14 @@ def test_to_static_numpy_report_error(self): static_res = self._run(to_static=True) -@paddle.jit.to_static +@paddle.jit.to_static(full_graph=True) def tensor_item(x): x = paddle.to_tensor(x) y = x.clone() return y.item() -@dy2static_unittest -class TestTensorItem(unittest.TestCase): +class TestTensorItem(Dy2StTestBase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) x = paddle.ones([1]) @@ -86,7 +83,7 @@ def _run(self, to_static): return tensor_item(x).numpy() return tensor_item(x) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_tensor_clone(self): paddle.disable_static() dygraph_res = self._run(to_static=False) @@ -102,8 +99,7 @@ def tensor_size(x): return y -@dy2static_unittest -class TestTensorSize(unittest.TestCase): +class TestTensorSize(Dy2StTestBase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) x = paddle.ones([1, 2, 3]) @@ -114,7 +110,7 @@ def _run(self, to_static): ret = ret.numpy() return ret - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_tensor_clone(self): paddle.disable_static() dygraph_res = self._run(to_static=False) @@ -128,15 +124,14 @@ def true_div(x, y): return z -@dy2static_unittest -class TestTrueDiv(unittest.TestCase): +class TestTrueDiv(Dy2StTestBase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) x = paddle.to_tensor([3], dtype='int64') y = paddle.to_tensor([4], dtype='int64') return true_div(x, y).numpy() - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_ture_div(self): paddle.disable_static() dygraph_res = self._run(to_static=False) diff --git a/test/dygraph_to_static/test_tensor_shape.py b/test/dygraph_to_static/test_tensor_shape.py index d8c13cff35193..23dccb0f61093 100644 --- a/test/dygraph_to_static/test_tensor_shape.py +++ b/test/dygraph_to_static/test_tensor_shape.py @@ -17,8 +17,8 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + compare_legacy_with_pir, + test_ast_only, ) import paddle @@ -266,7 +266,7 @@ def _run(self, to_static): def get_dygraph_output(self): return self._run(to_static=False) - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): return self._run(to_static=True) @@ -293,7 +293,7 @@ def _compute_op_num(self, program): [op for op in block.ops if op.type == "slice"] ) - @ast_only_test + @test_ast_only def test_op_num(self): static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec) program = static_layer.main_program @@ -526,7 +526,7 @@ def _compute_op_num(self, program): [op for op in block.ops if op.type == "slice"] ) - @ast_only_test + @test_ast_only def test_op_num(self): static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec) program = static_layer.main_program @@ -617,7 +617,7 @@ def dyfunc_with_static_convert_var_shape(x): class TestFindStatiConvertVarShapeSuffixVar(Dy2StTestBase): - @ast_only_test + @test_ast_only def test(self): x_spec = paddle.static.InputSpec(shape=[None, 10]) func = paddle.jit.to_static(dyfunc_with_if_2, input_spec=[x_spec]) diff --git a/test/dygraph_to_static/test_to_tensor.py b/test/dygraph_to_static/test_to_tensor.py index b211e09254ede..d818a3740b4e8 100644 --- a/test/dygraph_to_static/test_to_tensor.py +++ b/test/dygraph_to_static/test_to_tensor.py @@ -15,7 +15,7 @@ import unittest import numpy -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle from paddle.base import core @@ -100,8 +100,7 @@ def case_to_tensor_default_dtype(): return paddle.to_tensor(1) -@dy2static_unittest -class TestToTensorReturnVal(unittest.TestCase): +class TestToTensorReturnVal(Dy2StTestBase): def test_to_tensor_badreturn(self): paddle.disable_static() x = paddle.to_tensor([3]) @@ -173,8 +172,7 @@ def test_to_tensor_err_log(self): ) -@dy2static_unittest -class TestStatic(unittest.TestCase): +class TestStatic(Dy2StTestBase): def test_static(self): paddle.enable_static() main_prog = Program() @@ -202,5 +200,18 @@ def test_static(self): res = exe.run(fetch_list=[x, out]) +class TestInt16(unittest.TestCase): + def test_static(self): + import numpy as np + + paddle.enable_static() + data = np.array([1, 2], dtype="int16") + x = paddle.to_tensor(data) + self.assertTrue(x.dtype == paddle.framework.core.VarDesc.VarType.INT16) + + y = paddle.to_tensor([1, 2], dtype="int16") + self.assertTrue(y.dtype == paddle.framework.core.VarDesc.VarType.INT16) + + if __name__ == '__main__': unittest.main() diff --git a/test/dygraph_to_static/test_train_step.py b/test/dygraph_to_static/test_train_step.py index 3c003f0725909..b2d336d8c8b2e 100644 --- a/test/dygraph_to_static/test_train_step.py +++ b/test/dygraph_to_static/test_train_step.py @@ -17,10 +17,7 @@ from functools import partial import numpy as np -from dygraph_to_static_util import ( - enable_fallback_guard, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -53,7 +50,7 @@ def forward(self, data): return self.layer1(data) -class TestTrainStepTinyModel(unittest.TestCase): +class TestTrainStepTinyModel(Dy2StTestBase): def setUp(self): self.input = paddle.randn([10000, 10]) self.net_creator = TinyModel @@ -80,14 +77,16 @@ def get_train_step_losses(self, func, steps): losses.append(loss) return losses - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_train_step(self): reset_seed() dygraph_losses = self.get_train_step_losses( self.train_step_func, self.steps ) reset_seed() - static_func = paddle.jit.to_static(self.train_step_func) + static_func = paddle.jit.to_static( + self.train_step_func, full_graph=True + ) static_losses = self.get_train_step_losses(static_func, self.steps) self.assertEqual(len(dygraph_losses), len(static_losses)) for dygraph_loss, static_loss in zip(dygraph_losses, static_losses): @@ -437,6 +436,22 @@ def setUp(self): self.rtol = 1e-4 +class TestTrainStepTinyModelCosineAnnealingWarmRestarts(TestTrainStepTinyModel): + def setUp(self): + self.input = paddle.randn([10000, 10]) + self.net_creator = TinyModel + self.lr_creator = partial( + paddle.optimizer.lr.CosineAnnealingWarmRestarts, + learning_rate=0.5, + T_0=1, + T_mult=1, + ) + self.optimizer_creator = paddle.optimizer.SGD + self.loss_fn = loss_fn_tiny_model + self.train_step_func = train_step_tiny_model + self.steps = 3 + self.rtol = 1e-4 + + if __name__ == "__main__": - with enable_fallback_guard("False"): - unittest.main() + unittest.main() diff --git a/test/dygraph_to_static/test_train_step_resnet18_adam.py b/test/dygraph_to_static/test_train_step_resnet18_adam.py index 95fd040282b92..c8b34fe84f113 100644 --- a/test/dygraph_to_static/test_train_step_resnet18_adam.py +++ b/test/dygraph_to_static/test_train_step_resnet18_adam.py @@ -15,7 +15,6 @@ import platform import unittest -from dygraph_to_static_util import enable_fallback_guard from test_train_step import ( TestTrainStepTinyModel, loss_fn_tiny_model, @@ -41,5 +40,4 @@ def setUp(self): if __name__ == "__main__": - with enable_fallback_guard("False"): - unittest.main() + unittest.main() diff --git a/test/dygraph_to_static/test_train_step_resnet18_sgd.py b/test/dygraph_to_static/test_train_step_resnet18_sgd.py index f6139e62dc216..a73d945aa9524 100644 --- a/test/dygraph_to_static/test_train_step_resnet18_sgd.py +++ b/test/dygraph_to_static/test_train_step_resnet18_sgd.py @@ -15,7 +15,6 @@ import platform import unittest -from dygraph_to_static_util import enable_fallback_guard from test_train_step import ( TestTrainStepTinyModel, loss_fn_tiny_model, @@ -41,5 +40,4 @@ def setUp(self): if __name__ == "__main__": - with enable_fallback_guard("False"): - unittest.main() + unittest.main() diff --git a/test/dygraph_to_static/test_transformer.py b/test/dygraph_to_static/test_transformer.py index 29dda3916f3ab..2e8aefc568510 100644 --- a/test/dygraph_to_static/test_transformer.py +++ b/test/dygraph_to_static/test_transformer.py @@ -20,10 +20,7 @@ import numpy as np import transformer_util as util -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir from transformer_dygraph_model import ( CrossEntropyCriterion, Transformer, @@ -39,7 +36,7 @@ STEP_NUM = 10 -@test_and_compare_with_new_ir(True) +@compare_legacy_with_pir def train_static(args, batch_generator): paddle.enable_static() paddle.seed(SEED) @@ -422,7 +419,7 @@ def predict_dygraph(args, batch_generator): return seq_ids, seq_scores -@test_and_compare_with_new_ir(True) +@compare_legacy_with_pir def predict_static(args, batch_generator): test_prog = base.Program() with base.program_guard(test_prog): @@ -530,8 +527,7 @@ def predict_static(args, batch_generator): return seq_ids, seq_scores -@dy2static_unittest -class TestTransformer(unittest.TestCase): +class TestTransformer(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_tsm.py b/test/dygraph_to_static/test_tsm.py index 2cef9e7df4ded..83e7a27cad09c 100644 --- a/test/dygraph_to_static/test_tsm.py +++ b/test/dygraph_to_static/test_tsm.py @@ -19,10 +19,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from tsm_config_utils import merge_configs, parse_config, print_configs import paddle @@ -387,9 +384,8 @@ def train(args, fake_data_reader, to_static): return ret -@dy2static_unittest -class TestTsm(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestTsm(Dy2StTestBase): + @test_legacy_and_pir def test_dygraph_static_same_loss(self): if base.is_compiled_with_cuda(): base.set_flags({"FLAGS_cudnn_deterministic": True}) diff --git a/test/dygraph_to_static/test_typehint.py b/test/dygraph_to_static/test_typehint.py index 563db1d7a1df0..bf2f37fa86b7a 100644 --- a/test/dygraph_to_static/test_typehint.py +++ b/test/dygraph_to_static/test_typehint.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import base @@ -36,8 +33,7 @@ def function(x: A) -> A: return 2 * x -@dy2static_unittest -class TestTransformWhileLoop(unittest.TestCase): +class TestTypeHint(Dy2StTestBase): def setUp(self): self.place = ( base.CUDAPlace(0) @@ -50,7 +46,7 @@ def setUp(self): def _init_dyfunc(self): self.dyfunc = function - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def _run_static(self): return self._run(to_static=True) @@ -77,10 +73,5 @@ def test_ast_to_func(self): np.testing.assert_allclose(dygraph_numpy, static_numpy, rtol=1e-05) -class TestTypeHint(TestTransformWhileLoop): - def _init_dyfunc(self): - self.dyfunc = function - - if __name__ == '__main__': unittest.main() diff --git a/test/dygraph_to_static/test_typing.py b/test/dygraph_to_static/test_typing.py index c6810120488d0..71b098d1ca9ea 100644 --- a/test/dygraph_to_static/test_typing.py +++ b/test/dygraph_to_static/test_typing.py @@ -17,10 +17,7 @@ from typing import Dict, List, Tuple import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -72,8 +69,7 @@ def forward(self, x) -> Dict[str, paddle.Tensor]: return {'out': out2} -@dy2static_unittest -class TestTyping(unittest.TestCase): +class TestTyping(Dy2StTestBase): def setUp(self): self.in_num = 16 self.out_num = 16 @@ -97,7 +93,7 @@ def run_dy(self): out, _ = self.net(self.x) return out - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_type(self): self.net = self.build_net() out = self.run_dy() diff --git a/test/dygraph_to_static/test_unuseful_inputs.py b/test/dygraph_to_static/test_unuseful_inputs.py index 8f83f015db431..6a1d60ed7170d 100644 --- a/test/dygraph_to_static/test_unuseful_inputs.py +++ b/test/dygraph_to_static/test_unuseful_inputs.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle from paddle import nn @@ -65,14 +62,13 @@ def forward(self, x): return val -@dy2static_unittest -class TestDuplicateOutput(unittest.TestCase): +class TestDuplicateOutput(Dy2StTestBase): """ TestCase for the transformation from control flow `if/else` dependent on tensor in Dygraph into Static `base.layers.cond`. """ - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_case(self): # create network layer = Layer0(0) diff --git a/test/dygraph_to_static/test_utils.py b/test/dygraph_to_static/test_utils.py index 180078c144829..68ad96a8085c9 100644 --- a/test/dygraph_to_static/test_utils.py +++ b/test/dygraph_to_static/test_utils.py @@ -15,13 +15,13 @@ import types import unittest -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir_api from paddle.jit.dy2static.utils import index_in_list, is_paddle_func -@dy2static_unittest -class TestIndexInList(unittest.TestCase): +class TestIndexInList(Dy2StTestBase): + @test_legacy_and_pir_api def test_index_in_list(self): list_to_test = [1, 2, 3, 4, 5] self.assertEqual(index_in_list(list_to_test, 4), 3) @@ -52,11 +52,11 @@ def dyfunc_assign(input): y = n -@dy2static_unittest -class TestIsPaddle(unittest.TestCase): +class TestIsPaddle(Dy2StTestBase): def fake_module(self): return types.ModuleType('paddlenlp') + @test_legacy_and_pir_api def test_func(self): m = self.fake_module() self.assertFalse(is_paddle_func(m)) diff --git a/test/dygraph_to_static/test_variable_trans_func.py b/test/dygraph_to_static/test_variable_trans_func.py index 0ca73fbf9dd75..cbfe76ec2824c 100644 --- a/test/dygraph_to_static/test_variable_trans_func.py +++ b/test/dygraph_to_static/test_variable_trans_func.py @@ -14,14 +14,13 @@ import unittest -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase from paddle.jit.dy2static.utils import ast_to_source_code from paddle.jit.dy2static.variable_trans_func import create_fill_constant_node -@dy2static_unittest -class TestVariableTransFunc(unittest.TestCase): +class TestVariableTransFunc(Dy2StTestBase): def test_create_fill_constant_node(self): node = create_fill_constant_node("a", 1.0) source = "a = paddle.full(shape=[1], dtype='float64', fill_value=1.0, name='a')" diff --git a/test/dygraph_to_static/test_warning.py b/test/dygraph_to_static/test_warning.py index c6e255e36bdff..2a80d11375156 100644 --- a/test/dygraph_to_static/test_warning.py +++ b/test/dygraph_to_static/test_warning.py @@ -15,7 +15,7 @@ import unittest import warnings -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle from paddle.static.nn import cond @@ -39,9 +39,8 @@ def false_fn(): return [paddle.to_tensor(3), [None, paddle.to_tensor(4)]] -@dy2static_unittest -class TestReturnNoneInIfelse(unittest.TestCase): - @ast_only_test +class TestReturnNoneInIfelse(Dy2StTestBase): + @test_ast_only def test_dy2static_warning(self): paddle.disable_static() with warnings.catch_warnings(record=True) as w: diff --git a/test/dygraph_to_static/test_word2vec.py b/test/dygraph_to_static/test_word2vec.py index 0f16f5b2a9d23..58a7cd79775fd 100644 --- a/test/dygraph_to_static/test_word2vec.py +++ b/test/dygraph_to_static/test_word2vec.py @@ -17,10 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle from paddle import base @@ -321,9 +318,8 @@ def train(to_static): return np.array(ret) -@dy2static_unittest -class TestWord2Vec(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestWord2Vec(Dy2StTestBase): + @test_legacy_and_pir def test_dygraph_static_same_loss(self): dygraph_loss = train(to_static=False) static_loss = train(to_static=True) diff --git a/test/dygraph_to_static/test_write_python_container.py b/test/dygraph_to_static/test_write_python_container.py index a175b881d86c7..c22a5c7cba0a9 100644 --- a/test/dygraph_to_static/test_write_python_container.py +++ b/test/dygraph_to_static/test_write_python_container.py @@ -14,10 +14,10 @@ import unittest -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - sot_only_test, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_sot_only, ) import paddle @@ -99,8 +99,7 @@ def func_ifelse_write_nest_list_dict(x): return res -@dy2static_unittest -class TestWriteContainer(unittest.TestCase): +class TestWriteContainer(Dy2StTestBase): def setUp(self): self.set_func() self.set_getitem_path() @@ -117,7 +116,7 @@ def get_raw_value(self, container, getitem_path): out = out[path] return out - @sot_only_test + @test_sot_only def test_write_container_sot(self): func_static = paddle.jit.to_static(self.func) input = paddle.to_tensor([1, 2, 3]) @@ -125,7 +124,7 @@ def test_write_container_sot(self): out_dygraph = self.get_raw_value(self.func(input), self.getitem_path) self.assertEqual(out_static, out_dygraph) - @ast_only_test + @test_ast_only def test_write_container(self): func_static = paddle.jit.to_static(self.func) input = paddle.to_tensor([1, 2, 3]) diff --git a/test/dygraph_to_static/test_yolov3.py b/test/dygraph_to_static/test_yolov3.py index 12830ca7bce55..5a848c3a92741 100644 --- a/test/dygraph_to_static/test_yolov3.py +++ b/test/dygraph_to_static/test_yolov3.py @@ -17,10 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from yolov3 import YOLOv3, cfg import paddle @@ -168,9 +165,8 @@ def train(to_static): return np.array(ret) -@dy2static_unittest -class TestYolov3(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestYolov3(Dy2StTestBase): + @test_legacy_and_pir def test_dygraph_static_same_loss(self): dygraph_loss = train(to_static=False) static_loss = train(to_static=True) diff --git a/test/indexing/test_getitem.py b/test/indexing/test_getitem.py index 7fd263e72b62b..6ce8519926a49 100644 --- a/test/indexing/test_getitem.py +++ b/test/indexing/test_getitem.py @@ -15,170 +15,740 @@ import unittest import numpy as np +from op_test import convert_float_to_uint16, convert_uint16_to_float import paddle +from paddle.base import core from paddle.base.variable_index import _getitem_static class TestGetitemInDygraph(unittest.TestCase): def setUp(self): paddle.disable_static() + self.ndtype = np.float64 + self.dtype = 'float64' def test_combined_index_1(self): # int tensor + slice (without decreasing axes) - np_data = np.random.randn(3, 4, 5, 6) + np_data = np.random.randn(3, 4, 5, 6).astype(self.ndtype) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[[0, 1], :, [1, 2]] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[[0, 1], :, [1, 2]] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_combined_index_2(self): # int tensor + slice (with decreasing axes) - np_data = np.random.randn(3, 4, 5, 6) - x = paddle.to_tensor(np_data) + np_data = np.random.randn(3, 4, 5, 6).astype(self.ndtype) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + x = paddle.to_tensor(np_data, dtype=self.dtype) np_res = np_data[:, 1, [1, 2], 0] y = x[:, 1, [1, 2], 0] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_combined_index_3(self): # multiple int tensors, with one int tensor at first axis - np_data = np.random.randn(3, 4, 5, 6, 7) + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[[1, 0], :, [1, 4], 1:5:2, 4] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[[1, 0], :, [1, 4], 1:5:2, 4] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_combined_index_4(self): # multiple not adjacent int tensors, with no int tensor at first axis - np_data = np.random.randn(3, 4, 5, 6, 7) + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[:, [1, 0], 0:4:2, [2, 3], 4] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[:, [1, 0], 0:4:2, [2, 3], 4] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_combined_index_5(self): # multiple adjacent int tensors, with no int tensor at first axis - np_data = np.random.randn(3, 4, 5, 6, 7) + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[::2, [1, 0], [2, 3], 0:4:2] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[::2, [1, 0], [2, 3], 0:4:2] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_combined_index_6(self): # multiple adjacent and not adjacent int tensors, with no int tensor at first axis - np_data = np.random.randn(3, 4, 5, 6, 7) + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[::2, [1, 0], [2, 3], 0:4:2, [4, 6]] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[::2, [1, 0], [2, 3], 0:4:2, [4, 6]] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_combined_index_7(self): # multiple adjacent and not adjacent int tensors (rank > 1d), with no int tensor at first axis - np_data = np.random.randn(3, 4, 5, 6, 7) + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_combined_index_8(self): # multiple adjacent and not adjacent int tensors (rank > 1d), with int tensor at first axis - np_data = np.random.randn(3, 4, 5, 6, 7) + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[ [[1, 0], [0, 1]], [[2, 3], [1, 0]], 0:4:2, [[3, 5], [4, 2]] ] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[[[1, 0], [0, 1]], [[2, 3], [1, 0]], 0:4:2, [[3, 5], [4, 2]]] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_combined_index_9(self): # multiple int tensors, with broadcast. - np_data = np.random.randn(3, 4, 5, 6, 7) + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_combined_index_10(self): # only one bool tensor with basic-index - np_data = np.random.randn(3, 4, 5, 6) + np_data = np.random.randn(3, 4, 5, 6).astype(self.ndtype) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[:, [True, False, True, False], 4] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[:, [True, False, True, False], 4] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_combined_index_11(self): # only one bool tensor with all False - np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) + np_data = ( + np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)).astype(self.ndtype) + ) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[:, [False, False, False, False], 4] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[:, [False, False, False, False], 4] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_index_has_range(self): - np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) + np_data = ( + np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)).astype(self.ndtype) + ) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[:, range(3), 4] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[:, range(3), 4] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_indexing_with_bool_list1(self): # test bool-list indexing when axes num less than x.rank - np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) + np_data = ( + np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)).astype(self.ndtype) + ) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[[True, False, True], [False, False, False, True]] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[[True, False, True], [False, False, False, True]] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_indexing_with_bool_list2(self): # test bool-list indexing when axes num less than x.rank - np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) + np_data = ( + np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)).astype(self.ndtype) + ) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[ [True, False, True], [False, False, True, False], [True, False, False, True, False], ] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[ [True, False, True], [False, False, True, False], [True, False, False, True, False], ] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + np.testing.assert_allclose(y.numpy(), np_res) def test_indexing_is_multi_dim_list(self): # indexing is multi-dim int list, should be treat as one index, like numpy>=1.23 - np_data = np.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3)) + np_data = ( + np.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3)).astype(self.ndtype) + ) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + np_res = np_data[np.array([[2, 3, 4], [1, 2, 5]])] - x = paddle.to_tensor(np_data) + x = paddle.to_tensor(np_data, dtype=self.dtype) y = x[[[2, 3, 4], [1, 2, 5]]] y_index_tensor = x[paddle.to_tensor([[2, 3, 4], [1, 2, 5]])] + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + y_index_tensor = paddle.cast(y_index_tensor, dtype='float32') np.testing.assert_allclose(y.numpy(), np_res) np.testing.assert_allclose(y.numpy(), y_index_tensor.numpy()) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", +) +class TestFP16GetitemInDygraph(TestGetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float16 + self.dtype = 'float16' + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", +) +class TestBF16GetitemInDygraph(TestGetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float32 + self.dtype = 'bfloat16' + + +class TestFP32GetitemInDygraph(TestGetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float32 + self.dtype = 'float32' + + +class TestUINT8GetitemInDygraph(TestGetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.uint8 + self.dtype = 'uint8' + + +class TestINT8GetitemInDygraph(TestGetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.int8 + self.dtype = 'int8' + + +class TestINT16GetitemInDygraph(TestGetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.int16 + self.dtype = 'int16' + + +class TestINT32GetitemInDygraph(TestGetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.int32 + self.dtype = 'int32' + + +class TestINT64GetitemInDygraph(TestGetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.int64 + self.dtype = 'int64' + + +class TestBOOLGetitemInDygraph(TestGetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.bool8 + self.dtype = 'bool' + + +class TestComplex64GetitemInDygraph(TestGetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float32 + self.dtype = 'complex64' + + +class TestComplex128GetitemInDygraph(TestGetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float64 + self.dtype = 'complex128' + + +class TestGetitemGrad(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float64 + self.dtype = 'float64' + + def test_combined_index_1(self): + np_data = np.random.randn(3, 4, 5, 6).astype(self.ndtype) + res = np.zeros(np_data.shape) + res[[0, 1], :, [1, 2]] = 1 + x = paddle.to_tensor(np_data, dtype=self.dtype, stop_gradient=False) + if self.dtype == 'bool': + x = x.astype('int') + y = x[[0, 1], :, [1, 2]] + z = y + 1 + z.backward() + if self.dtype == 'bfloat16': + np.testing.assert_allclose(x.grad.cast('float32').numpy(), res) + elif self.dtype == 'bool': + self.assertIsNone(x.grad) + else: + np.testing.assert_allclose(x.grad.numpy(), res) + + def test_combined_index_2(self): + np_data = np.random.randn(3, 4, 5, 6).astype(self.ndtype) + res = np.zeros(np_data.shape) + res[:, 1, [1, 2], 0] = 1 + + x = paddle.to_tensor(np_data, dtype=self.dtype, stop_gradient=False) + if self.dtype == 'bool': + x = x.astype('int') + np_res = np_data[:, 1, [1, 2], 0] + y = x[:, 1, [1, 2], 0] + z = y + 1 + z.backward() + if self.dtype == 'bfloat16': + np.testing.assert_allclose(x.grad.cast('float32').numpy(), res) + elif self.dtype == 'bool': + self.assertIsNone(x.grad) + else: + np.testing.assert_allclose(x.grad.numpy(), res) + + def test_combined_index_3(self): + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + res = np.zeros(np_data.shape) + res[[1, 0], :, [1, 4], 1:5:2, 4] = 1 + + x = paddle.to_tensor(np_data, dtype=self.dtype, stop_gradient=False) + if self.dtype == 'bool': + x = x.astype('int') + y = x[[1, 0], :, [1, 4], 1:5:2, 4] + z = y + 1 + z.backward() + if self.dtype == 'bfloat16': + np.testing.assert_allclose(x.grad.cast('float32').numpy(), res) + elif self.dtype == 'bool': + self.assertIsNone(x.grad) + else: + np.testing.assert_allclose(x.grad.numpy(), res) + + def test_combined_index_4(self): + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + res = np.zeros(np_data.shape) + res[:, [1, 0], 0:4:2, [2, 3], 4] = 1 + + x = paddle.to_tensor(np_data, dtype=self.dtype, stop_gradient=False) + if self.dtype == 'bool': + x = x.astype('int') + y = x[:, [1, 0], 0:4:2, [2, 3], 4] + z = y + 1 + z.backward() + if self.dtype == 'bfloat16': + np.testing.assert_allclose(x.grad.cast('float32').numpy(), res) + elif self.dtype == 'bool': + self.assertIsNone(x.grad) + else: + np.testing.assert_allclose(x.grad.numpy(), res) + + def test_combined_index_5(self): + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + res = np.zeros(np_data.shape) + res[::2, [1, 0], [2, 3], 0:4:2] = 1 + + x = paddle.to_tensor(np_data, dtype=self.dtype, stop_gradient=False) + if self.dtype == 'bool': + x = x.astype('int') + y = x[::2, [1, 0], [2, 3], 0:4:2] + z = y + 1 + z.backward() + if self.dtype == 'bfloat16': + np.testing.assert_allclose(x.grad.cast('float32').numpy(), res) + elif self.dtype == 'bool': + self.assertIsNone(x.grad) + else: + np.testing.assert_allclose(x.grad.numpy(), res) + + def test_combined_index_6(self): + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + res = np.zeros(np_data.shape) + res[::2, [1, 0], [2, 3], 0:4:2, [4, 6]] = 1 + + x = paddle.to_tensor(np_data, dtype=self.dtype, stop_gradient=False) + if self.dtype == 'bool': + x = x.astype('int') + y = x[::2, [1, 0], [2, 3], 0:4:2, [4, 6]] + z = y + 1 + z.backward() + if self.dtype == 'bfloat16': + np.testing.assert_allclose(x.grad.cast('float32').numpy(), res) + elif self.dtype == 'bool': + self.assertIsNone(x.grad) + else: + np.testing.assert_allclose(x.grad.numpy(), res) + + def test_combined_index_7(self): + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + res = np.zeros(np_data.shape) + res[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]] = 1 + + x = paddle.to_tensor(np_data, dtype=self.dtype, stop_gradient=False) + if self.dtype == 'bool': + x = x.astype('int') + y = x[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]] + z = y + 1 + z.backward() + if self.dtype == 'bfloat16': + np.testing.assert_allclose(x.grad.cast('float32').numpy(), res) + elif self.dtype == 'bool': + self.assertIsNone(x.grad) + else: + np.testing.assert_allclose(x.grad.numpy(), res) + + def test_combined_index_8(self): + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + res = np.zeros(np_data.shape) + res[[[1, 0], [0, 1]], [[2, 3], [1, 0]], 0:4:2, [[3, 5], [4, 2]]] = 1 + + x = paddle.to_tensor(np_data, dtype=self.dtype, stop_gradient=False) + if self.dtype == 'bool': + x = x.astype('int') + y = x[[[1, 0], [0, 1]], [[2, 3], [1, 0]], 0:4:2, [[3, 5], [4, 2]]] + z = y + 1 + z.backward() + if self.dtype == 'bfloat16': + np.testing.assert_allclose(x.grad.cast('float32').numpy(), res) + elif self.dtype == 'bool': + self.assertIsNone(x.grad) + else: + np.testing.assert_allclose(x.grad.numpy(), res) + + def test_combined_index_9(self): + np_data = np.random.randn(3, 4, 5, 6, 7).astype(self.ndtype) + res = np.zeros(np_data.shape) + res[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]] = 1 + + x = paddle.to_tensor(np_data, dtype=self.dtype, stop_gradient=False) + if self.dtype == 'bool': + x = x.astype('int') + y = x[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]] + z = y + 1 + z.backward() + if self.dtype == 'bfloat16': + np.testing.assert_allclose(x.grad.cast('float32').numpy(), res) + elif self.dtype == 'bool': + self.assertIsNone(x.grad) + else: + np.testing.assert_allclose(x.grad.numpy(), res) + + def test_combined_index_10(self): + np_data = np.random.randn(3, 4, 5, 6).astype(self.ndtype) + res = np.zeros(np_data.shape) + res[:, [True, False, True, False], 4] = 1 + + x = paddle.to_tensor(np_data, dtype=self.dtype, stop_gradient=False) + if self.dtype == 'bool': + x = x.astype('int') + y = x[:, [True, False, True, False], 4] + z = y + 1 + z.backward() + if self.dtype == 'bfloat16': + np.testing.assert_allclose(x.grad.cast('float32').numpy(), res) + elif self.dtype == 'bool': + self.assertIsNone(x.grad) + else: + np.testing.assert_allclose(x.grad.numpy(), res) + + def test_index_has_range(self): + np_data = ( + np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)).astype(self.ndtype) + ) + res = np.zeros(np_data.shape) + res[:, range(3), 4] = 1 + + x = paddle.to_tensor(np_data, dtype=self.dtype, stop_gradient=False) + if self.dtype == 'bool': + x = x.astype('int') + y = x[:, range(3), 4] + z = y + 1 + z.backward() + if self.dtype == 'bfloat16': + np.testing.assert_allclose(x.grad.cast('float32').numpy(), res) + elif self.dtype == 'bool': + self.assertIsNone(x.grad) + else: + np.testing.assert_allclose(x.grad.numpy(), res) + + def test_indexing_with_bool_list1(self): + np_data = ( + np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)).astype(self.ndtype) + ) + res = np.zeros(np_data.shape) + res[[True, False, True], [False, False, False, True]] = 1 + + x = paddle.to_tensor(np_data, dtype=self.dtype, stop_gradient=False) + if self.dtype == 'bool': + x = x.astype('int') + y = x[[True, False, True], [False, False, False, True]] + z = y + 1 + z.backward() + if self.dtype == 'bfloat16': + np.testing.assert_allclose(x.grad.cast('float32').numpy(), res) + elif self.dtype == 'bool': + self.assertIsNone(x.grad) + else: + np.testing.assert_allclose(x.grad.numpy(), res) + + def test_indexing_with_bool_list2(self): + np_data = ( + np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)).astype(self.ndtype) + ) + res = np.zeros(np_data.shape) + res[ + [True, False, True], + [False, False, True, False], + [True, False, False, True, False], + ] = 1 + + x = paddle.to_tensor(np_data, dtype=self.dtype, stop_gradient=False) + if self.dtype == 'bool': + x = x.astype('int') + y = x[ + [True, False, True], + [False, False, True, False], + [True, False, False, True, False], + ] + z = y + 1 + z.backward() + if self.dtype == 'bfloat16': + np.testing.assert_allclose(x.grad.cast('float32').numpy(), res) + elif self.dtype == 'bool': + self.assertIsNone(x.grad) + else: + np.testing.assert_allclose(x.grad.numpy(), res) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", +) +class TestFP16GetitemGradInDygraph(TestGetitemGrad): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float16 + self.dtype = 'float16' + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", +) +class TestBF16GetitemGradInDygraph(TestGetitemGrad): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float32 + self.dtype = 'bfloat16' + + +class TestFP32GetitemGradInDygraph(TestGetitemGrad): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float32 + self.dtype = 'float32' + + +class TestBOOLGetitemGradInDygraph(TestGetitemGrad): + def setUp(self): + paddle.disable_static() + self.ndtype = np.bool8 + self.dtype = 'bool' + + +class TestINT8GetitemGradInDygraph(TestGetitemGrad): + def setUp(self): + paddle.disable_static() + self.ndtype = np.int8 + self.dtype = 'int8' + + +class TestINT16GetitemGradInDygraph(TestGetitemGrad): + def setUp(self): + paddle.disable_static() + self.ndtype = np.int16 + self.dtype = 'int16' + + +class TestINT32GetitemGradInDygraph(TestGetitemGrad): + def setUp(self): + paddle.disable_static() + self.ndtype = np.int32 + self.dtype = 'int32' + + +class TestINT64GetitemGradInDygraph(TestGetitemGrad): + def setUp(self): + paddle.disable_static() + self.ndtype = np.int64 + self.dtype = 'int64' + + +class TestComplex64GetitemGradInDygraph(TestGetitemGrad): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float32 + self.dtype = 'complex64' + + +class TestComplex128GetitemGradInDygraph(TestGetitemGrad): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float64 + self.dtype = 'complex128' + + class TestGetitemInStatic(unittest.TestCase): def setUp(self): paddle.enable_static() diff --git a/test/indexing/test_setitem.py b/test/indexing/test_setitem.py index 3c3f8deb3955e..6620922603e7a 100644 --- a/test/indexing/test_setitem.py +++ b/test/indexing/test_setitem.py @@ -15,118 +15,188 @@ import unittest import numpy as np +from op_test import convert_float_to_uint16, convert_uint16_to_float import paddle +from paddle.base import core from paddle.base.variable_index import _setitem_static class TestSetitemInDygraph(unittest.TestCase): def setUp(self): paddle.disable_static() + self.ndtype = np.float64 + self.dtype = 'float64' def test_combined_index_1(self): - np_data = np.zeros((3, 4, 5, 6), dtype='float32') - x = paddle.to_tensor(np_data) + np_data = np.zeros((3, 4, 5, 6), dtype='float32').astype(self.ndtype) + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + x = paddle.to_tensor(np_data, dtype=self.dtype) np_data[[0, 1], :, [1, 2]] = 10.0 x[[0, 1], :, [1, 2]] = 10.0 + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') np.testing.assert_allclose(x.numpy(), np_data) def test_combined_index_2(self): - np_data = np.ones((3, 4, 5, 6), dtype='float32') - x = paddle.to_tensor(np_data) + np_data = np.ones((3, 4, 5, 6), dtype='float32').astype(self.ndtype) + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + x = paddle.to_tensor(np_data, dtype=self.dtype) np_data[:, 1, [1, 2], 0] = 10.0 x[:, 1, [1, 2], 0] = 10.0 + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') np.testing.assert_allclose(x.numpy(), np_data) def test_combined_index_3(self): - np_data = np.ones((3, 4, 5, 6), dtype='int32') - x = paddle.to_tensor(np_data) + np_data = np.ones((3, 4, 5, 6), dtype='int32').astype(self.ndtype) + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + x = paddle.to_tensor(np_data, dtype=self.dtype) np_data[:, [True, False, True, False], [1, 4]] = 10 x[:, [True, False, True, False], [1, 4]] = 10 + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') np.testing.assert_allclose(x.numpy(), np_data) def test_index_has_range(self): - np_data = np.ones((3, 4, 5, 6), dtype='int32') - x = paddle.to_tensor(np_data) + np_data = np.ones((3, 4, 5, 6), dtype='int32').astype(self.ndtype) + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + x = paddle.to_tensor(np_data, dtype=self.dtype) np_data[:, range(3), [1, 2, 4]] = 10 x[:, range(3), [1, 2, 4]] = 10 + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') np.testing.assert_allclose(x.numpy(), np_data) def test_src_value_with_different_dtype_1(self): # basic-indexing, with set_value op - np_data = np.ones((3, 4, 5, 6), dtype='int32') + np_data = np.ones((3, 4, 5, 6), dtype='int32').astype(self.ndtype) + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data np_value = np.zeros((6,), dtype='float32') - x = paddle.to_tensor(np_data) - v = paddle.to_tensor(np_value) + x = paddle.to_tensor(np_data, dtype=self.dtype) + v = paddle.to_tensor(np_value, dtype=self.dtype) np_data[0, 2, 3] = np_value x[0, 2, 3] = v + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') np.testing.assert_allclose(x.numpy(), np_data) def test_src_value_with_different_dtype_2(self): # combined-indexing, with index_put op - np_data = np.ones((3, 4, 5, 6), dtype='float32') + np_data = np.ones((3, 4, 5, 6), dtype='float32').astype(self.ndtype) + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data np_value = np.zeros((6,), dtype='int64') - x = paddle.to_tensor(np_data) - v = paddle.to_tensor(np_value) + x = paddle.to_tensor(np_data, dtype=self.dtype) + v = paddle.to_tensor(np_value, dtype=self.dtype) np_data[:, [1, 0], 3] = np_value x[:, [1, 0], 3] = v + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') np.testing.assert_allclose(x.numpy(), np_data) def test_indexing_with_bool_list1(self): # test bool-list indexing when axes num less than x.rank - np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) - np_data[[True, False, True], [False, False, False, True]] = 7 + np_data = ( + np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)).astype(self.ndtype) + ) + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + x = paddle.to_tensor(np_data, dtype=self.dtype) - x = paddle.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) + np_data[[True, False, True], [False, False, False, True]] = 7 x[[True, False, True], [False, False, False, True]] = 7 - np.testing.assert_allclose(x.numpy(), np_data) + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') + + np.testing.assert_allclose(x.numpy(), np_data, verbose=True) def test_indexing_with_bool_list2(self): # test bool-list indexing when axes num less than x.rank - np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) + np_data = ( + np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)).astype(self.ndtype) + ) + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + x = paddle.to_tensor(np_data, dtype=self.dtype) + np_data[ [True, False, True], [False, False, True, False], [True, False, False, True, False], ] = 8 - x = paddle.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) x[ [True, False, True], [False, False, True, False], [True, False, False, True, False], ] = 8 + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') np.testing.assert_allclose(x.numpy(), np_data) def test_indexing_is_multi_dim_list(self): # indexing is multi-dim int list, should be treat as one index, like numpy>=1.23 - np_data = np.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3)) + np_data = ( + np.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3)).astype(self.ndtype) + ) + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_data = np_data + 1j * np_data + x = paddle.to_tensor(np_data, dtype=self.dtype) np_data[np.array([[2, 3, 4], [1, 2, 5]])] = 100 - - x = paddle.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3)) x[[[2, 3, 4], [1, 2, 5]]] = 100 + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') np.testing.assert_allclose(x.numpy(), np_data) def test_inplace_with_stride(self): - v = paddle.randn((3, 1)) + np_v = np.random.randn(3, 1).astype(self.ndtype) + if self.dtype == 'bfloat16': + np_v = convert_uint16_to_float(convert_float_to_uint16(np_v)) + if self.dtype == 'complex64' or self.dtype == 'complex32': + np_v = np_v + 1j * np_v + v = paddle.to_tensor(np_v, dtype=self.dtype) v.stop_gradient = False - vv = v * 1 + vv = v zero = paddle.randn((3, 3, 5)) zero.stop_gradient = False @@ -138,7 +208,103 @@ def test_inplace_with_stride(self): loss.backward() expected_v_grad = np.ones((3, 1)) * 10.0 - np.testing.assert_equal(v.grad.numpy(), expected_v_grad) + if self.dtype == 'bfloat16': + np.testing.assert_allclose( + v.grad.cast('float32').numpy(), expected_v_grad + ) + elif self.dtype == 'bool': + np.testing.assert_equal( + v.grad.numpy(), expected_v_grad.astype('bool') + ) + else: + np.testing.assert_equal(v.grad.numpy(), expected_v_grad) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", +) +class TestFP16SetitemInDygraph(TestSetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float16 + self.dtype = 'float16' + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", +) +class TestBF16SetitemInDygraph(TestSetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float32 + self.dtype = 'bfloat16' + + +class TestFP32SetitemInDygraph(TestSetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float32 + self.dtype = 'float32' + + +class TestUINT8SetitemInDygraph(TestSetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.uint8 + self.dtype = 'uint8' + + +class TestINT8SetitemInDygraph(TestSetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.int8 + self.dtype = 'int8' + + +class TestINT16SetitemInDygraph(TestSetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.int16 + self.dtype = 'int16' + + +class TestINT32SetitemInDygraph(TestSetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.int32 + self.dtype = 'int32' + + +class TestINT64SetitemInDygraph(TestSetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.int64 + self.dtype = 'int64' + + +class TestBOOLSetitemInDygraph(TestSetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.bool8 + self.dtype = 'bool' + + +class TestComplex64SetitemInDygraph(TestSetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float32 + self.dtype = 'complex64' + + +class TestComplex128SetitemInDygraph(TestSetitemInDygraph): + def setUp(self): + paddle.disable_static() + self.ndtype = np.float64 + self.dtype = 'complex128' class TestSetitemInStatic(unittest.TestCase): diff --git a/test/ir/CMakeLists.txt b/test/ir/CMakeLists.txt index fab15cc488caf..232ef033e2b35 100644 --- a/test/ir/CMakeLists.txt +++ b/test/ir/CMakeLists.txt @@ -16,7 +16,7 @@ foreach(target ${TEST_IR_PASSES}) endforeach() add_subdirectory(inference) -add_subdirectory(new_ir) +add_subdirectory(pir) set_tests_properties(test_fuse_resnet_unit PROPERTIES TIMEOUT 120) set_tests_properties(test_convert_to_mixed_precision PROPERTIES TIMEOUT 300) diff --git a/test/ir/inference/inference_pass_test.py b/test/ir/inference/inference_pass_test.py index 48b1603728a1b..f5bfc9c767adb 100644 --- a/test/ir/inference/inference_pass_test.py +++ b/test/ir/inference/inference_pass_test.py @@ -225,9 +225,7 @@ def check_output_with_option( # Check whether the results calculated on CPU and on GPU are the same. self.assertTrue( len(paddle_outs) == len(inference_outs), - "The number of outputs is different between inference and training forward at {}".format( - device - ), + f"The number of outputs is different between inference and training forward at {device}", ) for out, inference_out in zip(paddle_outs, inference_outs): @@ -241,9 +239,7 @@ def check_output_with_option( inference_out, rtol=1e-05, atol=atol, - err_msg='Output has diff between inference and training forward at {} '.format( - device - ), + err_msg=f'Output has diff between inference and training forward at {device} ', ) # Check whether the trt results and the GPU results are the same. diff --git a/test/ir/inference/quant_dequant_test.py b/test/ir/inference/quant_dequant_test.py index 4f1e2335f5d49..8cabff3d2d5b3 100644 --- a/test/ir/inference/quant_dequant_test.py +++ b/test/ir/inference/quant_dequant_test.py @@ -330,9 +330,7 @@ def check_output_with_option( # Check whether the results calculated on CPU and on GPU are the same. self.assertTrue( len(paddle_outs) == len(inference_outs), - "The number of outputs is different between inference and training forward at {}".format( - device - ), + f"The number of outputs is different between inference and training forward at {device}", ) for out, inference_out in zip(paddle_outs, inference_outs): @@ -347,9 +345,7 @@ def check_output_with_option( inference_out, rtol=1e-05, atol=atol, - err_msg='Output has diff between inference and training forward at {} '.format( - device - ), + err_msg=f'Output has diff between inference and training forward at {device} ', ) # Check whether the trt results and the GPU results are the same. diff --git a/test/ir/inference/test_inference_predictor_run.py b/test/ir/inference/test_inference_predictor_run.py index c6a8c5db9f3c1..1c552bc82b77e 100644 --- a/test/ir/inference/test_inference_predictor_run.py +++ b/test/ir/inference/test_inference_predictor_run.py @@ -63,6 +63,7 @@ def tearDown(self): self.temp_dir.cleanup() def init_predictor(self): + paddle.set_flags({'FLAGS_enable_pir_in_executor': True}) config = Config( os.path.join( self.temp_dir.name, @@ -74,7 +75,9 @@ def init_predictor(self): ), ) config.enable_use_gpu(256, 0) - config.enable_memory_optim() + config.switch_ir_optim(False) + # config.enable_memory_optim() + config.enable_new_executor() predictor = create_predictor(config) return predictor @@ -89,9 +92,7 @@ def get_inputs(self): return [input0_tensor, input1_tensor] - def get_disorder_output(self): - predictor = self.init_predictor() - + def get_disorder_output(self, predictor): [input0_tensor, input1_tensor] = self.get_inputs() input_names = predictor.get_input_names() @@ -104,9 +105,7 @@ def get_disorder_output(self): return outputs[0] - def get_inorder_output(self): - predictor = self.init_predictor() - + def get_inorder_output(self, predictor): [input0_tensor, input1_tensor] = self.get_inputs() # inorder @@ -116,8 +115,9 @@ def get_inorder_output(self): return outputs[0] def test_output(self): - inorder_output = self.get_inorder_output() - disorder_output = self.get_disorder_output() + predictor = self.init_predictor() + inorder_output = self.get_inorder_output(predictor) + disorder_output = self.get_disorder_output(predictor) np.testing.assert_allclose( inorder_output.numpy().flatten(), disorder_output.numpy().flatten() diff --git a/test/ir/inference/test_trt_support_nhwc_pass.py b/test/ir/inference/test_trt_support_nhwc_pass.py index 0648202aba30c..bd585d1b5b850 100644 --- a/test/ir/inference/test_trt_support_nhwc_pass.py +++ b/test/ir/inference/test_trt_support_nhwc_pass.py @@ -93,6 +93,10 @@ def setUp(self): self.temp_dir.name, 'inference_pass', 'nhwc_converter', '' ) self.model_prefix = self.path + 'infer_model' + self.set_args() + + def set_args(self): + self.precision_mode = inference.PrecisionType.Float32 def create_model(self): image = static.data( @@ -115,7 +119,7 @@ def create_predictor(self): workspace_size=1 << 30, max_batch_size=1, min_subgraph_size=3, - precision_mode=inference.PrecisionType.Float32, + precision_mode=self.precision_mode, use_static=False, use_calib_mode=False, ) @@ -147,5 +151,44 @@ def tearDown(self): shutil.rmtree(self.path) +class TRTNHWCConvertAMPTest(TRTNHWCConvertTest): + def set_args(self): + self.precision_mode = inference.PrecisionType.Half + + def create_model(self): + train_prog = paddle.static.Program() + with paddle.static.program_guard(train_prog): + with paddle.static.amp.fp16_guard(): + image = paddle.static.data( + name='image', shape=[None, 224, 224, 4], dtype='float32' + ) + label = paddle.static.data( + name='label', shape=[None, 1], dtype='int64' + ) + predict = SimpleNet()(image) + cost = paddle.nn.functional.loss.cross_entropy( + input=predict, label=label + ) + avg_cost = paddle.mean(x=cost) + optimizer = paddle.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + weight_decay=paddle.regularizer.L2Decay(4e-5), + ) + optimizer = paddle.static.amp.decorate( + optimizer, + use_dynamic_loss_scaling=False, + use_pure_fp16=False, + ) + optimizer.minimize(avg_cost) + val_prog = train_prog.clone(for_test=True) + + exe = paddle.static.Executor(self.place) + exe.run(paddle.static.default_startup_program()) + paddle.static.save_inference_model( + self.model_prefix, [image], [predict], exe, program=val_prog + ) + + if __name__ == '__main__': unittest.main() diff --git a/test/ir/new_ir/fused_pass/test_fused_gemm_epilogue_pass.py b/test/ir/new_ir/fused_pass/test_fused_gemm_epilogue_pass.py new file mode 100644 index 0000000000000..a3aa6a4458b96 --- /dev/null +++ b/test/ir/new_ir/fused_pass/test_fused_gemm_epilogue_pass.py @@ -0,0 +1,120 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.autograd.ir_backward import grad as ir_grad +from paddle.base import core + +np.random.seed(2013) + +import os +import re + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11060, + "core is not complied with CUDA or nvcc version is less than11.6", +) +class TestFusedgemm_epilogueAdd(unittest.TestCase): + def test_fused_gemm_epilogue_add(self): + with paddle.pir_utils.IrGuard(): + x_np = np.random.normal(3, 2.5, size=(1024, 1024)).astype( + np.float32 + ) + y_np = x_np + z_np = np.random.normal(3, 2.5, size=(1024)).astype(np.float32) + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x_ = paddle.static.data( + name="x", shape=[1024, 1024], dtype="float32" + ) + y_ = paddle.static.data( + name="y", shape=[1024, 1024], dtype="float32" + ) + z_ = paddle.static.data(name="z", shape=[1024], dtype="float32") + x_.stop_gradient = False + y_.stop_gradient = False + z_.stop_gradient = False + x = paddle.assign(x_) + y = paddle.assign(y_) + z = paddle.assign(z_) + res1 = paddle.matmul(x=x, y=y) + res2 = paddle.add(res1, z) + res3 = paddle.assign(res2) + + res4, res5, res6 = ir_grad(res3, [x, y, z]) + res4_ = paddle.assign(res4) + res5_ = paddle.assign(res5) + res6_ = paddle.assign(res6) + op_names = [op.name() for op in main_program.global_block().ops] + self.assertTrue( + 'pd_op.matmul' in op_names and 'pd_op.add' in op_names + ) + self.assertTrue( + 'pd_op.add_grad' in op_names + and 'pd_op.matmul_grad' in op_names + ) + + with paddle.static.scope_guard(paddle.static.Scope()): + exe = paddle.base.Executor(paddle.base.CUDAPlace(0)) + fetches0 = exe.run( + main_program, + feed={"x": x_np, "y": y_np, "z": z_np}, + fetch_list=[res3, res4_, res5_, res6_], + ) + # main_program = main_program.clone() + + pm = paddle.pir.PassManager() + pm.add_pass( + 'fused_gemm_epilogue_pass' + ) # apply pass to elimitate dead code + pm.run(main_program) + op_names = [op.name() for op in main_program.global_block().ops] + self.assertTrue( + 'pd_op.fused_gemm_epilogue' in op_names + and 'pd_op.fused_gemm_epilogue_grad' in op_names + ) + + with paddle.static.scope_guard(paddle.static.Scope()): + exe = paddle.base.Executor(paddle.base.CUDAPlace(0)) + fetches1 = exe.run( + main_program, + feed={"x": x_np, "y": y_np, "z": z_np}, + fetch_list=[res3, res4_, res5_, res6_], + ) + + np.array_equal(fetches0[0], fetches1[0]) + np.array_equal(fetches0[1], fetches1[1]) + np.array_equal(fetches0[2], fetches1[2]) + np.array_equal(fetches0[3], fetches1[3]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/ir/new_ir/CMakeLists.txt b/test/ir/pir/CMakeLists.txt similarity index 80% rename from test/ir/new_ir/CMakeLists.txt rename to test/ir/pir/CMakeLists.txt index cad2633fb1aa4..ec950e711d894 100644 --- a/test/ir/new_ir/CMakeLists.txt +++ b/test/ir/pir/CMakeLists.txt @@ -6,12 +6,12 @@ string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}") set(TEST_IR_SYSTEM_CASES test_build_model test_pd_inplace_pass test_symbol_overload - test_new_ir_to_static test_stop_gradient test_override_operator) + test_pir_to_static test_stop_gradient test_override_operator) list(REMOVE_ITEM TEST_INTERP_CASES ${TEST_IR_SYSTEM_CASES}) foreach(target ${TEST_INTERP_CASES}) py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1 - FLAGS_enable_new_ir_in_executor=true) + FLAGS_enable_pir_in_executor=true) endforeach() foreach(target ${TEST_IR_SYSTEM_CASES}) @@ -19,3 +19,5 @@ foreach(target ${TEST_IR_SYSTEM_CASES}) endforeach() set_tests_properties(test_pd_inplace_pass PROPERTIES TIMEOUT 60) + +add_subdirectory(fused_pass) diff --git a/test/ir/pir/cinn/CMakeLists.txt b/test/ir/pir/cinn/CMakeLists.txt new file mode 100644 index 0000000000000..e25676c5e8756 --- /dev/null +++ b/test/ir/pir/cinn/CMakeLists.txt @@ -0,0 +1,21 @@ +if(WITH_GPU) + file( + GLOB CINN_PIR_TEST + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*.py") + + foreach(cinn_pir_test_name ${CINN_PIR_TEST}) + string(REGEX REPLACE ".py" "" cinn_pir_test_name ${cinn_pir_test_name}) + add_test( + NAME ${cinn_pir_test_name} + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH} + FLAGS_enable_pir_api=1 ${PYTHON_EXECUTABLE} + ${CMAKE_CURRENT_SOURCE_DIR}/${cinn_pir_test_name}.py + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + set_tests_properties(${cinn_pir_test_name} PROPERTIES LABELS + "RUN_TYPE=CINN") + endforeach() + +endif() diff --git a/test/ir/pir/cinn/test_cinn_sub_graph.py b/test/ir/pir/cinn/test_cinn_sub_graph.py new file mode 100644 index 0000000000000..612942bc1853d --- /dev/null +++ b/test/ir/pir/cinn/test_cinn_sub_graph.py @@ -0,0 +1,108 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.jit import not_to_static + + +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) + + +# TODO(Aurelius84): support in next PR +@not_to_static +def softmax(x, axis): + """define composite rule of op softmax""" + is_amp = False + from paddle.base.data_feeder import convert_dtype + + # Softmax need fp32 compute since it has sum op in + dtype = convert_dtype(x.dtype) + if dtype in ["float16", "uint16"]: + is_amp = True + x = paddle.cast(x, "float32") + if not x.shape: + # do not return 1, to ensure gradients + res = paddle.exp(x - x) + if is_amp: + res = paddle.cast(res, "float16") + return res + max_temp = paddle.max(x, axis, keepdim=True) + max_temp.stop_gradient = True + molecular = paddle.exp(x - max_temp) + denominator = paddle.sum(molecular, axis=axis, keepdim=True) + res = paddle.divide(molecular, denominator) + if is_amp: + res = paddle.cast(res, dtype) + return res + + +def exp_sub(x): + y = paddle.exp(x) + z = y - x + return z + + +class CINNSubGraphNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.fn = exp_sub + # TODO(Aurelius84): support in next PR + # self.fn = softmax + + def forward(self, x, axis=1): + # out = self.fn(x, axis=axis) + out = self.fn(x) + return out + + +class TestCinnSubGraphBase(unittest.TestCase): + """ + Test Pir API + @to_static + CINN. + """ + + def setUp(self): + paddle.seed(2022) + self.shape = [64, 128] + self.axis = -1 + self.prepare_data() + + def prepare_data(self): + self.x = paddle.randn(self.shape, dtype="float32") + self.x.stop_gradient = False + + def train(self, use_cinn): + paddle.seed(2022) + net = CINNSubGraphNet() + net = apply_to_static(net, use_cinn) + net.eval() + out = net(self.x, self.axis) + return out + + def test_forward(self): + cinn_out = self.train(use_cinn=True) + dy_out = self.train(use_cinn=False) + np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy()) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/ir/pir/fused_pass/CMakeLists.txt b/test/ir/pir/fused_pass/CMakeLists.txt new file mode 100644 index 0000000000000..8876db2d4b794 --- /dev/null +++ b/test/ir/pir/fused_pass/CMakeLists.txt @@ -0,0 +1,9 @@ +file( + GLOB TEST_INTERP_CASES + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*.py") +string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}") + +foreach(target ${TEST_INTERP_CASES}) + py_test_modules(${target} MODULES ${target}) +endforeach() diff --git a/test/ir/pir/fused_pass/test_fused_dropout_add_pass.py b/test/ir/pir/fused_pass/test_fused_dropout_add_pass.py new file mode 100644 index 0000000000000..af413d0e2096e --- /dev/null +++ b/test/ir/pir/fused_pass/test_fused_dropout_add_pass.py @@ -0,0 +1,118 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.autograd.ir_backward import grad +from paddle.base import core + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not complied with CUDA", +) +class TestFusedDropoutAdd(unittest.TestCase): + def _test_fused_dropout_add(self): + with paddle.pir_utils.IrGuard(): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data(name="x", shape=[3, 2], dtype="float32") + y = paddle.static.data(name="y", shape=[3, 2], dtype="float32") + res1 = paddle.nn.functional.dropout(x=x, p=0.5, training=True) + res2 = paddle.add(res1, y) + res3 = paddle.sum(res2) + + op_names = [op.name() for op in main_program.global_block().ops] + self.assertTrue('pd_op.dropout' in op_names) + self.assertTrue('pd_op.add' in op_names) + pm = paddle.pir.PassManager() + pm.add_pass( + 'fused_dropout_add_pass' + ) # apply pass to elimitate dead code + pm.run(main_program) + op_names = [op.name() for op in main_program.global_block().ops] + self.assertTrue('pd_op.fused_dropout_add' in op_names) + self.assertTrue('pd_op.dropout' not in op_names) + + x_np = np.ones([3, 2]).astype("float32") + y_np = x_np + + exe = paddle.base.Executor(paddle.base.CUDAPlace(0)) + fetches = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[res3], + ) + + def test_fused_dropout_add_grad(self): + with paddle.pir_utils.IrGuard(): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data(name="x", shape=[3, 2], dtype="float32") + x.stop_gradient = False + y = paddle.static.data(name="y", shape=[3, 2], dtype="float32") + y.stop_gradient = False + dout = paddle.static.data( + name="dout", shape=[3, 2], dtype="float32" + ) + res0 = paddle.assign(x) + res1 = paddle.nn.functional.dropout( + x=res0, p=0.5, training=True + ) + res2 = paddle.add(res1, y) + res3 = paddle.sum(res2) + + # res4 = paddle.incubate.nn.functional.fused_dropout_add( x, y, p=0.5, training=True) + # res5 = paddle.sum(res4) + dx = grad(res3, x) + + op_names = [op.name() for op in main_program.global_block().ops] + self.assertTrue( + 'pd_op.dropout' in op_names and 'pd_op.add' in op_names + ) + self.assertTrue( + 'pd_op.add_grad' in op_names + and 'pd_op.dropout_grad' in op_names + ) + pm = paddle.pir.PassManager() + pm.add_pass( + 'fused_dropout_add_pass' + ) # apply pass to elimitate dead code + pm.run(main_program) + op_names = [op.name() for op in main_program.global_block().ops] + self.assertTrue( + 'pd_op.fused_dropout_add' in op_names + and 'pd_op.fused_dropout_add_grad' in op_names + ) + self.assertTrue( + 'pd_op.dropout' not in op_names + and 'pd_op.dropout_grad' not in op_names + ) + + x_np = np.ones([3, 2]).astype("float32") + y_np = x_np + + exe = paddle.base.Executor(paddle.base.CUDAPlace(0)) + fetches = exe.run( + main_program, + feed={"x": x_np, "y": y_np, "dout": y_np}, + fetch_list=[dx], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/ir/new_ir/test_build_model.py b/test/ir/pir/test_build_model.py similarity index 100% rename from test/ir/new_ir/test_build_model.py rename to test/ir/pir/test_build_model.py diff --git a/test/ir/new_ir/test_build_op.py b/test/ir/pir/test_build_op.py similarity index 82% rename from test/ir/new_ir/test_build_op.py rename to test/ir/pir/test_build_op.py index 68a6ce35c2bb7..cd0ae03b33958 100644 --- a/test/ir/new_ir/test_build_op.py +++ b/test/ir/pir/test_build_op.py @@ -33,16 +33,16 @@ def get_ir_program(): y_s = paddle.matmul(x_s, x_s) y_s = paddle.add(x_s, y_s) y_s = paddle.tanh(y_s) - newir_program = pir.translate_to_new_ir(main_program.desc) - return newir_program + pir_program = pir.translate_to_pir(main_program.desc) + return pir_program class TestBuildOp(unittest.TestCase): def test_build_mean_op(self): - newir_program = get_ir_program() - tanh_out = newir_program.global_block().ops[-1].result(0) + pir_program = get_ir_program() + tanh_out = pir_program.global_block().ops[-1].result(0) with paddle.pir_utils.IrGuard(), paddle.pir.core.program_guard( - newir_program + pir_program ): out = paddle.mean(tanh_out) self.assertEqual(out.get_defining_op().name(), "pd_op.mean") @@ -58,10 +58,10 @@ def test_build_mean_op(self): class TestBuildOp2(unittest.TestCase): def test_build_add_n_op(self): - newir_program = get_ir_program() - tanh_out = newir_program.global_block().ops[-1].result(0) + pir_program = get_ir_program() + tanh_out = pir_program.global_block().ops[-1].result(0) with paddle.pir_utils.IrGuard(), paddle.pir.core.program_guard( - newir_program + pir_program ): out1 = paddle.mean(tanh_out) out2 = paddle.mean(tanh_out) @@ -79,14 +79,14 @@ def test_build_add_n_op(self): class TestBuildOp3(unittest.TestCase): def test_insertion_point(self): - newir_program = get_ir_program() + pir_program = get_ir_program() with paddle.pir_utils.IrGuard(): - add_op = newir_program.global_block().ops[-2] - tanh_op = newir_program.global_block().ops[-1] + add_op = pir_program.global_block().ops[-2] + tanh_op = pir_program.global_block().ops[-1] add_out = add_op.result(0) tanh_operand = tanh_op.operands()[0] - with paddle.pir.core.program_guard(newir_program): + with paddle.pir.core.program_guard(pir_program): pir.set_insertion_point(tanh_op) full_out = paddle.tensor.fill_constant( shape=[4, 4], dtype="float", value=2 @@ -96,7 +96,7 @@ def test_insertion_point(self): out = paddle.mean(sum_out) tanh_operand.set_source(out) - print(newir_program) + print(pir_program) self.assertEqual( tanh_operand.source().get_defining_op().name(), "pd_op.mean" ) @@ -104,10 +104,10 @@ def test_insertion_point(self): class TestBuildOp4(unittest.TestCase): def test_build_concat_op(self): - newir_program = get_ir_program() - tanh_out = newir_program.global_block().ops[-1].result(0) + pir_program = get_ir_program() + tanh_out = pir_program.global_block().ops[-1].result(0) with paddle.pir_utils.IrGuard(), paddle.pir.core.program_guard( - newir_program + pir_program ): out = paddle.concat([tanh_out, tanh_out], 0) self.assertEqual(out.get_defining_op().name(), "pd_op.concat") @@ -123,10 +123,10 @@ def test_build_concat_op(self): class TestBuildOp5(unittest.TestCase): def test_build_split_op(self): - newir_program = get_ir_program() - tanh_out = newir_program.global_block().ops[-1].result(0) + pir_program = get_ir_program() + tanh_out = pir_program.global_block().ops[-1].result(0) with paddle.pir_utils.IrGuard(), paddle.pir.core.program_guard( - newir_program + pir_program ): out = paddle.split(tanh_out, [2, 2], 0) self.assertEqual(out[0].get_defining_op().name(), "builtin.split") diff --git a/test/ir/new_ir/test_data_op.py b/test/ir/pir/test_data_op.py similarity index 95% rename from test/ir/new_ir/test_data_op.py rename to test/ir/pir/test_data_op.py index a7659e32486c4..900686e1201b4 100644 --- a/test/ir/new_ir/test_data_op.py +++ b/test/ir/pir/test_data_op.py @@ -36,8 +36,8 @@ def data(): return out -class TestNewIr(unittest.TestCase): - def test_with_new_ir(self): +class TestPir(unittest.TestCase): + def test_with_pir(self): paddle.enable_static() place = paddle.CPUPlace() exe = paddle.static.Executor(place) diff --git a/test/ir/new_ir/test_ir_backward.py b/test/ir/pir/test_ir_backward.py similarity index 74% rename from test/ir/new_ir/test_ir_backward.py rename to test/ir/pir/test_ir_backward.py index e6994d1b0dbef..7f704de11d25c 100644 --- a/test/ir/new_ir/test_ir_backward.py +++ b/test/ir/pir/test_ir_backward.py @@ -32,8 +32,8 @@ def get_ir_program_0(): x_s = paddle.static.data('x', [4, 4], x.dtype) x_s.stop_gradient = False k_s = paddle.tanh(x_s) - newir_program = pir.translate_to_new_ir(main_program.desc) - return newir_program + pir_program = pir.translate_to_pir(main_program.desc) + return pir_program class TesBackward_1(unittest.TestCase): @@ -41,11 +41,11 @@ def tearDown(self) -> None: paddle.framework.set_flags({"FLAGS_enable_pir_api": False}) def test_grad(self): - newir_program = get_ir_program_0() - input = newir_program.global_block().ops[-1].operand(0).source() - tanh_out = newir_program.global_block().ops[-1].result(0) + pir_program = get_ir_program_0() + input = pir_program.global_block().ops[-1].operand(0).source() + tanh_out = pir_program.global_block().ops[-1].result(0) with paddle.pir_utils.IrGuard(), paddle.pir.core.program_guard( - newir_program + pir_program ): out = paddle.mean(tanh_out) out2 = paddle.mean(tanh_out) @@ -66,16 +66,16 @@ def test_grad(self): def test_full(self): # test create output_grad in backward use full op - newir_program = get_ir_program_0() - input = newir_program.global_block().ops[-1].operand(0).source() - tanh_out = newir_program.global_block().ops[-1].result(0) + pir_program = get_ir_program_0() + input = pir_program.global_block().ops[-1].operand(0).source() + tanh_out = pir_program.global_block().ops[-1].result(0) with paddle.pir_utils.IrGuard(), paddle.pir.core.program_guard( - newir_program + pir_program ): out = paddle.mean(tanh_out) input_grad = grad(out, input) self.assertEqual( - newir_program.global_block().ops[-3].name(), "pd_op.full_like" + pir_program.global_block().ops[-3].name(), "pd_op.full_like" ) self.assertEqual( input_grad[0].get_defining_op().name(), "pd_op.tanh_grad" @@ -92,25 +92,25 @@ def test_full(self): def test_no_grad_set(self): # test create output_grad in backward use full op - newir_program = get_ir_program_0() - input = newir_program.global_block().ops[-1].operand(0).source() - tanh_out = newir_program.global_block().ops[-1].result(0) + pir_program = get_ir_program_0() + input = pir_program.global_block().ops[-1].operand(0).source() + tanh_out = pir_program.global_block().ops[-1].result(0) with paddle.pir_utils.IrGuard(), paddle.pir.core.program_guard( - newir_program + pir_program ): out = paddle.mean(tanh_out) input_grad = grad(out, input, no_grad_vars=[input]) self.assertEqual( - newir_program.global_block().ops[-1].name(), "pd_op.full" + pir_program.global_block().ops[-1].name(), "pd_op.full" ) def test_split(self): # test create output_grad in backward use full op - newir_program = get_ir_program_0() - input = newir_program.global_block().ops[-1].operand(0).source() - tanh_out = newir_program.global_block().ops[-1].result(0) + pir_program = get_ir_program_0() + input = pir_program.global_block().ops[-1].operand(0).source() + tanh_out = pir_program.global_block().ops[-1].result(0) with paddle.pir_utils.IrGuard(), paddle.pir.core.program_guard( - newir_program + pir_program ): out = paddle.split(tanh_out, [2, 2], 0) input_grad = grad(out, input) @@ -129,7 +129,7 @@ def test_split(self): "pd_op.concat", "pd_op.tanh_grad", ] - for i, op in enumerate(newir_program.global_block().ops): + for i, op in enumerate(pir_program.global_block().ops): self.assertEqual(op.name(), ops_name[i]) @@ -149,8 +149,8 @@ def get_ir_program_1(): k_s = paddle.tanh(x_s) z_x = paddle.tanh(x_s) out = paddle.add(z_x, k_s) - newir_program = pir.translate_to_new_ir(main_program.desc) - return newir_program + pir_program = pir.translate_to_pir(main_program.desc) + return pir_program class TesBackward_2(unittest.TestCase): @@ -158,33 +158,33 @@ def tearDown(self) -> None: paddle.framework.set_flags({"FLAGS_enable_pir_api": False}) def test_add_n(self): - newir_program = get_ir_program_1() - input_x = newir_program.global_block().ops[-3].operand(0).source() + pir_program = get_ir_program_1() + input_x = pir_program.global_block().ops[-3].operand(0).source() - add_out = newir_program.global_block().ops[-1].result(0) + add_out = pir_program.global_block().ops[-1].result(0) with paddle.pir_utils.IrGuard(), paddle.pir.core.program_guard( - newir_program + pir_program ): out = paddle.mean(add_out) input_grad = grad(out, input_x) self.assertEqual( - newir_program.global_block().ops[-1].name(), "pd_op.add_n" + pir_program.global_block().ops[-1].name(), "pd_op.add_n" ) self.assertEqual( - newir_program.global_block().ops[-1].name(), "pd_op.add_n" + pir_program.global_block().ops[-1].name(), "pd_op.add_n" ) self.assertEqual( - newir_program.global_block().ops[-2].name(), "builtin.combine" + pir_program.global_block().ops[-2].name(), "builtin.combine" ) def test_concat(self): - newir_program = get_ir_program_1() - input_x = newir_program.global_block().ops[-3].operand(0).source() + pir_program = get_ir_program_1() + input_x = pir_program.global_block().ops[-3].operand(0).source() - add_out = newir_program.global_block().ops[-1].result(0) + add_out = pir_program.global_block().ops[-1].result(0) with paddle.pir_utils.IrGuard(), paddle.pir.core.program_guard( - newir_program + pir_program ): out = paddle.concat([add_out, add_out]) input_grad = grad(out, input_x) @@ -210,7 +210,7 @@ def test_concat(self): "builtin.combine", "pd_op.add_n", ] - for i, op in enumerate(newir_program.global_block().ops): + for i, op in enumerate(pir_program.global_block().ops): self.assertEqual(op.name(), ops_name[i]) @@ -225,8 +225,8 @@ def get_ir_program_2(): x_s = paddle.static.data('x', [4, 4], x.dtype) x_s.stop_gradient = False k_s = paddle.sum(x_s, axis=(-1,), keepdim=False) - newir_program = pir.translate_to_new_ir(main_program.desc) - return newir_program + pir_program = pir.translate_to_pir(main_program.desc) + return pir_program class TestBackward_3(unittest.TestCase): @@ -234,11 +234,11 @@ def tearDown(self) -> None: paddle.framework.set_flags({"FLAGS_enable_pir_api": False}) def test_basic_network(self): - newir_program = get_ir_program_2() - x = newir_program.global_block().ops[-1].operand(0).source() - sum_x = newir_program.global_block().ops[-1].result(0) + pir_program = get_ir_program_2() + x = pir_program.global_block().ops[-1].operand(0).source() + sum_x = pir_program.global_block().ops[-1].result(0) with paddle.pir_utils.IrGuard(), paddle.pir.core.program_guard( - newir_program + pir_program ): norm = paddle.tensor.fill_constant( shape=[], diff --git a/test/ir/new_ir/test_ir_pybind.py b/test/ir/pir/test_ir_pybind.py similarity index 74% rename from test/ir/new_ir/test_ir_pybind.py rename to test/ir/pir/test_ir_pybind.py index 6434b0eb65268..4c42c0f6f77ae 100644 --- a/test/ir/new_ir/test_ir_pybind.py +++ b/test/ir/pir/test_ir_pybind.py @@ -32,49 +32,51 @@ def get_ir_program(): y_s = paddle.matmul(x_s, x_s) z_s = paddle.add(y_s, y_s) k_s = paddle.tanh(z_s) - newir_program = pir.translate_to_new_ir(main_program.desc) - return newir_program + q_s = paddle.unsqueeze(k_s, [2]) + + pir_program = pir.translate_to_pir(main_program.desc) + return pir_program class TestPybind(unittest.TestCase): def test_program(self): - newir_program = get_ir_program() - print(newir_program) + pir_program = get_ir_program() + print(pir_program) - block = newir_program.global_block() + block = pir_program.global_block() program = block.program - self.assertEqual(newir_program, program) + self.assertEqual(pir_program, program) def test_block(self): - newir_program = get_ir_program() - block = newir_program.global_block() + pir_program = get_ir_program() + block = pir_program.global_block() ops = block.ops self.assertEqual( - len(ops), 4 + len(ops), 6 ) # pir program add "builtin.get_parameter" by default, so size is 4 - block.remove_op(ops[3]) - self.assertEqual(len(block.ops), 3) + block.remove_op(ops[5]) + self.assertEqual(len(block.ops), 5) def test_operation(self): - newir_program = get_ir_program() - ops = newir_program.global_block().ops - matmul_op = newir_program.global_block().ops[1] - add_op = newir_program.global_block().ops[2] - tanh_op = newir_program.global_block().ops[3] + pir_program = get_ir_program() + ops = pir_program.global_block().ops + matmul_op = pir_program.global_block().ops[1] + add_op = pir_program.global_block().ops[2] + tanh_op = pir_program.global_block().ops[3] parent_block = tanh_op.get_parent_block() parent_ops_num = len(parent_block.ops) - self.assertEqual(parent_ops_num, 4) + self.assertEqual(parent_ops_num, 6) self.assertEqual(tanh_op.num_results(), 1) self.assertEqual(len(matmul_op.get_input_names()), 2) self.assertEqual(len(matmul_op.get_attr_names()), 2) self.assertEqual(len(matmul_op.get_output_names()), 1) def test_value(self): - newir_program = get_ir_program() - matmul_op = newir_program.global_block().ops[1] - add_op = newir_program.global_block().ops[2] - tanh_op = newir_program.global_block().ops[3] + pir_program = get_ir_program() + matmul_op = pir_program.global_block().ops[1] + add_op = pir_program.global_block().ops[2] + tanh_op = pir_program.global_block().ops[3] self.assertEqual( matmul_op.result(0).dtype, paddle.base.core.DataType.FLOAT32 @@ -136,9 +138,9 @@ def test_value(self): self.assertEqual(uninit_op_result.initialized(), False) def test_type(self): - newir_program = get_ir_program() - matmul_op = newir_program.global_block().ops[1] - add_op = newir_program.global_block().ops[2] + pir_program = get_ir_program() + matmul_op = pir_program.global_block().ops[1] + add_op = pir_program.global_block().ops[2] print(matmul_op.result(0).type()) self.assertEqual( matmul_op.result(0).type() == add_op.result(0).type(), True @@ -164,10 +166,10 @@ def test_attr(self): shape=[4, 4], dtype="float32", value=2 ) - newir_program = pir.translate_to_new_ir(main_program.desc) - print(newir_program) - conv_attr = newir_program.global_block().ops[3].attrs() - full_attr = newir_program.global_block().ops[8].attrs() + pir_program = pir.translate_to_pir(main_program.desc) + print(pir_program) + conv_attr = pir_program.global_block().ops[3].attrs() + full_attr = pir_program.global_block().ops[8].attrs() self.assertEqual(conv_attr["stop_gradient"], [False]) self.assertEqual(conv_attr["dilations"], [1, 1]) self.assertEqual(conv_attr["data_format"], "NCHW") @@ -179,17 +181,30 @@ def test_attr(self): self.assertTrue(isinstance(full_attr["place"], paddle.base.core.Place)) def test_operands(self): - newir_program = get_ir_program() - matmul_op = newir_program.global_block().ops[1] + pir_program = get_ir_program() + matmul_op = pir_program.global_block().ops[1] operands = matmul_op.operands() self.assertEqual(len(operands), 2) def test_results(self): - newir_program = get_ir_program() - matmul_op = newir_program.global_block().ops[1] + pir_program = get_ir_program() + matmul_op = pir_program.global_block().ops[1] results = matmul_op.results() self.assertEqual(len(results), 1) + def test_get_output_intermediate_status(self): + pir_program = get_ir_program() + unsqueeze_op = pir_program.global_block().ops[-1] + results = unsqueeze_op.get_output_intermediate_status() + self.assertEqual(results, [False, True]) + + def test_prog_seed(self): + p = pir.Program() + self.assertEqual(p._seed, 0) + + p.global_seed(10) + self.assertEqual(p._seed, 10) + if __name__ == "__main__": unittest.main() diff --git a/test/ir/new_ir/test_ir_vjp.py b/test/ir/pir/test_ir_vjp.py similarity index 67% rename from test/ir/new_ir/test_ir_vjp.py rename to test/ir/pir/test_ir_vjp.py index d0e630fccff72..8401761ba3a05 100644 --- a/test/ir/new_ir/test_ir_vjp.py +++ b/test/ir/pir/test_ir_vjp.py @@ -31,19 +31,25 @@ def get_ir_program(): x.stop_gradient = False paddle.tanh(x) paddle.tensor.fill_constant(shape=[4, 4], dtype='float32', value=2.0) - newir_program = pir.translate_to_new_ir(main_program.desc) - return newir_program + pir_program = pir.translate_to_pir(main_program.desc) + return pir_program class TestTanhVjp(unittest.TestCase): def test_tanh_vjp1(self): - newir_program = get_ir_program() - tanh_op = newir_program.global_block().ops[-2] - fill_constant_op = newir_program.global_block().ops[-1] + pir_program = get_ir_program() + tanh_op = pir_program.global_block().ops[-2] + fill_constant_op = pir_program.global_block().ops[-1] out_grads = [[fill_constant_op.result(0)]] stop_gradients = [[False]] - with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(tanh_op, out_grads, stop_gradients) + with paddle.pir.core.program_guard(pir_program): + grad_outs = call_vjp( + tanh_op, + [[tanh_op.operand_source(0)]], + [[tanh_op.result(0)]], + out_grads, + stop_gradients, + ) self.assertEqual( grad_outs[0][0].get_defining_op().name(), "pd_op.tanh_grad" ) @@ -65,16 +71,22 @@ def test_tanh_vjp1(self): .name(), "pd_op.full", ) - self.assertEqual(len(newir_program.global_block().ops), 4) + self.assertEqual(len(pir_program.global_block().ops), 4) def test_tanh_vjp2(self): - newir_program = get_ir_program() - tanh_op = newir_program.global_block().ops[-2] - fill_constant_op = newir_program.global_block().ops[-1] + pir_program = get_ir_program() + tanh_op = pir_program.global_block().ops[-2] + fill_constant_op = pir_program.global_block().ops[-1] out_grads = [[fill_constant_op.result(0)]] stop_gradients = [[True]] - with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(tanh_op, out_grads, stop_gradients) + with paddle.pir.core.program_guard(pir_program): + grad_outs = call_vjp( + tanh_op, + [[tanh_op.operand_source(0)]], + [[tanh_op.result(0)]], + out_grads, + stop_gradients, + ) self.assertEqual(grad_outs[0][0], None) @@ -89,13 +101,19 @@ def test_mean_vjp1(self): x.stop_gradient = False paddle.mean(x, axis=[0, 1]) paddle.tensor.fill_constant(shape=[1], dtype='float32', value=2.0) - newir_program = pir.translate_to_new_ir(main_program.desc) - fill_constant_op = newir_program.global_block().ops[-1] - mean_op = newir_program.global_block().ops[-2] + pir_program = pir.translate_to_pir(main_program.desc) + fill_constant_op = pir_program.global_block().ops[-1] + mean_op = pir_program.global_block().ops[-2] out_grads = [[fill_constant_op.result(0)]] stop_gradients = [[False]] - with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(mean_op, out_grads, stop_gradients) + with paddle.pir.core.program_guard(pir_program): + grad_outs = call_vjp( + mean_op, + [[mean_op.operand_source(0)]], + [[mean_op.result(0)]], + out_grads, + stop_gradients, + ) self.assertEqual( grad_outs[0][0].get_defining_op().name(), "pd_op.mean_grad" ) @@ -117,7 +135,7 @@ def test_mean_vjp1(self): .name(), "pd_op.full", ) - self.assertEqual(len(newir_program.global_block().ops), 4) + self.assertEqual(len(pir_program.global_block().ops), 4) def test_mean_vjp2(self): main_program, start_program = ( @@ -129,13 +147,19 @@ def test_mean_vjp2(self): x.stop_gradient = False paddle.mean(x, axis=[0, 1]) paddle.tensor.fill_constant(shape=[1], dtype='float32', value=2.0) - newir_program = pir.translate_to_new_ir(main_program.desc) - fill_constant_op = newir_program.global_block().ops[-1] - mean_op = newir_program.global_block().ops[-2] + pir_program = pir.translate_to_pir(main_program.desc) + fill_constant_op = pir_program.global_block().ops[-1] + mean_op = pir_program.global_block().ops[-2] out_grads = [[fill_constant_op.result(0)]] stop_gradients = [[True]] - with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(mean_op, out_grads, stop_gradients) + with paddle.pir.core.program_guard(pir_program): + grad_outs = call_vjp( + mean_op, + [[mean_op.operand_source(0)]], + [[mean_op.result(0)]], + out_grads, + stop_gradients, + ) self.assertEqual(grad_outs[0][0], None) @@ -150,9 +174,9 @@ def test_has_vjp(self): x.stop_gradient = False paddle.mean(x, axis=[0, 1]) paddle.tensor.fill_constant(shape=[1], dtype='float32', value=2.0) - newir_program = pir.translate_to_new_ir(main_program.desc) - fill_constant_op = newir_program.global_block().ops[-1] - mean_op = newir_program.global_block().ops[-2] + pir_program = pir.translate_to_pir(main_program.desc) + fill_constant_op = pir_program.global_block().ops[-1] + mean_op = pir_program.global_block().ops[-2] self.assertEqual(has_vjp(fill_constant_op), False) self.assertEqual(has_vjp(mean_op), True) diff --git a/test/ir/new_ir/test_override_operator.py b/test/ir/pir/test_override_operator.py similarity index 100% rename from test/ir/new_ir/test_override_operator.py rename to test/ir/pir/test_override_operator.py diff --git a/test/ir/new_ir/test_pass_manager.py b/test/ir/pir/test_pass_manager.py similarity index 92% rename from test/ir/new_ir/test_pass_manager.py rename to test/ir/pir/test_pass_manager.py index 5849b0bbdfeff..b9ca872e92d34 100644 --- a/test/ir/new_ir/test_pass_manager.py +++ b/test/ir/pir/test_pass_manager.py @@ -45,18 +45,18 @@ def test_op(self): attrs={"name": out.name}, ) - new_program = pir.translate_to_new_ir(main_program.desc) + new_program = pir.translate_to_pir(main_program.desc) op_names = [op.name() for op in new_program.global_block().ops] # print(op_names) self.assertTrue('pd_op.uniform' in op_names) pm = pir.PassManager() pm.add_pass( - 'dead_code_elimination' + 'dead_code_elimination_pass' ) # apply pass to elimitate dead code pm.run(new_program) op_names = [op.name() for op in new_program.global_block().ops] # print(op_names) - self.assertEqual(pm.passes(), ['dead_code_elimination']) + self.assertEqual(pm.passes(), ['dead_code_elimination_pass']) self.assertFalse(pm.empty()) self.assertTrue( 'pd_op.uniform' not in op_names diff --git a/test/ir/new_ir/test_pd_inplace_pass.py b/test/ir/pir/test_pd_inplace_pass.py similarity index 100% rename from test/ir/new_ir/test_pd_inplace_pass.py rename to test/ir/pir/test_pd_inplace_pass.py diff --git a/test/ir/new_ir/test_new_ir_to_static.py b/test/ir/pir/test_pir_to_static.py similarity index 90% rename from test/ir/new_ir/test_new_ir_to_static.py rename to test/ir/pir/test_pir_to_static.py index eb40ae632790f..c4dd6e4924ec0 100644 --- a/test/ir/new_ir/test_new_ir_to_static.py +++ b/test/ir/pir/test_pir_to_static.py @@ -19,13 +19,13 @@ import paddle -class TestDy2staticNewIR(unittest.TestCase): +class TestDy2staticPir(unittest.TestCase): def test_basic_network(self): def func(x): out = paddle.mean(x) return out - static_func = paddle.jit.to_static(func) + static_func = paddle.jit.to_static(func, full_graph=True) x = paddle.randn((3, 3)) y = paddle.randn((3, 3)) x.stop_gradient = False @@ -43,7 +43,7 @@ def func(x): return out # ==== dygraph computation ==== - static_func = paddle.jit.to_static(func) + static_func = paddle.jit.to_static(func, full_graph=True) x = paddle.randn((3, 3)) y = paddle.randn((3, 3)) x.stop_gradient = False @@ -64,7 +64,7 @@ def func(x): ) -class TestDy2staticNewIR2(unittest.TestCase): +class TestDy2staticPir2(unittest.TestCase): def test_basic_layer(self): class SimpleNet(paddle.nn.Layer): def __init__(self): @@ -78,14 +78,14 @@ def forward(self, x): x = paddle.randn((10, 10)) x.stop_gradient = False ans = net(x) - net = paddle.jit.to_static(net) + net = paddle.jit.to_static(net, full_graph=True) out = net(x) np.testing.assert_allclose( out.numpy(), ans.numpy(), rtol=1e-05, atol=1e-8 ) -class TestDy2staticNewIR3(unittest.TestCase): +class TestDy2staticPir3(unittest.TestCase): def test_complex_layer(self): def output_pure_func(x, y): outx = paddle.mean(x) @@ -101,7 +101,7 @@ def run_function(to_static=True): y.stop_gradient = True func = output_pure_func if to_static: - func = paddle.jit.to_static(func) + func = paddle.jit.to_static(func, full_graph=True) y, y_mean = func(x, y) loss = y.mean() loss.backward() @@ -134,7 +134,7 @@ def train_step(to_static=True): learning_rate=0.1, parameters=net.parameters() ) if to_static: - net = paddle.jit.to_static(net) + net = paddle.jit.to_static(net, full_graph=True) losses = [] for step in range(100): y_pred = net(x) @@ -152,7 +152,7 @@ def train_step(to_static=True): ) -class TestDy2staticNewIR5(unittest.TestCase): +class TestDy2staticPir5(unittest.TestCase): def test_run(self): # Dy2static RunProgramOp support nn.Layer's forward and backward training. class SimpleNet(paddle.nn.Layer): @@ -177,7 +177,7 @@ def train_step(to_static=True): learning_rate=0.1, parameters=net.parameters() ) if to_static: - net = paddle.jit.to_static(net) + net = paddle.jit.to_static(net, full_graph=True) losses = [] for step in range(100): y_pred = net(x, step % 2 == 1) @@ -195,7 +195,7 @@ def train_step(to_static=True): ) -class TestDy2staticNewIR6(unittest.TestCase): +class TestDy2staticPir6(unittest.TestCase): # test basic-indexing __getitem__ for OpResult def test_basic_network(self): def func(x): @@ -203,7 +203,7 @@ def func(x): out = shape[1:] return out - static_func = paddle.jit.to_static(func) + static_func = paddle.jit.to_static(func, full_graph=True) x = paddle.randn((2, 3, 4)) x.stop_gradient = False ans = func(x) diff --git a/test/ir/new_ir/test_special_op_translator.py b/test/ir/pir/test_special_op_translator.py similarity index 81% rename from test/ir/new_ir/test_special_op_translator.py rename to test/ir/pir/test_special_op_translator.py index 30757fbeb95bd..415ff2513b2f1 100644 --- a/test/ir/new_ir/test_special_op_translator.py +++ b/test/ir/pir/test_special_op_translator.py @@ -35,10 +35,75 @@ def test_op(self): x = paddle.to_tensor([2, 3, 4], 'float64') y = paddle.cast(x, 'uint8') - _, mappings = pir.translate_to_new_ir_with_param_map(main_program.desc) + _, mappings = pir.translate_to_pir_with_param_map(main_program.desc) assert len(str(mappings)) > 0, "no mapping found" +class TestCondWithInplace(unittest.TestCase): + def test_op(self): + def cond_with_inplace(): + x = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + y = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + running_mean = paddle.to_tensor([0], dtype="float32") + running_variance = paddle.to_tensor([1], dtype="float32") + weight = paddle.to_tensor([2], dtype="float32") + bias = paddle.to_tensor([1], dtype="float32") + if x > y: + y = paddle.nn.functional.batch_norm( + x, running_mean, running_variance, weight, bias + ) + else: + y = paddle.nn.functional.batch_norm( + x, running_mean, running_variance, weight, bias + ) + + legacy_program = paddle.jit.to_static( + cond_with_inplace, + input_spec=[], + full_graph=True, + ) + + l = pir.translate_to_pir(legacy_program.main_program.desc) + assert l is not None + + def test_nested_op(self): + def cond_with_inplace(): + x = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + y = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + z = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + running_mean = paddle.to_tensor([0], dtype="float32") + running_variance = paddle.to_tensor([1], dtype="float32") + weight = paddle.to_tensor([2], dtype="float32") + bias = paddle.to_tensor([1], dtype="float32") + if x > y: + if y > z: + z = paddle.nn.functional.batch_norm( + z, running_mean, running_variance, weight, bias + ) + else: + y = paddle.nn.functional.batch_norm( + x, running_mean, running_variance, weight, bias + ) + else: + if y > z: + z = paddle.nn.functional.batch_norm( + z, running_mean, running_variance, weight, bias + ) + else: + y = paddle.nn.functional.batch_norm( + x, running_mean, running_variance, weight, bias + ) + + legacy_program = paddle.jit.to_static( + cond_with_inplace, + input_spec=[], + full_graph=True, + ) + + l = pir.translate_to_pir(legacy_program.main_program.desc) + assert l is not None + + class TestElementwiseOpTranscriber(unittest.TestCase): def test_elementwise_without_y_grad(self): place = core.Place() @@ -120,7 +185,7 @@ def test_add_inplace(self): outputs={"Out": y}, attrs={"axis": -1}, ) - _ = pir.translate_to_new_ir(main_program.desc) + _ = pir.translate_to_pir(main_program.desc) class TestEmbeddingOpTranscriber(unittest.TestCase): @@ -137,7 +202,7 @@ def test_op(self): ) output = embedding(x) - _ = pir.translate_to_new_ir(main_program.desc) + _ = pir.translate_to_pir(main_program.desc) class TestIncrementOpTranscriber(unittest.TestCase): @@ -151,7 +216,7 @@ def test_op(self): data = paddle.zeros(shape=[1], dtype='float32') counter = paddle.increment(data) - _ = pir.translate_to_new_ir(main_program.desc) + _ = pir.translate_to_pir(main_program.desc) class TestAssignValueOpTranscriber(unittest.TestCase): @@ -168,7 +233,7 @@ def test_op(self): stop_gradient=False, ) - _ = pir.translate_to_new_ir(main_program.desc) + _ = pir.translate_to_pir(main_program.desc) class TestRnnOpTranscriber(unittest.TestCase): @@ -185,7 +250,7 @@ def test_op(self): cell = paddle.nn.SimpleRNNCell(16, 32) y, h = cell(x, prev_h) - _ = pir.translate_to_new_ir(main_program.desc) + _ = pir.translate_to_pir(main_program.desc) class TestEmptyVarTranslate(unittest.TestCase): @@ -207,7 +272,7 @@ def test_op(self): out2 = paddle.mean(out1) sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.1) sgd_optimizer.minimize(out2) - _ = pir.translate_to_new_ir(main_program.desc) + _ = pir.translate_to_pir(main_program.desc) class TestOneHotOpTranscriber(unittest.TestCase): @@ -226,7 +291,7 @@ def test_mutable_attribute(self): x=label, num_classes=depth ) - _ = pir.translate_to_new_ir(main_program.desc) + _ = pir.translate_to_pir(main_program.desc) def test_normal_attribute(self): place = core.Place() @@ -243,7 +308,7 @@ def test_normal_attribute(self): x=label, num_classes=depth ) - _ = pir.translate_to_new_ir(main_program.desc) + _ = pir.translate_to_pir(main_program.desc) class TestReduceOpTranscriber(unittest.TestCase): @@ -293,7 +358,7 @@ def test_op(self): value = paddle.randn([2]) y = paddle.index_put(x, indices, value, False) - _ = pir.translate_to_new_ir(main_program.desc) + _ = pir.translate_to_pir(main_program.desc) class TestGradAddOpTranscriber(unittest.TestCase): @@ -319,7 +384,7 @@ def test_op(self): attrs={"axis": -1}, ) - _ = pir.translate_to_new_ir(main_program.desc) + _ = pir.translate_to_pir(main_program.desc) class TestShadowOutputSlice(unittest.TestCase): @@ -344,7 +409,7 @@ def test_op(self): attrs={"name": out.name}, ) - l = pir.translate_to_new_ir(main_program.desc) + l = pir.translate_to_pir(main_program.desc) class TestSetValueOp(unittest.TestCase): @@ -448,12 +513,30 @@ def test_program(self): inputs={"X": x}, outputs={"Out": y, "XOut": x}, ) - l = pir.translate_to_new_ir(main_program.desc) + l = pir.translate_to_pir(main_program.desc) assert ( l.global_block().ops[2].name() == "pd_op.share_data" ), "share_buffer should be translated to share_data" +class TestDataOp(unittest.TestCase): + def test_data_op(self): + place = core.Place() + place.set_place(paddle.CPUPlace()) + + new_scope = paddle.static.Scope() + main_program = paddle.static.Program() + with paddle.static.scope_guard(new_scope): + with paddle.static.program_guard(main_program): + _ = paddle.static.data(name="y", shape=[3, 9, 5], dtype="int64") + l = pir.translate_to_pir(main_program.desc) + self.assertTrue(len(l.global_block().ops) > 0) + self.assertTrue(l.global_block().ops[0].name() == "pd_op.data") + data_op = l.global_block().ops[0] + self.assertIn("dtype", data_op.attrs()) + self.assertEqual(str(data_op.attrs()["dtype"]), "DataType.INT64") + + class TestCheckUnregisteredOp(unittest.TestCase): def test_program(self): main_program = paddle.static.Program() diff --git a/test/ir/new_ir/test_standalone_new_ir.py b/test/ir/pir/test_standalone_pir.py similarity index 92% rename from test/ir/new_ir/test_standalone_new_ir.py rename to test/ir/pir/test_standalone_pir.py index 51843b8b5037e..298e40b3299e1 100644 --- a/test/ir/new_ir/test_standalone_new_ir.py +++ b/test/ir/pir/test_standalone_pir.py @@ -21,8 +21,8 @@ import paddle -class TestNewIr(unittest.TestCase): - def test_with_new_ir(self): +class TestPir(unittest.TestCase): + def test_with_pir(self): paddle.enable_static() place = ( paddle.CUDAPlace(0) @@ -47,7 +47,7 @@ def test_with_new_ir(self): class TestCombineOp(unittest.TestCase): - def test_with_new_ir(self): + def test_with_pir(self): paddle.enable_static() place = ( paddle.CUDAPlace(0) @@ -73,7 +73,7 @@ def test_with_new_ir(self): class TestFeedOp(unittest.TestCase): - def test_with_new_ir(self): + def test_with_pir(self): paddle.enable_static() place = ( paddle.CUDAPlace(0) @@ -105,7 +105,7 @@ def test_with_new_ir(self): class TestSelectedRows(unittest.TestCase): - def test_with_new_ir(self): + def test_with_pir(self): # TODO(phlrain): support selected rows in GPU paddle.enable_static() place = paddle.CPUPlace() @@ -129,7 +129,7 @@ def test_with_new_ir(self): class TestAddGradOp(unittest.TestCase): - def test_with_new_ir(self): + def test_with_pir(self): paddle.enable_static() place = ( paddle.CUDAPlace(0) @@ -163,8 +163,8 @@ def test_with_new_ir(self): np.testing.assert_array_equal(out[0], gold_res) -class TestNewIrDygraph(unittest.TestCase): - def test_with_new_ir(self): +class TestPirDygraph(unittest.TestCase): + def test_with_pir(self): paddle.disable_static() @paddle.jit.to_static @@ -179,8 +179,8 @@ def func(x, y): np.testing.assert_array_equal(z.numpy(), gold_res) -class TestNewIrBackwardDygraph(unittest.TestCase): - def test_with_new_ir(self): +class TestPirBackwardDygraph(unittest.TestCase): + def test_with_pir(self): paddle.disable_static() build_strategy = paddle.static.BuildStrategy() build_strategy.enable_inplace = False @@ -204,8 +204,8 @@ def func(x, y): np.testing.assert_array_equal(y.gradient(), gold_res) -class TestNewIrReshapeBackwardDygraph(unittest.TestCase): - def test_with_new_ir(self): +class TestPirReshapeBackwardDygraph(unittest.TestCase): + def test_with_pir(self): paddle.disable_static() build_strategy = paddle.static.BuildStrategy() build_strategy.enable_inplace = False @@ -233,7 +233,7 @@ def func(x, y): class TestSplitOp(unittest.TestCase): - def test_with_new_ir(self): + def test_with_pir(self): paddle.enable_static() place = ( paddle.CUDAPlace(0) @@ -260,8 +260,8 @@ def test_with_new_ir(self): np.testing.assert_array_equal(out[0], np_a[0:2]) -class TestNewIrPrint(unittest.TestCase): - def test_with_new_ir(self): +class TestPirPrint(unittest.TestCase): + def test_with_pir(self): paddle.enable_static() place = ( paddle.CUDAPlace(0) @@ -290,12 +290,12 @@ def test_with_new_ir(self): class TestJitSaveOp(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() - self.model_path = os.path.join(self.temp_dir.name, "new_ir_save_load") + self.model_path = os.path.join(self.temp_dir.name, "pir_save_load") def tearDown(self): self.temp_dir.cleanup() - def test_with_new_ir(self): + def test_with_pir(self): paddle.disable_static() linear = paddle.nn.Linear(10, 10) @@ -328,8 +328,8 @@ def test_with_new_ir(self): ) -class TestNewIrConcatDygraph(unittest.TestCase): - def test_with_new_ir(self): +class TestPirConcatDygraph(unittest.TestCase): + def test_with_pir(self): paddle.disable_static() @paddle.jit.to_static @@ -346,8 +346,8 @@ def func(x, y): # TODO(phlrain): open this after fix pr(55509) confict -# class TestNewIrLogicalDygraph(unittest.TestCase): -# def test_with_new_ir(self): +# class TestPirLogicalDygraph(unittest.TestCase): +# def test_with_pir(self): # paddle.disable_static() # @paddle.jit.to_static diff --git a/test/ir/new_ir/test_stop_gradient.py b/test/ir/pir/test_stop_gradient.py similarity index 100% rename from test/ir/new_ir/test_stop_gradient.py rename to test/ir/pir/test_stop_gradient.py diff --git a/test/ir/new_ir/test_symbol_overload.py b/test/ir/pir/test_symbol_overload.py similarity index 100% rename from test/ir/new_ir/test_symbol_overload.py rename to test/ir/pir/test_symbol_overload.py diff --git a/test/ir/test_op_input_grad_semantic.py b/test/ir/test_op_input_grad_semantic.py index ab1ca4f61d191..ab4ca0c2c347b 100644 --- a/test/ir/test_op_input_grad_semantic.py +++ b/test/ir/test_op_input_grad_semantic.py @@ -20,7 +20,7 @@ paddle.enable_static() -def get_gather_program_new_ir(): +def get_gather_program_pir(): main_program, start_program = ( paddle.static.Program(), paddle.static.Program(), @@ -32,11 +32,11 @@ def get_gather_program_new_ir(): index = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=1.0) axis = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=2.0) out = paddle.gather(x, index, axis) - newir_program = pir.translate_to_new_ir(main_program.desc) - return newir_program + pir_program = pir.translate_to_pir(main_program.desc) + return pir_program -def get_multiply_program_new_ir(): +def get_multiply_program_pir(): main_program, start_program = ( paddle.static.Program(), paddle.static.Program(), @@ -49,21 +49,21 @@ def get_multiply_program_new_ir(): shape=[3, 4], dtype='float32', value=3.0 ) out = paddle.multiply(x, y) - newir_program = pir.translate_to_new_ir(main_program.desc) - return newir_program + pir_program = pir.translate_to_pir(main_program.desc) + return pir_program class TestOpInputGradSemantic(unittest.TestCase): def test_gather_op_input_grad_semantic(self): - newir_program = get_gather_program_new_ir() - gather_op = newir_program.global_block().ops[-1] + pir_program = get_gather_program_pir() + gather_op = pir_program.global_block().ops[-1] self.assertEqual( gather_op.get_input_grad_semantics(), [True, False, False] ) def test_multiply_op_input_grad_semantic(self): - newir_program = get_multiply_program_new_ir() - multiply_op = newir_program.global_block().ops[-1] + pir_program = get_multiply_program_pir() + multiply_op = pir_program.global_block().ops[-1] self.assertEqual(multiply_op.get_input_grad_semantics(), [True, True]) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 96a15b04ab8a2..03860aaa01a20 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -589,10 +589,9 @@ py_test_modules( py_test_modules(test_install_check MODULES test_install_check ENVS FLAGS_cudnn_deterministic=1) set_tests_properties(test_install_check PROPERTIES LABELS "RUN_TYPE=DIST") -py_test_modules(test_install_check_new_ir MODULES test_install_check ENVS - FLAGS_cudnn_deterministic=1 FLAGS_enable_new_ir_in_executor=1) -set_tests_properties(test_install_check_new_ir PROPERTIES LABELS - "RUN_TYPE=DIST") +py_test_modules(test_install_check_pir MODULES test_install_check ENVS + FLAGS_cudnn_deterministic=1 FLAGS_enable_pir_in_executor=1) +set_tests_properties(test_install_check_pir PROPERTIES LABELS "RUN_TYPE=DIST") if((WITH_GPU) AND (CUDA_VERSION GREATER_EQUAL 11.6)) py_test_modules(test_fused_gemm_epilogue_op MODULES @@ -1344,18 +1343,16 @@ foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS}) FLAGS_new_executor_static_build=true) endforeach() -set(NEW_IR_COVERAGE_TESTS test_fused_feedforward_pass) +set(PIR_COVERAGE_TESTS test_fused_feedforward_pass) if(NOT WITH_GPU) - list(REMOVE_ITEM NEW_IR_COVERAGE_TESTS test_fused_feedforward_pass) + list(REMOVE_ITEM PIR_COVERAGE_TESTS test_fused_feedforward_pass) endif() -foreach(NEW_IR_COVERAGE_TEST ${NEW_IR_COVERAGE_TESTS}) - py_test_modules( - ${NEW_IR_COVERAGE_TEST}_new_ir MODULES ${NEW_IR_COVERAGE_TEST} ENVS - FLAGS_enable_new_ir_in_executor=true) - set_tests_properties(${NEW_IR_COVERAGE_TEST}_new_ir PROPERTIES TIMEOUT 120) - message( - STATUS "NewIR Copied OpTest: ${NEW_IR_COVERAGE_TEST}_new_ir in legacy_test") +foreach(PIR_COVERAGE_TEST ${PIR_COVERAGE_TESTS}) + py_test_modules(${PIR_COVERAGE_TEST}_pir MODULES ${PIR_COVERAGE_TEST} ENVS + FLAGS_enable_pir_in_executor=true) + set_tests_properties(${PIR_COVERAGE_TEST}_pir PROPERTIES TIMEOUT 120) + message(STATUS "PIR Copied OpTest: ${PIR_COVERAGE_TEST}_pir in legacy_test") endforeach() set_tests_properties(test_decoupled_py_reader_static_build PROPERTIES TIMEOUT @@ -1380,29 +1377,37 @@ set_tests_properties(test_sync_batch_norm_op_static_build set_tests_properties(test_sync_batch_norm_op_static_build PROPERTIES TIMEOUT 250) -file(STRINGS "${CMAKE_SOURCE_DIR}/test/white_list/new_ir_op_test_white_list" - NEW_IR_OP_TESTS) -foreach(IR_OP_TEST ${NEW_IR_OP_TESTS}) +file(STRINGS "${CMAKE_SOURCE_DIR}/test/white_list/pir_op_test_white_list" + PIR_OP_TESTS) +foreach(IR_OP_TEST ${PIR_OP_TESTS}) if(TEST ${IR_OP_TEST}) set_tests_properties( - ${IR_OP_TEST} PROPERTIES ENVIRONMENT - "FLAGS_NEW_IR_OPTEST_WHITE_LIST=True") + ${IR_OP_TEST} PROPERTIES ENVIRONMENT "FLAGS_PIR_OPTEST_WHITE_LIST=True") + else() + message(STATUS "PIR OpTest: not found ${IR_OP_TEST} in legacy_test") + endif() +endforeach() + +file(STRINGS "${CMAKE_SOURCE_DIR}/test/white_list/pir_op_test_no_check_list" + PIR_OP_NO_CHECK_TESTS) +foreach(IR_OP_TEST ${PIR_OP_NO_CHECK_TESTS}) + if(TEST ${IR_OP_TEST}) + set_tests_properties(${IR_OP_TEST} PROPERTIES ENVIRONMENT + "FLAGS_PIR_NO_CHECK=True") else() - message(STATUS "NewIR OpTest: not found ${IR_OP_TEST} in legacy_test") + message(STATUS "PIR OpTest: not found ${IR_OP_TEST} in legacy_test") endif() endforeach() file(STRINGS - "${CMAKE_SOURCE_DIR}/test/white_list/new_ir_op_test_precision_white_list" - NEW_IR_OP_RELAXED_TESTS) -foreach(IR_OP_TEST ${NEW_IR_OP_RELAXED_TESTS}) + "${CMAKE_SOURCE_DIR}/test/white_list/pir_op_test_precision_white_list" + PIR_OP_RELAXED_TESTS) +foreach(IR_OP_TEST ${PIR_OP_RELAXED_TESTS}) if(TEST ${IR_OP_TEST}) set_tests_properties( - ${IR_OP_TEST} PROPERTIES ENVIRONMENT - "FLAGS_NEW_IR_OPTEST_RELAX_CHECK=True") + ${IR_OP_TEST} PROPERTIES ENVIRONMENT "FLAGS_PIR_OPTEST_RELAX_CHECK=True") else() - message( - STATUS "NewIR Relaxed OpTest: not found ${IR_OP_TEST} in legacy_test") + message(STATUS "PIR Relaxed OpTest: not found ${IR_OP_TEST} in legacy_test") endif() endforeach() diff --git a/test/legacy_test/auto_parallel_gpt_model.py b/test/legacy_test/auto_parallel_gpt_model.py index 3f25aeb19b64c..ebecb08d45c7f 100644 --- a/test/legacy_test/auto_parallel_gpt_model.py +++ b/test/legacy_test/auto_parallel_gpt_model.py @@ -266,7 +266,9 @@ def forward( if self.use_new_recompute and self.recompute_granularity == "core_attn": out, weights = auto.recompute(self.core_attn)(q, k, v, attn_mask) else: - out, weights = self.core_attn(q, k, v, attn_mask) + out, weights = auto.exclude_ops_in_recompute(self.core_attn)( + q, k, v, attn_mask + ) # project to output out = self.out_proj(out) diff --git a/test/legacy_test/distributed_fused_lamb_test_base.py b/test/legacy_test/distributed_fused_lamb_test_base.py index ea011becc9090..348191e66d7d5 100644 --- a/test/legacy_test/distributed_fused_lamb_test_base.py +++ b/test/legacy_test/distributed_fused_lamb_test_base.py @@ -270,7 +270,10 @@ def setUpClass(cls): paddle.enable_static() paddle.set_flags({'FLAGS_cudnn_deterministic': True}) _clip_by_global_norm_using_mp_type(True) - if os.environ.get("FLAGS_dynamic_static_unified_comm") == "1": + if ( + os.environ.get("FLAGS_dynamic_static_unified_comm", "false").lower() + == "true" + ): paddle.distributed.collective._init_parallel_env("nccl") else: fleet.init(role_maker=get_role_maker()) diff --git a/test/legacy_test/gradient_checker.py b/test/legacy_test/gradient_checker.py index 67e18075e60a0..ac01c0756287f 100644 --- a/test/legacy_test/gradient_checker.py +++ b/test/legacy_test/gradient_checker.py @@ -224,6 +224,7 @@ def grad_check( x_init=None, place=None, program=None, + scope=None, eps=1e-6, atol=1e-5, rtol=1e-3, @@ -254,40 +255,6 @@ def fail_test(msg): raise RuntimeError(msg) return False - # check input arguments - x = _as_list(x) - y = _as_list(y) - - for v in x: - v.stop_gradient = False - v.persistable = True - for u in y: - u.stop_gradient = False - u.persistable = True - if place is None: - place = base.CPUPlace() - if program is None: - program = base.default_main_program() - - # init variable in startup program - scope = base.executor.global_scope() - exe = base.Executor(place) - exe.run(base.default_startup_program()) - - x_init = _as_list(x_init) - # init inputs if x_init is not None - if x_init: - if len(x_init) != len(x): - raise ValueError( - 'len(x_init) (=%d) is not the same' - ' as len(x) (= %d)' % (len(x_init), len(x)) - ) - # init variable in main program - for var, arr in zip(x, x_init): - assert var.shape == arr.shape - feeds = {k.name: v for k, v in zip(x, x_init)} - exe.run(program, feed=feeds, scope=scope) - # [x_idx, y_idx] numerical = [ _compute_numerical_jacobian(program, xi, y, place, scope, eps) @@ -371,23 +338,12 @@ def double_grad_check( u.stop_gradient = False u.persistable = True - if program is None: - program = base.default_main_program() - + scope = base.executor.global_scope() if y_grads is None: - scope = base.executor.global_scope() - y_grads = [] y_grads_init = [] for yi in y: - dyi_name = _append_grad_suffix_(yi.name) np_type = dtype_to_np_dtype(yi.dtype) - dy = program.global_block().create_var( - name=dyi_name, shape=yi.shape, dtype=np_type, persistable=True - ) - dy.stop_gradient = False v = np.random.random(size=yi.shape).astype(np_type) - set_var_in_scope(scope, place, dyi_name, v) - y_grads.append(dy) y_grads_init.append(v) else: y_grads = _as_list(y_grads) @@ -395,16 +351,13 @@ def double_grad_check( var_to_np_array_in_scope(scope, place, v.name) for v in y_grads ] - # append first order grads - target_grads = base.gradients(y, x, y_grads) - - # y_grads are the input of first-order backward, - # so, they are also the input of second-order backward. - x += y_grads x_init = _as_list(x_init) - x_init += y_grads_init - grad_check(x, target_grads, x_init, place, program, eps, atol, rtol) + grad_res, x, target_grads, program, scope = get_static_double_grad( + x, y, x_init, y_grads_init, place + ) + + grad_check(x, target_grads, x_init, place, program, scope, eps, atol, rtol) # TODO(jiabin): We currently support only triple grad check here, extend this to support @@ -456,23 +409,12 @@ def triple_grad_check( u.stop_gradient = False u.persistable = True - if program is None: - program = base.default_main_program() - + scope = base.executor.global_scope() if y_grads is None: - scope = base.executor.global_scope() - y_grads = [] y_grads_init = [] for yi in y: - dyi_name = _append_grad_suffix_(yi.name) np_type = dtype_to_np_dtype(yi.dtype) - dy = program.global_block().create_var( - name=dyi_name, shape=yi.shape, dtype=np_type, persistable=True - ) - dy.stop_gradient = False v = np.random.random(size=yi.shape).astype(np_type) - set_var_in_scope(scope, place, dyi_name, v) - y_grads.append(dy) y_grads_init.append(v) else: y_grads = _as_list(y_grads) @@ -480,52 +422,18 @@ def triple_grad_check( var_to_np_array_in_scope(scope, place, v.name) for v in y_grads ] - # append first order grads - target_grads = base.gradients(y, x, y_grads) - - if x_grads_grads is None: - scope = base.executor.global_scope() - x_grads_grads = [] - x_grads_grads_init = [] - for dxi in target_grads: - ddxi_name = _append_grad_suffix_(dxi.name) - np_type = dtype_to_np_dtype(dxi.dtype) - ddx = program.global_block().create_var( - name=ddxi_name, shape=dxi.shape, dtype=np_type, persistable=True - ) - ddx.stop_gradient = False - v = np.random.random(size=dxi.shape).astype(np_type) - set_var_in_scope(scope, place, ddxi_name, v) - x_grads_grads.append(ddx) - x_grads_grads_init.append(v) - else: - x_grads_grads = _as_list(x_grads_grads) - x_grads_grads_init = [ - var_to_np_array_in_scope(scope, place, v.name) - for v in x_grads_grads - ] - x += y_grads x_init = _as_list(x_init) - x_init += y_grads_init - - # append second order grads - target_grads_grads = base.gradients(target_grads, x, x_grads_grads) - - # filter None in target_grads_grads for Dy/Dx may be None in kernel - filted = [ - (i, dyi) for i, dyi in enumerate(target_grads_grads) if dyi is not None - ] - filted_idx, filted_target_grads_grads = zip(*filted) - - x += x_grads_grads - x_init += x_grads_grads_init # x <=> [x, dout, ddx] + grad_res, x, target_grads_grads, program, scope = get_static_triple_grad( + x, y, x_init, y_grads_init, place + ) grad_check( x=x, - y=filted_target_grads_grads, + y=target_grads_grads, x_init=x_init, place=place, + scope=scope, program=program, eps=eps, atol=atol, @@ -610,7 +518,6 @@ def get_static_double_grad( for var, arr in zip(x, x_init): assert var.shape == arr.shape feeds = {k.name: v for k, v in zip(x, x_init)} - exe.run(program, feed=feeds, scope=scope) dys = [] for yi in y: @@ -633,9 +540,9 @@ def get_static_double_grad( # only fetch not None dx in exe.run filted = [(i, dxi) for i, dxi in enumerate(ddx) if dxi is not None] filted_idx, filted_ddx = zip(*filted) - ddx_res = exe.run(program, scope=scope, fetch_list=filted_ddx) + ddx_res = exe.run(program, feed=feeds, scope=scope, fetch_list=filted_ddx) - return ddx_res + return ddx_res, x, filted_dx, program, scope def get_eager_double_grad( @@ -767,7 +674,7 @@ def fail_test(msg): eager_double_grad = get_eager_double_grad(func, x_init, y_grads_init, place) paddle.enable_static() - static_double_grad = get_static_double_grad( + static_double_grad, _, _, _, _ = get_static_double_grad( x, y, x_init, y_grads_init, place ) @@ -930,7 +837,7 @@ def fail_test(msg): eager_triple_grad = get_eager_triple_grad(func, x_init, y_grads_init, place) paddle.enable_static() - static_triple_grad = get_static_triple_grad( + static_triple_grad, _, _, _, _ = get_static_triple_grad( x, y, x_init, y_grads_init, place ) diff --git a/test/legacy_test/op_test.py b/test/legacy_test/op_test.py index 873e38110dd3e..b181e549eed7c 100644 --- a/test/legacy_test/op_test.py +++ b/test/legacy_test/op_test.py @@ -95,9 +95,7 @@ def check_out_dtype(api_fn, in_specs, expect_dtypes, target_index=0, **configs): shape, dtype = spec else: raise ValueError( - "Value of in_specs[{}] should contains two elements: [shape, dtype]".format( - index - ) + f"Value of in_specs[{index}] should contains two elements: [shape, dtype]" ) input_t.append( paddle.static.data( @@ -1300,9 +1298,7 @@ def _need_fetch(self, sig_name): return True return False - def _calc_new_ir_output( - self, place, no_check_set=None, inps=None, oups=None - ): + def _calc_pir_output(self, place, no_check_set=None, inps=None, oups=None): """set egr_inps and egr_oups = None if you want to create it by yourself.""" def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): @@ -1367,6 +1363,8 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): ret_tuple, paddle.base.libpaddle.pir.OpResult ): fetch_list.append(ret_tuple) + elif ret_tuple is None: + pass else: raise ValueError( "output of python api should be OpResult or list of OpResult or tuple of OpResult" @@ -1393,9 +1391,9 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): return result def _check_ir_output(self, place, program, feed_map, fetch_list, outs): - if os.getenv("FLAGS_NEW_IR_OPTEST") is None: + if os.getenv("FLAGS_PIR_OPTEST") is None: return - if os.getenv("FLAGS_NEW_IR_OPTEST_WHITE_LIST") is None: + if os.getenv("FLAGS_PIR_OPTEST_WHITE_LIST") is None: return if self.check_prim or self.check_prim_pir: return @@ -1403,15 +1401,15 @@ def _check_ir_output(self, place, program, feed_map, fetch_list, outs): return stored_flag = get_flags( [ - 'FLAGS_enable_new_ir_in_executor', - "FLAGS_new_ir_apply_inplace_pass", + 'FLAGS_enable_pir_in_executor', + "FLAGS_pir_apply_inplace_pass", ] ) try: set_flags( { - "FLAGS_enable_new_ir_in_executor": True, - "FLAGS_new_ir_apply_inplace_pass": 0, + "FLAGS_enable_pir_in_executor": True, + "FLAGS_pir_apply_inplace_pass": 0, } ) new_scope = paddle.static.Scope() @@ -1435,10 +1433,12 @@ def _check_ir_output(self, place, program, feed_map, fetch_list, outs): ), "Fetch result should have same length when executed in pir" check_method = np.testing.assert_array_equal - if os.getenv("FLAGS_NEW_IR_OPTEST_RELAX_CHECK", None): + if os.getenv("FLAGS_PIR_OPTEST_RELAX_CHECK", None) == "True": check_method = lambda x, y, z: np.testing.assert_allclose( x, y, err_msg=z, atol=1e-6, rtol=1e-6 ) + if os.getenv("FLAGS_PIR_NO_CHECK", None) == "True": + check_method = lambda x, y, err_msg: None for i in range(len(outs)): check_method( @@ -1904,7 +1904,7 @@ def check_inplace_output_with_place( if getattr(self, "no_need_check_inplace", False): return - if os.getenv("FLAGS_enable_new_ir_in_executor"): + if os.getenv("FLAGS_enable_pir_in_executor"): return has_infer_inplace = base.core.has_infer_inplace(self.op_type) @@ -2326,25 +2326,25 @@ def _is_skip_name(self, name): return True return super()._is_skip_name(name) - class NewIRChecker(Checker): + class PirChecker(Checker): def init(self): self.checker_name = "pir checker" def calculate_output(self): self.is_python_api_test = True - new_ir_outs = self.op_test._calc_new_ir_output(place) - if new_ir_outs is None: + pir_outs = self.op_test._calc_pir_output(place) + if pir_outs is None: self.is_python_api_test = False # missing KernelSignature, fall back to eager middle output. - new_ir_outs = self.op_test._calc_dygraph_output( + pir_outs = self.op_test._calc_dygraph_output( place, no_check_set=no_check_set ) - self.outputs = new_ir_outs + self.outputs = pir_outs if self.op_test.is_compared_with_fp32(): self.op_test.enable_cal_ref_output() self.is_python_api_test = True - self.ref_outputs = self.op_test._calc_new_ir_output(place) + self.ref_outputs = self.op_test._calc_pir_output(place) if self.ref_outputs is None: self.is_python_api_test = False # missing KernelSignature, fall back to eager middle output. @@ -2392,12 +2392,12 @@ def convert_uint16_to_float_ifneed(self, actual_np, expect_np): expect_np = convert_uint16_to_float(expect_np) return actual_np, expect_np - def find_imperative_actual(target_name, new_ir_outs, place): - for name in new_ir_outs: + def find_imperative_actual(target_name, pir_outs, place): + for name in pir_outs: if name == target_name: - return new_ir_outs[name][0] + return pir_outs[name][0] - var_list = new_ir_outs[name] + var_list = pir_outs[name] for i, var in enumerate(var_list): if isinstance(var, list): for tensor in var: @@ -2407,19 +2407,19 @@ def find_imperative_actual(target_name, new_ir_outs, place): isinstance(var, paddle.Tensor) and var.name == target_name ): - return new_ir_outs[name][i] + return pir_outs[name][i] self.assertTrue( False, - f"Found failed {new_ir_outs.keys()} {target_name}", + f"Found failed {pir_outs.keys()} {target_name}", ) - def find_imperative_expect(self, target_name, new_ir_outs, place): - for name in new_ir_outs: + def find_imperative_expect(self, target_name, pir_outs, place): + for name in pir_outs: if name == target_name: - return new_ir_outs[name][0] + return pir_outs[name][0] self.assertTrue( False, - f"Found failed {new_ir_outs.keys()} {target_name}", + f"Found failed {pir_outs.keys()} {target_name}", ) def find_actual_value(self, target_name): @@ -2544,8 +2544,8 @@ def _is_skip_name(self, name): or type(place) is paddle.base.libpaddle.CUDAPlace ): with paddle.pir_utils.IrGuard(): - new_ir_checker = NewIRChecker(self, self.outputs) - new_ir_checker.check() + pir_checker = PirChecker(self, self.outputs) + pir_checker.check() # Note(zhiqiu): inplace_atol should be only set when op doesn't ensure # computational consistency. @@ -2695,7 +2695,7 @@ def check_output( self.op_type not in compile_vs_runtime_white_list.COMPILE_RUN_OP_WHITE_LIST ): - if os.getenv("FLAGS_enable_new_ir_in_executor"): + if os.getenv("FLAGS_enable_pir_in_executor"): return self.check_compile_vs_runtime(fetch_list, outs) @@ -2713,7 +2713,7 @@ def check_output_customized( checker(outs) if check_pir: with paddle.pir_utils.IrGuard(): - outs_p = self._calc_new_ir_output(place) + outs_p = self._calc_pir_output(place) outs_p = [outs_p[out] for out in outs_p] outs_p.sort(key=len) checker(outs_p[0]) @@ -2727,10 +2727,10 @@ def check_output_with_place_customized( checker(outs) if check_pir: with paddle.pir_utils.IrGuard(): - outs_p = self._calc_new_ir_output(place) - outs_p = [outs_p[out] for out in outs_p] + outs_p = self._calc_pir_output(place) + outs_p = [outs_p[out][0] for out in outs_p] outs_p.sort(key=len) - checker(outs_p[0]) + checker(outs_p) def _assert_is_close( self, @@ -3132,7 +3132,7 @@ def check_grad_with_place( or type(place) is paddle.base.libpaddle.CUDAPlace ): with paddle.pir_utils.IrGuard(): - new_ir_grad = self._get_ir_gradient( + pir_grad = self._get_ir_gradient( inputs_to_check, place, output_names, @@ -3140,7 +3140,7 @@ def check_grad_with_place( no_grad_set, ) fp32_analytic_grads = [] - for grad in new_ir_grad: + for grad in pir_grad: if grad.dtype == np.uint16: grad = convert_uint16_to_float(grad) max_relative_error = ( @@ -3149,7 +3149,7 @@ def check_grad_with_place( else max_relative_error ) fp32_analytic_grads.append(grad) - new_ir_grad = fp32_analytic_grads + pir_grad = fp32_analytic_grads if self.is_float16_op(): max_relative_error = ( 0.01 @@ -3158,7 +3158,7 @@ def check_grad_with_place( ) self._assert_is_close( numeric_grads, - new_ir_grad, + pir_grad, inputs_to_check, max_relative_error, "Gradient Check On %s" % str(place), @@ -3331,9 +3331,9 @@ def cast_bf16_output(self, block, cast_inputs): def _check_ir_grad_output( self, place, program, scope, feed_dict, fetch_list, gradients ): - if os.getenv("FLAGS_NEW_IR_OPTEST") is None: + if os.getenv("FLAGS_PIR_OPTEST") is None: return - if os.getenv("FLAGS_NEW_IR_OPTEST_WHITE_LIST") is None: + if os.getenv("FLAGS_PIR_OPTEST_WHITE_LIST") is None: return if self.check_prim or self.check_prim_pir: return @@ -3342,15 +3342,15 @@ def _check_ir_grad_output( stored_flag = get_flags( [ - 'FLAGS_enable_new_ir_in_executor', - "FLAGS_new_ir_apply_inplace_pass", + 'FLAGS_enable_pir_in_executor', + "FLAGS_pir_apply_inplace_pass", ] ) try: set_flags( { - "FLAGS_enable_new_ir_in_executor": True, - "FLAGS_new_ir_apply_inplace_pass": 0, + "FLAGS_enable_pir_in_executor": True, + "FLAGS_pir_apply_inplace_pass": 0, } ) executor = Executor(place) @@ -3368,11 +3368,14 @@ def _check_ir_grad_output( ) check_method = np.testing.assert_array_equal - if os.getenv("FLAGS_NEW_IR_OPTEST_RELAX_CHECK", None): + if os.getenv("FLAGS_PIR_OPTEST_RELAX_CHECK", None) == "True": check_method = lambda x, y, z: np.testing.assert_allclose( x, y, err_msg=z, atol=1e-6, rtol=1e-6 ) + if os.getenv("FLAGS_PIR_NO_CHECK", None) == "True": + check_method = lambda x, y, err_msg: None + for i in range(len(new_gradients)): check_method( gradients[i], @@ -3460,7 +3463,7 @@ def _get_gradient( tensor = true_var.get_tensor() tensor.set(grad_out_value, place) grad_outputs.append(var) - if os.getenv("FLAGS_NEW_IR_OPTEST") is not None: + if os.getenv("FLAGS_PIR_OPTEST") is not None: ir_true_var = ir_scope.var(var.name) ir_tensor = ir_true_var.get_tensor() ir_tensor.set(grad_out_value, place) diff --git a/test/legacy_test/prim_op_test.py b/test/legacy_test/prim_op_test.py index bf32aefcebeae..88843a1e55081 100644 --- a/test/legacy_test/prim_op_test.py +++ b/test/legacy_test/prim_op_test.py @@ -235,7 +235,9 @@ def is_bfloat16_type(cls, np_type): def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) class PrimNet(paddle.nn.Layer): diff --git a/test/legacy_test/spawn_runner_base.py b/test/legacy_test/spawn_runner_base.py index 3b94904214224..0a8703725aca7 100644 --- a/test/legacy_test/spawn_runner_base.py +++ b/test/legacy_test/spawn_runner_base.py @@ -83,7 +83,5 @@ def check_dist_result_with_spawn_func(self, test_class, delta=1e-3): dist_loss, delta=delta, msg="The results of single-card execution and multi-card execution are inconsistent." - "signal-card loss is:\n{}\nmulti-card average loss is:\n{}\n".format( - loss, dist_loss - ), + f"signal-card loss is:\n{loss}\nmulti-card average loss is:\n{dist_loss}\n", ) diff --git a/test/legacy_test/test_accuracy_op.py b/test/legacy_test/test_accuracy_op.py index 2acb9aa121e18..44c4cfa7c49ac 100755 --- a/test/legacy_test/test_accuracy_op.py +++ b/test/legacy_test/test_accuracy_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import Program, core, program_guard +from paddle.pir_utils import test_with_pir_api def accuracy_wrapper(infer, indices, label): @@ -53,7 +54,7 @@ def init_dtype(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAccuracyOpFp16(TestAccuracyOp): @@ -61,7 +62,7 @@ def init_dtype(self): self.dtype = np.float16 def test_check_output(self): - self.check_output(atol=1e-3) + self.check_output(atol=1e-3, check_pir=True) @unittest.skipIf( @@ -103,7 +104,7 @@ def init_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-2) + self.check_output_with_place(place, atol=1e-2, check_pir=True) class TestAccuracyOpError(unittest.TestCase): @@ -142,35 +143,38 @@ def test_value_errors(self): class TestAccuracyAPI1(unittest.TestCase): - def setUp(self): + def run_api(self, accuracy_api): with paddle_static_guard(): - self.predictions = paddle.static.data( - shape=[2, 5], name="predictions", dtype="float32" - ) - self.label = paddle.static.data( - shape=[2, 1], name="labels", dtype="int64" - ) - self.result = paddle.static.accuracy( - input=self.predictions, label=self.label, k=1 - ) - self.input_predictions = np.array( - [[0.2, 0.1, 0.4, 0.1, 0.1], [0.2, 0.3, 0.1, 0.15, 0.25]], - dtype="float32", - ) - self.input_labels = np.array([[2], [0]], dtype="int64") - self.expect_value = np.array([0.5], dtype='float32') + with paddle.static.program_guard(paddle.static.Program()): + self.predictions = paddle.static.data( + shape=[2, 5], name="predictions", dtype="float32" + ) + self.label = paddle.static.data( + shape=[2, 1], name="labels", dtype="int64" + ) + self.result = accuracy_api( + input=self.predictions, label=self.label, k=1 + ) + self.input_predictions = np.array( + [[0.2, 0.1, 0.4, 0.1, 0.1], [0.2, 0.3, 0.1, 0.15, 0.25]], + dtype="float32", + ) + self.input_labels = np.array([[2], [0]], dtype="int64") + self.expect_value = np.array([0.5], dtype='float32') + exe = paddle.static.Executor() + (result,) = exe.run( + feed={ + "predictions": self.input_predictions, + 'labels': self.input_labels, + }, + fetch_list=[self.result], + ) + self.assertEqual((result == self.expect_value).all(), True) + @test_with_pir_api def test_api(self): - with paddle_static_guard(): - exe = paddle.static.Executor() - (result,) = exe.run( - feed={ - "predictions": self.input_predictions, - 'labels': self.input_labels, - }, - fetch_list=[self.result.name], - ) - self.assertEqual((result == self.expect_value).all(), True) + self.run_api(accuracy_api=paddle.static.accuracy) + self.run_api(accuracy_api=paddle.metric.accuracy) class TestAccuracyAPI2(unittest.TestCase): diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 0872fdee435d2..1e7550f35dd80 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -27,6 +27,7 @@ from paddle import base, static from paddle.base import Program, core, program_guard from paddle.base.layer_helper import LayerHelper +from paddle.pir_utils import test_with_pir_api @contextmanager @@ -204,9 +205,10 @@ def init_dtype(self): class Test_Exp_Op_Fp16(unittest.TestCase): + @test_with_pir_api def test_api_fp16(self): with static_guard(): - with static.program_guard( + with paddle.static.program_guard( paddle.static.Program(), paddle.static.Program() ): np_x = np.array([[2, 3, 4], [7, 8, 9]]) @@ -387,10 +389,19 @@ def init_dtype(self): def if_enable_cinn(self): pass + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', max_relative_error=0.01, check_prim=True) + self.check_grad( + ['X'], + 'Out', + max_relative_error=0.01, + check_prim=True, + check_pir=True, + ) class TestSigmoid_Complex64(TestSigmoid): @@ -398,7 +409,8 @@ def init_dtype(self): self.dtype = np.complex64 def test_check_output(self): - self.check_output(check_prim=False) + with paddle.static.scope_guard(paddle.static.Scope()): + self.check_output(check_prim=False) def test_check_grad(self): self.check_grad( @@ -406,6 +418,7 @@ def test_check_grad(self): 'Out', max_relative_error=0.006, check_prim=False, + check_pir=True, ) @@ -414,11 +427,7 @@ def init_dtype(self): self.dtype = np.complex128 def test_check_grad(self): - self.check_grad( - ['X'], - 'Out', - check_prim=False, - ) + self.check_grad(['X'], 'Out', check_prim=False, check_pir=True) class TestSigmoid_ZeroDim(TestSigmoid): @@ -459,12 +468,13 @@ def if_enable_cinn(self): def test_check_output(self): place = core.CUDAPlace(0) - # elementwise_pow doesn't support bfloat16, skip check_prim here. - self.check_output_with_place(place, check_prim=True) + self.check_output_with_place(place, check_prim=True, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', check_prim=True) + self.check_grad_with_place( + place, ['X'], 'Out', check_prim=True, check_pir=True + ) ''' @@ -551,6 +561,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -616,7 +627,7 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', max_relative_error=0.008) + self.check_grad(['X'], 'Out', max_relative_error=0.008, check_pir=True) class TestLogSigmoidComplex64(TestLogSigmoid): @@ -645,6 +656,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -746,11 +758,19 @@ class TestTanh_Complex64(TestTanh): def init_dtype(self): self.dtype = np.complex64 + def test_check_output(self): + with paddle.static.scope_guard(paddle.static.Scope()): + self.check_output(check_pir=True) + class TestTanh_Complex128(TestTanh): def init_dtype(self): self.dtype = np.complex128 + def test_check_output(self): + with paddle.static.scope_guard(paddle.static.Scope()): + self.check_output(check_pir=True) + class TestTanh_ZeroDim(TestTanh): def init_shape(self): @@ -773,6 +793,7 @@ def setUp(self): def executed_api(self): self.tanh = F.tanh + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -840,11 +861,15 @@ def setUp(self): self.outputs = {'Out': out} self.convert_input_output() + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) + @test_with_pir_api def test_out_name(self): with static_guard(): with base.program_guard(base.Program()): @@ -903,10 +928,13 @@ def setUp(self): self.convert_input_output() + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestSinh_Complex64(TestSinh): @@ -933,10 +961,13 @@ def test_dygraph(self): z_expected = np.sinh(np_x) np.testing.assert_allclose(z, z_expected, rtol=1e-05) + @test_with_pir_api def test_api(self): with static_guard(): test_data_shape = [11, 17] - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): input_x = np.random.uniform(0.1, 1, test_data_shape).astype( "float32" ) @@ -948,9 +979,9 @@ def test_api(self): pd_sinh_out = paddle.sinh(data_x) exe = base.Executor(place=base.CPUPlace()) - exe.run(base.default_startup_program()) + exe.run(paddle.static.default_startup_program()) (np_sinh_res,) = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={"data_x": input_x}, fetch_list=[pd_sinh_out], ) @@ -1214,10 +1245,13 @@ def init_shape(self): def set_attrs(self): pass + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestHardShrink_threshold_negative(TestHardShrink): @@ -1244,6 +1278,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -1310,6 +1345,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -1383,10 +1419,13 @@ def setUp(self): self.attrs = {"lambda": threshold} + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestSoftshrink_ZeroDim(TestSoftshrink): @@ -1406,6 +1445,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -1814,6 +1854,9 @@ def init_shape(self): def if_enable_cinn(self): pass + def test_check_output(self): + self.check_output(check_pir=True) + # the gradient on floor, ceil, round is undefined. # we return zero as gradient, but the numpy return nan # The same reason with TestFloor @@ -1832,6 +1875,7 @@ def test_check_grad_for_prim(self): 'Out', check_prim=True, only_check_prim=True, + check_pir=True, ) @@ -1865,6 +1909,9 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return @@ -1872,10 +1919,14 @@ def test_check_grad(self): if self.dtype == np.complex64 or self.dtype == np.complex128: # Complex64 [GPU]: AssertionError: 0.0057843705 not less than or equal to 0.005 self.check_grad( - ['X'], 'Out', check_prim=False, max_relative_error=0.006 + ['X'], + 'Out', + check_prim=False, + max_relative_error=0.006, + check_pir=True, ) else: - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def if_enable_cinn(self): pass @@ -1925,10 +1976,13 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestTan_float32(TestTan): @@ -1969,6 +2023,7 @@ def test_dygraph_api(self): out_ref = np.tan(self.x_np) np.testing.assert_allclose(out_ref, out_test.numpy(), rtol=1e-05) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -2016,13 +2071,16 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) -class TestAcos_Comple64(TestAcos): +class TestAcos_Complex64(TestAcos): def init_dtype(self): self.dtype = np.complex64 @@ -2062,14 +2120,22 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + @test_with_pir_api + def test_out_name(self): + # inherit from `TestParameter` + super().test_out_name() + + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return # TODO(ScottWong98): set `check_prim=False` when `fill_any_like` supports `complex` dtype if self.dtype == np.complex64 or self.dtype == np.complex128: - self.check_grad(['X'], 'Out', check_prim=False) + self.check_grad(['X'], 'Out', check_prim=False, check_pir=True) else: - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def if_enable_cinn(self): pass @@ -2113,10 +2179,13 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestAsin_Complex64(TestAsin): @@ -2157,14 +2226,19 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return if self.dtype == np.complex64: # Complex64[CPU]: AssertionError: 0.012431525 not less than or equal to 0.005 - self.check_grad(['X'], 'Out', max_relative_error=0.02) + self.check_grad( + ['X'], 'Out', max_relative_error=0.02, check_pir=True + ) else: - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestAcosh_Complex64(TestAcosh): @@ -2205,14 +2279,19 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return if self.dtype == np.complex64 or self.dtype == np.complex128: # Complex64 [CPU]: AssertionError: 0.006898686 not less than or equal to 0.005 - self.check_grad(['X'], 'Out', max_relative_error=0.007) + self.check_grad( + ['X'], 'Out', max_relative_error=0.007, check_pir=True + ) else: - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestAsinh_Complex64(TestAsinh): @@ -2253,10 +2332,13 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestAtanh_Complex64(TestAtanh): @@ -2292,6 +2374,9 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): pass @@ -2324,10 +2409,12 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) + self.check_grad( + ['X'], 'Out', check_prim=True, check_pir=True, check_prim_pir=True + ) def test_check_output(self): - self.check_output(check_prim=True, check_pir=True) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) def if_enable_cinn(self): pass @@ -2353,6 +2440,7 @@ def setUp(self): def executed_api(self): self.relu = F.relu + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -2435,12 +2523,12 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True) + self.check_output(check_prim=True, check_pir=True) def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) class TestLeakyReluAlpha1(TestLeakyRelu): @@ -2477,6 +2565,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -2649,6 +2738,7 @@ def setUp(self): self.rev_comp_rtol = 1e-8 self.rev_comp_atol = 1e-8 + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -2718,10 +2808,13 @@ def setUp(self): self.convert_input_output() self.attrs = {'t_min': t_min, 't_max': t_max} + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) def ref_relu6(x, threshold=6.0): @@ -2755,7 +2848,7 @@ def init_shape(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestRelu6_ZeroDim(TestRelu6): @@ -2775,6 +2868,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -2798,6 +2892,7 @@ def test_dygraph_api(self): for r in [out1, out2]: np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + @test_with_pir_api def test_base_api(self): with static_guard(): with base.program_guard(base.Program()): @@ -2906,13 +3001,15 @@ def test_check_grad(self): if self.dtype not in [np.complex64, np.complex128] else False, only_check_prim=self.if_only_check_prim(), + check_pir=True, ) def test_check_output(self): self.check_output( check_prim=True if self.dtype not in [np.complex64, np.complex128] - else False + else False, + check_pir=True, ) @@ -2941,6 +3038,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -2964,6 +3062,7 @@ def test_dygraph_api(self): for r in [out1, out2]: np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + @test_with_pir_api def test_base_api(self): with static_guard(): with base.program_guard(base.Program()): @@ -3177,7 +3276,7 @@ def init_shape(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestCELU_ZeroDim(TestCELU): @@ -3200,6 +3299,7 @@ def setUp(self): def executed_api(self): self.celu = F.celu + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -3381,8 +3481,9 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) + @test_with_pir_api def test_api(self): with static_guard(): with paddle.static.program_guard( @@ -3430,6 +3531,7 @@ def test_api_int(self): np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3) paddle.enable_static() + @test_with_pir_api def test_api_bf16(self): with static_guard(): with static.program_guard( @@ -3458,10 +3560,13 @@ def setUp(self): self.outputs = {'Out': out} self.convert_input_output() + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestLog10_ZeroDim(TestLog10): @@ -3480,21 +3585,23 @@ def test_api_int(self): np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3) paddle.enable_static() + @test_with_pir_api def test_api_bf16(self): - with static_guard(): - with static.program_guard( - paddle.static.Program(), paddle.static.Program() - ): - x = [[2, 3, 4], [7, 8, 9]] - x = paddle.to_tensor(x, dtype='bfloat16') - out = paddle.log10(x) - if core.is_compiled_with_cuda(): - place = paddle.CUDAPlace(0) - exe = paddle.static.Executor(place) - (res,) = exe.run(fetch_list=[out]) + paddle.enable_static() + with static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = [[2, 3, 4], [7, 8, 9]] + x = paddle.to_tensor(x, dtype='bfloat16') + out = paddle.log10(x) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + (res,) = exe.run(fetch_list=[out]) class TestLog10API(unittest.TestCase): + @test_with_pir_api def test_api(self): with static_guard(): with paddle.static.program_guard( @@ -3642,10 +3749,10 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007, check_pir=True) def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestSquare_ZeroDim(TestSquare): @@ -3677,11 +3784,13 @@ def init_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', numeric_grad_delta=0.5) + self.check_grad_with_place( + place, ['X'], 'Out', numeric_grad_delta=0.5, check_pir=True + ) class TestPow(TestActivation): @@ -3801,7 +3910,10 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) + + def test_check_output(self): + self.check_output(check_pir=True) class TestSTanhScaleA(TestSTanh): @@ -3848,6 +3960,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -3867,6 +3980,7 @@ def test_dygraph_api(self): for r in [out]: np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + @test_with_pir_api def test_base_api(self): with static_guard(): with base.program_guard(base.Program()): @@ -3938,10 +4052,13 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestSoftplus_Complex64(TestSoftplus): @@ -3949,7 +4066,7 @@ def init_dtype(self): self.dtype = np.complex64 def test_check_grad(self): - self.check_grad(['X'], 'Out', max_relative_error=0.06) + self.check_grad(['X'], 'Out', max_relative_error=0.06, check_pir=True) class TestSoftplus_Complex128(TestSoftplus): @@ -3987,11 +4104,13 @@ def init_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', numeric_grad_delta=0.05) + self.check_grad_with_place( + place, ['X'], 'Out', numeric_grad_delta=0.05, check_pir=True + ) class TestSoftplusAPI(unittest.TestCase): @@ -4007,6 +4126,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -4062,6 +4182,11 @@ def setUp(self): np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + x = ( + np.random.uniform(-1, 1, self.shape) + + 1j * np.random.uniform(-1, 1, self.shape) + ).astype(self.dtype) out = ref_softsign(x) self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} @@ -4071,10 +4196,23 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) + + +class TestSoftsign_Complex64(TestSoftsign): + def init_dtype(self): + self.dtype = np.complex64 + + +class TestSoftsign_Complex128(TestSoftsign): + def init_dtype(self): + self.dtype = np.complex128 class TestSoftsign_ZeroDim(TestSoftsign): @@ -4093,6 +4231,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -4261,6 +4400,14 @@ def init_shape(self): def set_attrs(self): pass + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out', check_pir=True) + class TestHardSigmoidFP32(TestHardSigmoid): def set_attrs(self): @@ -4288,6 +4435,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -4311,6 +4459,7 @@ def test_dygraph_api(self): for r in [out1, out2]: np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + @test_with_pir_api def test_base_api(self): with static_guard(): with base.program_guard(base.Program()): @@ -4373,6 +4522,7 @@ def test_check_grad(self): self.check_grad( ['X'], 'Out', + check_pir=True, ) @@ -4392,6 +4542,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -4415,6 +4566,7 @@ def test_dygraph_api(self): for r in [out1, out2]: np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + @test_with_pir_api def test_base_api(self): with static_guard(): with base.program_guard(base.Program()): @@ -4628,7 +4780,9 @@ def test_check_grad(self): TestExpFp32_Prim, check_prim=True, enable_cinn=True, check_prim_pir=True ) create_test_act_fp16_class(TestExpm1) -create_test_act_fp16_class(TestSigmoid, check_prim=True, enable_cinn=True) +create_test_act_fp16_class( + TestSigmoid, check_prim=True, enable_cinn=True, check_pir=True +) create_test_act_fp16_class( TestSilu, check_prim=True, enable_cinn=True, check_prim_pir=True ) @@ -4637,8 +4791,8 @@ def test_check_grad(self): TestTanh, check_prim=True, check_prim_pir=True, enable_cinn=True ) create_test_act_fp16_class(TestTanhshrink) -create_test_act_fp16_class(TestHardShrink) -create_test_act_fp16_class(TestSoftshrink) +create_test_act_fp16_class(TestHardShrink, check_pir=True) +create_test_act_fp16_class(TestSoftshrink, check_pir=True) create_test_act_fp16_class( TestSqrt, check_prim=True, @@ -4658,22 +4812,30 @@ def test_check_grad(self): ) create_test_act_fp16_class(TestCeil, grad_check=False, check_pir=True) create_test_act_fp16_class( - TestFloor, check_prim=True, grad_check=False, enable_cinn=True + TestFloor, + check_prim=True, + grad_check=False, + enable_cinn=True, + check_pir=True, ) -create_test_act_fp16_class(TestCos) -create_test_act_fp16_class(TestTan) +create_test_act_fp16_class(TestCos, check_pir=True) +create_test_act_fp16_class(TestTan, check_pir=True) create_test_act_fp16_class(TestCosh) -create_test_act_fp16_class(TestAcos) -create_test_act_fp16_class(TestSin) +create_test_act_fp16_class(TestAcos, check_pir=True) +create_test_act_fp16_class(TestSin, check_pir=True) create_test_act_fp16_class(TestSinh) -create_test_act_fp16_class(TestAsin) -create_test_act_fp16_class(TestAtan) -create_test_act_fp16_class(TestAcosh) -create_test_act_fp16_class(TestAsinh) -create_test_act_fp16_class(TestAtanh) -create_test_act_fp16_class(TestRound, grad_check=False) +create_test_act_fp16_class(TestAsin, check_pir=True) +create_test_act_fp16_class(TestAtan, check_pir=True) +create_test_act_fp16_class(TestAcosh, check_pir=True) +create_test_act_fp16_class(TestAsinh, check_pir=True) +create_test_act_fp16_class(TestAtanh, check_pir=True) +create_test_act_fp16_class(TestRound, grad_check=False, check_pir=True) create_test_act_fp16_class( - TestRelu, check_prim=True, enable_cinn=True, check_pir=True + TestRelu, + check_prim=True, + enable_cinn=True, + check_pir=True, + check_prim_pir=True, ) create_test_act_fp16_class( TestGelu, @@ -4686,7 +4848,7 @@ def test_check_grad(self): cinn_rtol=1e-3, cinn_atol=1e-3, ) -create_test_act_fp16_class(TestBRelu) +create_test_act_fp16_class(TestBRelu, check_pir=True) create_test_act_fp16_class(TestRelu6) create_test_act_fp16_class(TestSoftRelu, check_dygraph=False) create_test_act_fp16_class(TestELU) @@ -4694,23 +4856,25 @@ def test_check_grad(self): create_test_act_fp16_class(TestReciprocal) create_test_act_fp16_class(TestLog, check_prim=True, check_pir=True) if core.is_compiled_with_rocm(): - create_test_act_fp16_class(TestLog2) + create_test_act_fp16_class(TestLog2, check_pir=True) else: - create_test_act_fp16_class(TestLog2) -create_test_act_fp16_class(TestLog10) + create_test_act_fp16_class(TestLog2, check_pir=True) +create_test_act_fp16_class(TestLog10, check_pir=True) create_test_act_fp16_class(TestLog1p) -create_test_act_fp16_class(TestSquare) +create_test_act_fp16_class(TestSquare, check_pir=True) create_test_act_fp16_class(TestPow, check_prim=True, check_prim_pir=True) create_test_act_fp16_class(TestPow_API) create_test_act_fp16_class(TestSTanh) -create_test_act_fp16_class(TestSoftplus) -create_test_act_fp16_class(TestSoftsign) +create_test_act_fp16_class(TestSoftplus, check_pir=True) +create_test_act_fp16_class(TestSoftsign, check_pir=True) create_test_act_fp16_class(TestThresholdedRelu) -create_test_act_fp16_class(TestHardSigmoid) +create_test_act_fp16_class(TestHardSigmoid, check_pir=True) create_test_act_fp16_class(TestSwish) -create_test_act_fp16_class(TestHardSwish, check_prim=True) +create_test_act_fp16_class(TestHardSwish, check_prim=True, check_pir=True) create_test_act_fp16_class(TestMish) -create_test_act_fp16_class(TestLeakyRelu, check_prim=True, enable_cinn=True) +create_test_act_fp16_class( + TestLeakyRelu, check_prim=True, enable_cinn=True, check_pir=True +) create_test_act_fp16_class( TestLeakyReluAlpha1, check_prim=True, enable_cinn=True ) @@ -4738,6 +4902,7 @@ def create_test_act_bf16_class( check_prim=False, enable_cinn=False, check_pir=False, + check_prim_pir=False, grad_atol=1e-2, **kwargs ): @@ -4770,6 +4935,7 @@ def test_check_output(self): atol=atol, check_prim=check_prim, check_pir=check_pir, + check_prim_pir=check_prim_pir, ) def test_check_grad(self): @@ -4782,6 +4948,7 @@ def test_check_grad(self): max_relative_error=grad_atol, check_prim=check_prim, check_pir=check_pir, + check_prim_pir=check_prim_pir, ) cls_name = "{}_{}".format(parent.__name__, "BF16OP") @@ -4794,13 +4961,13 @@ def test_check_grad(self): TestExpFp32_Prim, check_prim=True, check_prim_pir=True ) create_test_act_bf16_class(TestExpm1) -create_test_act_bf16_class(TestSigmoid, check_prim=True) +create_test_act_bf16_class(TestSigmoid, check_prim=True, check_pir=True) create_test_act_bf16_class(TestSilu, check_prim=True, check_prim_pir=True) create_test_act_bf16_class(TestLogSigmoid) create_test_act_bf16_class(TestTanh, check_prim=True, check_prim_pir=True) create_test_act_bf16_class(TestTanhshrink) -create_test_act_bf16_class(TestHardShrink) -create_test_act_bf16_class(TestSoftshrink) +create_test_act_bf16_class(TestHardShrink, check_pir=True) +create_test_act_bf16_class(TestSoftshrink, check_pir=True) create_test_act_bf16_class( TestSqrt, check_prim=True, check_pir=True, check_prim_pir=True ) @@ -4809,20 +4976,24 @@ def test_check_grad(self): ) create_test_act_bf16_class(TestAbs, check_prim=True, check_pir=True) create_test_act_bf16_class(TestCeil, grad_check=False, check_pir=True) -create_test_act_bf16_class(TestFloor, grad_check=False, check_prim=True) -create_test_act_bf16_class(TestCos) -create_test_act_bf16_class(TestTan) +create_test_act_bf16_class( + TestFloor, grad_check=False, check_prim=True, check_pir=True +) +create_test_act_bf16_class(TestCos, check_pir=True) +create_test_act_bf16_class(TestTan, check_pir=True) create_test_act_bf16_class(TestCosh) -create_test_act_bf16_class(TestAcos) -create_test_act_bf16_class(TestSin) +create_test_act_bf16_class(TestAcos, check_pir=True) +create_test_act_bf16_class(TestSin, check_pir=True) create_test_act_bf16_class(TestSinh) -create_test_act_bf16_class(TestAsin) -create_test_act_bf16_class(TestAtan) -create_test_act_bf16_class(TestAcosh) -create_test_act_bf16_class(TestAsinh) -create_test_act_bf16_class(TestAtanh) -create_test_act_bf16_class(TestRound, grad_check=False) -create_test_act_bf16_class(TestRelu, check_prim=True, check_pir=True) +create_test_act_bf16_class(TestAsin, check_pir=True) +create_test_act_bf16_class(TestAtan, check_pir=True) +create_test_act_bf16_class(TestAcosh, check_pir=True) +create_test_act_bf16_class(TestAsinh, check_pir=True) +create_test_act_bf16_class(TestAtanh, check_pir=True) +create_test_act_bf16_class(TestRound, grad_check=False, check_pir=True) +create_test_act_bf16_class( + TestRelu, check_prim=True, check_pir=True, check_prim_pir=True +) create_test_act_bf16_class( TestGelu, check_prim=True, @@ -4832,7 +5003,7 @@ def test_check_grad(self): cinn_rtol=1e-2, cinn_atol=1e-2, ) -create_test_act_bf16_class(TestBRelu) +create_test_act_bf16_class(TestBRelu, check_pir=True) create_test_act_bf16_class(TestRelu6) create_test_act_bf16_class(TestSoftRelu, check_dygraph=False) create_test_act_bf16_class(TestELU) @@ -4840,23 +5011,23 @@ def test_check_grad(self): create_test_act_bf16_class(TestReciprocal) create_test_act_bf16_class(TestLog, check_prim=True, check_pir=True) if core.is_compiled_with_rocm(): - create_test_act_bf16_class(TestLog2) + create_test_act_bf16_class(TestLog2, check_pir=True) else: - create_test_act_bf16_class(TestLog2) -create_test_act_bf16_class(TestLog10) + create_test_act_bf16_class(TestLog2, check_pir=True) +create_test_act_bf16_class(TestLog10, check_pir=True) create_test_act_bf16_class(TestLog1p) -create_test_act_bf16_class(TestSquare) +create_test_act_bf16_class(TestSquare, check_pir=True) create_test_act_bf16_class(TestPow, check_prim=True) create_test_act_bf16_class(TestPow_API) create_test_act_bf16_class(TestSTanh) -create_test_act_bf16_class(TestSoftplus) -create_test_act_bf16_class(TestSoftsign) +create_test_act_bf16_class(TestSoftplus, check_pir=True) +create_test_act_bf16_class(TestSoftsign, check_pir=True) create_test_act_bf16_class(TestThresholdedRelu) -create_test_act_bf16_class(TestHardSigmoid) +create_test_act_bf16_class(TestHardSigmoid, check_pir=True) create_test_act_bf16_class(TestSwish) -create_test_act_bf16_class(TestHardSwish, check_prim=True) +create_test_act_bf16_class(TestHardSwish, check_prim=True, check_pir=True) create_test_act_bf16_class(TestMish) -create_test_act_bf16_class(TestLeakyRelu, check_prim=True) +create_test_act_bf16_class(TestLeakyRelu, check_prim=True, check_pir=True) create_test_act_bf16_class(TestLeakyReluAlpha1, check_prim=True) create_test_act_bf16_class(TestLeakyReluAlpha2, check_prim=True) create_test_act_bf16_class(TestLeakyReluAlpha3, check_prim=True) diff --git a/test/legacy_test/test_adaptive_avg_pool2d.py b/test/legacy_test/test_adaptive_avg_pool2d.py index 9c6c0c96287a4..137e943fa5e89 100644 --- a/test/legacy_test/test_adaptive_avg_pool2d.py +++ b/test/legacy_test/test_adaptive_avg_pool2d.py @@ -19,8 +19,8 @@ from test_attribute_var import UnittestBase import paddle -from paddle import base from paddle.base import Program, core, program_guard +from paddle.pir_utils import test_with_pir_api def adaptive_start_index(index, input_size, output_size): @@ -113,37 +113,45 @@ def setUp(self): x=self.x_np, output_size=[None, 3], pool_type="avg" ) + @test_with_pir_api def test_static_graph(self): for use_cuda in ( [False, True] if core.is_compiled_with_cuda() else [False] ): place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() paddle.enable_static() - x = paddle.static.data( - name="x", shape=[2, 3, 7, 7], dtype="float32" - ) - out_1 = paddle.nn.functional.adaptive_avg_pool2d( - x=x, output_size=[3, 3] - ) + main_program = paddle.static.Program() + startup_program = paddle.static.Program() - out_2 = paddle.nn.functional.adaptive_avg_pool2d(x=x, output_size=5) + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + name="x", shape=[2, 3, 7, 7], dtype="float32" + ) - out_3 = paddle.nn.functional.adaptive_avg_pool2d( - x=x, output_size=[2, 5] - ) + out_1 = paddle.nn.functional.adaptive_avg_pool2d( + x=x, output_size=[3, 3] + ) - out_4 = paddle.nn.functional.adaptive_avg_pool2d( - x=x, output_size=[3, 3], data_format="NHWC" - ) + out_2 = paddle.nn.functional.adaptive_avg_pool2d( + x=x, output_size=5 + ) - out_5 = paddle.nn.functional.adaptive_avg_pool2d( - x=x, output_size=[None, 3] - ) + out_3 = paddle.nn.functional.adaptive_avg_pool2d( + x=x, output_size=[2, 5] + ) + + out_4 = paddle.nn.functional.adaptive_avg_pool2d( + x=x, output_size=[3, 3], data_format="NHWC" + ) + + out_5 = paddle.nn.functional.adaptive_avg_pool2d( + x=x, output_size=[None, 3] + ) exe = paddle.static.Executor(place=place) [res_1, res_2, res_3, res_4, res_5] = exe.run( - base.default_main_program(), + main_program, feed={"x": self.x_np}, fetch_list=[out_1, out_2, out_3, out_4, out_5], ) @@ -232,38 +240,47 @@ def setUp(self): x=self.x_np, output_size=[None, 3], pool_type="avg" ) + @test_with_pir_api def test_static_graph(self): for use_cuda in ( [False, True] if core.is_compiled_with_cuda() else [False] ): place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() paddle.enable_static() - x = paddle.static.data( - name="x", shape=[2, 3, 7, 7], dtype="float32" - ) + main_program = paddle.static.Program() + startup_program = paddle.static.Program() - adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(output_size=[3, 3]) - out_1 = adaptive_avg_pool(x=x) + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + name="x", shape=[2, 3, 7, 7], dtype="float32" + ) - adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(output_size=5) - out_2 = adaptive_avg_pool(x=x) + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D( + output_size=[3, 3] + ) + out_1 = adaptive_avg_pool(x=x) - adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(output_size=[2, 5]) - out_3 = adaptive_avg_pool(x=x) + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(output_size=5) + out_2 = adaptive_avg_pool(x=x) - adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D( - output_size=[3, 3], data_format="NHWC" - ) - out_4 = adaptive_avg_pool(x=x) + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D( + output_size=[2, 5] + ) + out_3 = adaptive_avg_pool(x=x) - adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D( - output_size=[None, 3] - ) - out_5 = adaptive_avg_pool(x=x) + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D( + output_size=[3, 3], data_format="NHWC" + ) + out_4 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D( + output_size=[None, 3] + ) + out_5 = adaptive_avg_pool(x=x) exe = paddle.static.Executor(place=place) [res_1, res_2, res_3, res_4, res_5] = exe.run( - base.default_main_program(), + main_program, feed={"x": self.x_np}, fetch_list=[out_1, out_2, out_3, out_4, out_5], ) diff --git a/test/legacy_test/test_adaptive_max_pool2d.py b/test/legacy_test/test_adaptive_max_pool2d.py index 90f0f12e9303f..3a0579cbcc1fb 100644 --- a/test/legacy_test/test_adaptive_max_pool2d.py +++ b/test/legacy_test/test_adaptive_max_pool2d.py @@ -159,6 +159,61 @@ def test_static_graph(self): np.testing.assert_allclose(res_5, self.res_5_np) + def test_static_graph_return_mask(self): + for use_cuda in ( + [False, True] if core.is_compiled_with_cuda() else [False] + ): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.enable_static() + x = paddle.static.data( + name="x", shape=[2, 3, 7, 7], dtype="float32" + ) + + out_1 = paddle.nn.functional.adaptive_max_pool2d( + x=x, output_size=[3, 3], return_mask=True + ) + + out_2 = paddle.nn.functional.adaptive_max_pool2d( + x=x, output_size=5, return_mask=True + ) + + out_3 = paddle.nn.functional.adaptive_max_pool2d( + x=x, output_size=[2, 5], return_mask=True + ) + + # out_4 = paddle.nn.functional.adaptive_max_pool2d( + # x=x, output_size=[3, 3], data_format="NHWC"), return_mask=True + + out_5 = paddle.nn.functional.adaptive_max_pool2d( + x=x, output_size=[None, 3], return_mask=True + ) + + exe = paddle.static.Executor(place=place) + [ + res_1, + mask_1, + res_2, + mask_2, + res_3, + mask_3, + res_5, + mask_5, + ] = exe.run( + base.default_main_program(), + feed={"x": self.x_np}, + fetch_list=[out_1, out_2, out_3, out_5], + ) + + self.assertEqual(res_1.shape, mask_1.shape) + + self.assertEqual(res_2.shape, mask_2.shape) + + self.assertEqual(res_3.shape, mask_3.shape) + + # self.assertEqual(res_4.shape, mask_4.shape) + + self.assertEqual(res_5.shape, mask_5.shape) + def test_dynamic_graph(self): for use_cuda in ( [False, True] if core.is_compiled_with_cuda() else [False] diff --git a/test/legacy_test/test_allclose_op.py b/test/legacy_test/test_allclose_op.py index 54e78867e7443..cb76671284e2c 100644 --- a/test/legacy_test/test_allclose_op.py +++ b/test/legacy_test/test_allclose_op.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestAllcloseOp(OpTest): @@ -174,12 +175,13 @@ def test_equal_nan(): class TestAllcloseOpFp16(unittest.TestCase): + @test_with_pir_api def test_fp16(self): x_data = np.random.rand(10, 10).astype('float16') y_data = np.random.rand(10, 10).astype('float16') with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data(shape=[10, 10], name='x', dtype='float16') - y = paddle.static.data(shape=[10, 10], name='x', dtype='float16') + y = paddle.static.data(shape=[10, 10], name='y', dtype='float16') out = paddle.allclose(x, y, rtol=1e-05, atol=1e-08) if core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) diff --git a/test/legacy_test/test_angle_op.py b/test/legacy_test/test_angle_op.py index c4ec247b1677b..9e2ee04531b09 100644 --- a/test/legacy_test/test_angle_op.py +++ b/test/legacy_test/test_angle_op.py @@ -20,6 +20,7 @@ import paddle from paddle import static from paddle.base import core, dygraph +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -49,7 +50,7 @@ def setUp(self): self.outputs = {'Out': out_ref} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -58,6 +59,7 @@ def test_check_grad(self): user_defined_grads=[ angle_grad(self.x, np.ones_like(self.x) / self.x.size) ], + check_pir=True, ) @@ -93,7 +95,7 @@ def setUp(self): self.place = core.CUDAPlace(0) def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, check_pir=True) def test_check_grad(self): self.check_grad_with_place( @@ -103,6 +105,7 @@ def test_check_grad(self): user_defined_grads=[ angle_grad(self.x, np.ones_like(self.x) / self.x.size) ], + check_pir=True, ) @@ -119,7 +122,7 @@ def setUp(self): self.outputs = {'Out': out_ref} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -128,6 +131,7 @@ def test_check_grad(self): user_defined_grads=[ angle_grad(self.x, np.ones_like(self.x) / self.x.size) ], + check_pir=True, ) @@ -142,6 +146,7 @@ def test_dygraph(self): out_np = paddle.angle(x).numpy() np.testing.assert_allclose(self.out, out_np, rtol=1e-05) + @test_with_pir_api def test_static(self): mp, sp = static.Program(), static.Program() with static.program_guard(mp, sp): diff --git a/test/legacy_test/test_arange.py b/test/legacy_test/test_arange.py index e71402518696b..e901a060c3233 100644 --- a/test/legacy_test/test_arange.py +++ b/test/legacy_test/test_arange.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api from paddle.static import Program, program_guard @@ -138,9 +139,12 @@ def test_static_errors(self): class TestArangeAPI(unittest.TestCase): + @test_with_pir_api def test_out(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x1 = paddle.arange(0, 5, 1, 'float32') place = ( @@ -151,9 +155,9 @@ def test_out(self): exe = paddle.static.Executor(place) out = exe.run(fetch_list=[x1]) - expected_data = np.arange(0, 5, 1).astype(np.float32) - self.assertEqual((out == expected_data).all(), True) - self.assertListEqual(list(x1.shape), [5]) + expected_data = np.arange(0, 5, 1).astype(np.float32) + self.assertEqual((out == expected_data).all(), True) + self.assertListEqual(list(x1.shape), [5]) paddle.disable_static(place) diff --git a/test/legacy_test/test_assign_op.py b/test/legacy_test/test_assign_op.py index 270fe45ffe742..50f9e5e054869 100644 --- a/test/legacy_test/test_assign_op.py +++ b/test/legacy_test/test_assign_op.py @@ -279,7 +279,9 @@ class TestAssignOpErrorApi(unittest.TestCase): @test_with_pir_api def test_errors(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): # The type of input must be Variable or numpy.ndarray. x1 = base.create_lod_tensor( np.array([[-1]]), [[1]], base.CPUPlace() @@ -293,7 +295,9 @@ def test_errors(self): @test_with_pir_api def test_type_error(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = [paddle.randn([3, 3]), paddle.randn([3, 3])] # not support to assign list(var) self.assertRaises(TypeError, paddle.assign, x) diff --git a/test/legacy_test/test_assign_value_op.py b/test/legacy_test/test_assign_value_op.py index 3bdc97f4247f3..6ff4282d9fc55 100644 --- a/test/legacy_test/test_assign_value_op.py +++ b/test/legacy_test/test_assign_value_op.py @@ -108,7 +108,7 @@ def test_assign(self): def test_pir_assign(self): with paddle.pir_utils.IrGuard(): main_program = paddle.pir.Program() - with paddle.pir.core.program_guard(main_program): + with paddle.static.program_guard(main_program): x = paddle.zeros(shape=[1], dtype=self.dtype) paddle.assign(self.value, output=x) diff --git a/test/legacy_test/test_auto_parallel_reshard_mppp.py b/test/legacy_test/test_auto_parallel_reshard_mppp.py index c98f96fc30c6e..0ef5403a7d2c7 100644 --- a/test/legacy_test/test_auto_parallel_reshard_mppp.py +++ b/test/legacy_test/test_auto_parallel_reshard_mppp.py @@ -210,7 +210,7 @@ def check_initialization_for_mppp(dist_startup_prog, rank_id): def check_allgather(dist_main_program): - allgather_out = "x@RESHARD_0" + allgather_out = "c_allgather@RESHARD_0.tmp_0" # "x@RESHARD_0" var_result = False op_result = False vars = dist_main_program.global_block().vars diff --git a/test/legacy_test/test_batch_norm_op.py b/test/legacy_test/test_batch_norm_op.py index cfbb33c2a2933..151753685cd34 100644 --- a/test/legacy_test/test_batch_norm_op.py +++ b/test/legacy_test/test_batch_norm_op.py @@ -28,6 +28,7 @@ from paddle import base from paddle.base import Program, core, program_guard from paddle.base.framework import grad_var_name +from paddle.pir_utils import test_with_pir_api _set_use_system_allocator(True) @@ -359,6 +360,113 @@ def check_with_place(self, place, data_layout, dtype, shape): atol=atol, ) + def check_with_place_without_scale_and_bias( + self, place, data_layout, dtype, shape + ): + epsilon = 0.00001 + if len(shape) == 2: + x_shape = shape + c = x_shape[1] + else: + n, h, w, c = shape[0], shape[1], shape[2], shape[3] + if data_layout == "NHWC": + x_shape = [n, h, w, c] + elif data_layout == "NCHW": + x_shape = [n, c, h, w] + else: + raise ValueError("Unknown data layout.") + scale_shape = [c] + + if dtype == np.uint16: + x_val = np.random.random_sample(x_shape).astype(np.float32) + else: + x_val = np.random.random_sample(x_shape).astype(dtype) + # generate some negative values to test case with relu fused + x_val = x_val - 0.5 + scale_val = np.ones(scale_shape).astype(np.float32) + bias_val = np.zeros(scale_shape).astype(np.float32) + + mean = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) + + if dtype == np.uint16: + y_out = _reference_testing( + x_val, scale_val, bias_val, mean, variance, epsilon, data_layout + ).astype(np.float32) + y_out = convert_float_to_uint16(y_out) + else: + y_out = _reference_testing( + x_val, scale_val, bias_val, mean, variance, epsilon, data_layout + ).astype(dtype) + if self.fuse_with_relu: + y_out = np.maximum(y_out, 0) + + if dtype == np.uint16: + x_val = convert_float_to_uint16(x_val) + + exe = paddle.static.Executor(place) + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + x_ = paddle.static.data( + name='x_val', shape=x_shape, dtype='float32' + ) + mean_ = paddle.static.data( + name='mean', shape=scale_shape, dtype='float32' + ) + variance_ = paddle.static.data( + name='variance', shape=scale_shape, dtype='float32' + ) + y_tensor = paddle.nn.functional.batch_norm( + x_, + mean_, + variance_, + None, + None, + False, + data_format=data_layout, + ) + y_tensor = exe.run( + main, + feed={'x_val': x_val, 'mean': mean, 'variance': variance}, + fetch_list=[y_tensor], + )[0] + + # When op is called without Executor then + # MKL-DNN Tensor is returned. For NHWC data layout + # dims will be in NCHW order as it is MKL-DNN way + # of memory descripting. So we need to convert NCHW + # dims into NHWC. + if data_layout == "NHWC" and self.use_mkldnn: + # Create executor to have MKL-DNN cache + # cleared after NHWC unit test + place = core.CPUPlace() + exe = base.Executor(place) + dims = y_tensor.shape() + c = dims.pop(1) + dims.append(c) + y_tensor._set_dims(dims) + + # check inference result + atol = 1e-3 + if dtype == np.uint16: + y_tensor = convert_uint16_to_float(y_tensor) + y_out = convert_uint16_to_float(y_out) + atol = 1e-2 + self.__assert_close( + y_tensor, + y_out, + "inference output are different at " + + str(place) + + ", " + + data_layout + + ", " + + str(np.dtype(dtype)) + + str(np.array(y_tensor)) + + str(y_out), + atol=atol, + ) + def test_check_output(self): places = [core.CPUPlace()] if core.is_compiled_with_cuda(): @@ -370,6 +478,12 @@ def test_check_output(self): place, data_format, self.dtype, [2, 3, 4, 5] ) self.check_with_place(place, data_format, self.dtype, [2, 3]) + self.check_with_place_without_scale_and_bias( + place, data_format, self.dtype, [2, 3, 4, 5] + ) + self.check_with_place_without_scale_and_bias( + place, data_format, self.dtype, [2, 3] + ) def init_kernel_type(self): pass @@ -857,6 +971,7 @@ def compute(x, is_test, trainable_statistics): y2 = compute(x, True, True) np.testing.assert_allclose(y1, y2, rtol=1e-05) + @test_with_pir_api def test_static(self): places = [base.CPUPlace()] if core.is_compiled_with_cuda(): @@ -866,7 +981,9 @@ def test_static(self): shape = [4, 10, 16, 16] def compute(x_np, is_test, trainable_statistics): - with program_guard(Program(), Program()): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, @@ -876,7 +993,7 @@ def compute(x_np, is_test, trainable_statistics): name='x', shape=x_np.shape, dtype=x_np.dtype ) y = bn(x) - exe.run(base.default_startup_program()) + exe.run(startup_program) r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] return r @@ -887,8 +1004,11 @@ def compute(x_np, is_test, trainable_statistics): class TestDygraphBatchNormOpenReserveSpace(unittest.TestCase): + @test_with_pir_api def test_reservespace(self): - with program_guard(Program(), Program()): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): paddle.enable_static() x = np.random.random(size=(3, 10, 3, 7)).astype('float32') x = paddle.static.data(name='x', shape=x.shape, dtype=x.dtype) diff --git a/test/legacy_test/test_batch_norm_op_v2.py b/test/legacy_test/test_batch_norm_op_v2.py index b53bfb9e73373..4ae4c609ea1de 100644 --- a/test/legacy_test/test_batch_norm_op_v2.py +++ b/test/legacy_test/test_batch_norm_op_v2.py @@ -18,7 +18,8 @@ import paddle from paddle import base -from paddle.base import Program, core, program_guard +from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestBatchNorm(unittest.TestCase): @@ -191,25 +192,93 @@ def compute_v3(x, is_test, trainable_statistics): ), trainable_statistics=trainable_statistics, ) - y = bn(paddle.to_tensor(x)) - return y.numpy() + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() + + def compute_v3_1(x, is_test, trainable_statistics): + with base.dygraph.guard(p): + bn = paddle.nn.BatchNorm( + shape[1], + is_test=is_test, + param_attr=False, + bias_attr=False, + trainable_statistics=trainable_statistics, + ) + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() + + def compute_v3_2(x, is_test, trainable_statistics): + with base.dygraph.guard(p): + bn = paddle.nn.BatchNorm( + shape[1], + is_test=is_test, + param_attr=False, + bias_attr=base.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0), + trainable=False, + ), + trainable_statistics=trainable_statistics, + ) + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() + + def compute_v3_3(x, is_test, trainable_statistics): + with base.dygraph.guard(p): + bn = paddle.nn.BatchNorm( + shape[1], + is_test=is_test, + param_attr=base.ParamAttr( + initializer=paddle.nn.initializer.Constant(1.0), + trainable=False, + ), + bias_attr=False, + trainable_statistics=trainable_statistics, + ) + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() def compute_v4(x): with base.dygraph.guard(p): bn = paddle.nn.BatchNorm2D( shape[1], weight_attr=False, bias_attr=False ) - y = bn(paddle.to_tensor(x)) - return y.numpy() + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() x = np.random.randn(*shape).astype("float32") y1 = compute_v1(x, False, False) y2 = compute_v2(x) - y3 = compute_v3(x, False, False) - y4 = compute_v4(x) + y3, g3 = compute_v3(x, False, False) + y3_1, g3_1 = compute_v3_1(x, False, False) + y3_2, g3_2 = compute_v3_2(x, False, False) + y3_3, g3_3 = compute_v3_3(x, False, False) + y4, g4 = compute_v4(x) np.testing.assert_allclose(y1, y2, rtol=1e-05) np.testing.assert_allclose(y3, y4, rtol=1e-05) - + np.testing.assert_allclose(y3_1, y4, rtol=1e-05) + np.testing.assert_allclose(y3_2, y4, rtol=1e-05) + np.testing.assert_allclose(y3_3, y4, rtol=1e-05) + np.testing.assert_allclose(g3, g4, rtol=1e-05) + np.testing.assert_allclose(g3_1, g4, rtol=1e-05) + np.testing.assert_allclose(g3_2, g4, rtol=1e-05) + np.testing.assert_allclose(g3_3, g4, rtol=1e-05) + + @test_with_pir_api def test_static(self): places = [base.CPUPlace()] if core.is_compiled_with_cuda(): @@ -219,7 +288,9 @@ def test_static(self): shape = [4, 10, 16, 16] def compute_v1(x_np, is_test, trainable_statistics): - with program_guard(Program(), Program()): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with base.program_guard(main_program, startup_program): bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, @@ -229,18 +300,20 @@ def compute_v1(x_np, is_test, trainable_statistics): name='x', shape=x_np.shape, dtype=x_np.dtype ) y = bn(x) - exe.run(base.default_startup_program()) + exe.run(startup_program) r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] return r def compute_v2(x_np): - with program_guard(Program(), Program()): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with base.program_guard(main_program, startup_program): bn = paddle.nn.BatchNorm2D(shape[1]) x = paddle.static.data( name='x', shape=x_np.shape, dtype=x_np.dtype ) y = bn(x) - exe.run(base.default_startup_program()) + exe.run(startup_program) r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] return r diff --git a/test/legacy_test/test_batch_sampler.py b/test/legacy_test/test_batch_sampler.py index 72ea1577beb53..750a916b3b29a 100644 --- a/test/legacy_test/test_batch_sampler.py +++ b/test/legacy_test/test_batch_sampler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import unittest import numpy as np @@ -22,6 +23,7 @@ RandomSampler, Sampler, SequenceSampler, + SubsetRandomSampler, WeightedRandomSampler, ) @@ -110,6 +112,28 @@ def test_with_generator_num_samples(self): assert tuple(sorted(rets)) == tuple(range(0, 50)) +class TestSubsetRandomSampler(unittest.TestCase): + def test_main(self): + indices = list(range(100)) + random.shuffle(indices) + indices = indices[:30] + sampler = SubsetRandomSampler(indices) + assert len(sampler) == len(indices) + + hints = {i: 0 for i in indices} + for index in iter(sampler): + hints[index] += 1 + for h in hints.values(): + assert h == 1 + + def test_raise(self): + try: + sampler = SubsetRandomSampler([]) + self.assertTrue(False) + except ValueError: + self.assertTrue(True) + + class TestBatchSampler(unittest.TestCase): def setUp(self): self.num_samples = 1000 diff --git a/test/legacy_test/test_bicubic_interp_v2_op.py b/test/legacy_test/test_bicubic_interp_v2_op.py index 4df985cccde32..5955739a62c10 100644 --- a/test/legacy_test/test_bicubic_interp_v2_op.py +++ b/test/legacy_test/test_bicubic_interp_v2_op.py @@ -21,6 +21,7 @@ from paddle import base from paddle.base import Program, core, program_guard from paddle.nn.functional import interpolate +from paddle.pir_utils import test_with_pir_api def create_test_case0(self): @@ -314,10 +315,10 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) def init_test_case(self): create_test_case0(self) @@ -355,14 +356,11 @@ def init_test_case(self): class TestBicubicInterpOpFP16(TestBicubicInterpOp): def test_check_output(self): - self.check_output(atol=1e-3) + self.check_output(atol=1e-3, check_pir=True) def test_check_grad(self): self.check_grad( - ['X'], - 'Out', - in_place=True, - max_relative_error=1e-2, + ['X'], 'Out', in_place=True, max_relative_error=1e-2, check_pir=True ) def init_test_case(self): @@ -481,10 +479,10 @@ def setUp(self): self.outputs = {'Out': convert_float_to_uint16(output_np)} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) def init_test_case(self): create_test_case0(self) @@ -583,6 +581,7 @@ def init_test_case(self): class TestBicubicInterpOpAPI(unittest.TestCase): + @test_with_pir_api def test_case(self): np.random.seed(200) x_data = np.random.random((2, 3, 6, 6)).astype("float32") @@ -591,15 +590,15 @@ def test_case(self): actual_size_data = np.array([12, 12]).astype("int32") scale_data = np.array([2.0]).astype("float32") - prog = base.Program() - startup_prog = base.Program() + prog = paddle.static.Program() + startup_prog = paddle.static.Program() place = ( base.CUDAPlace(0) if base.core.is_compiled_with_cuda() else base.CPUPlace() ) - with base.program_guard(prog, startup_prog): + with paddle.static.program_guard(prog, startup_prog): x = paddle.static.data( name="x", shape=[2, 3, 6, 6], dtype="float32" ) @@ -641,9 +640,9 @@ def test_case(self): ) exe = base.Executor(place) - exe.run(base.default_startup_program()) + exe.run(startup_prog) results = exe.run( - base.default_main_program(), + prog, feed={ "x": x_data, "dim": dim_data, diff --git a/test/legacy_test/test_bilinear_interp_v2_op.py b/test/legacy_test/test_bilinear_interp_v2_op.py index ea7ac00498953..126cdaaf5da40 100755 --- a/test/legacy_test/test_bilinear_interp_v2_op.py +++ b/test/legacy_test/test_bilinear_interp_v2_op.py @@ -304,10 +304,10 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) def init_test_case(self): create_test_case0(self) @@ -386,10 +386,12 @@ def init_test_case(self): class TestBilinearInterpOpFP16(TestBilinearInterpOp): def test_check_output(self): - self.check_output(atol=1e-3) + self.check_output(atol=1e-3, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True, max_relative_error=1e-2) + self.check_grad( + ['X'], 'Out', in_place=True, max_relative_error=1e-2, check_pir=True + ) def init_test_case(self): create_test_case0(self) @@ -513,10 +515,12 @@ def setUp(self): self.outputs = {'Out': convert_float_to_uint16(output_np)} def test_check_output(self): - self.check_output(atol=1e-2) + self.check_output(atol=1e-2, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True, max_relative_error=1e-2) + self.check_grad( + ['X'], 'Out', in_place=True, max_relative_error=1e-2, check_pir=True + ) def init_test_case(self): create_test_case0(self) @@ -650,7 +654,9 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output_with_place(place=core.CPUPlace(), atol=1) + self.check_output_with_place( + place=core.CPUPlace(), atol=1, check_pir=True + ) def init_test_case(self): self.interp_method = 'bilinear' @@ -824,10 +830,10 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) def init_test_case(self): self.interp_method = 'bilinear' diff --git a/test/legacy_test/test_bitwise_op.py b/test/legacy_test/test_bitwise_op.py index 21a7abe812ad7..eb3ec980f05fd 100644 --- a/test/legacy_test/test_bitwise_op.py +++ b/test/legacy_test/test_bitwise_op.py @@ -150,7 +150,7 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): pass @@ -258,7 +258,7 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): pass @@ -363,7 +363,7 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): pass diff --git a/test/legacy_test/test_broadcast_to_op.py b/test/legacy_test/test_broadcast_to_op.py index 331addd30909b..5e2bb7c1ed161 100644 --- a/test/legacy_test/test_broadcast_to_op.py +++ b/test/legacy_test/test_broadcast_to_op.py @@ -18,7 +18,8 @@ import paddle from paddle import base -from paddle.base import Program, program_guard +from paddle.pir_utils import test_with_pir_api +from paddle.static import Program, program_guard paddle.enable_static() @@ -40,36 +41,42 @@ def test_errors(self): # Test python API class TestBroadcastToAPI(unittest.TestCase): + # TODO: add test_with_pir_api + # base.backward.calc_gradient maybe not support pir + # AttributeError: 'paddle.base.libpaddle.pir.Program' object has no attribute '_appending_grad_times' def test_api(self): - input = np.random.random([12, 14]).astype("float32") - x = paddle.static.data(name='x', shape=[12, 14], dtype="float32") - - positive_2 = paddle.tensor.fill_constant([1], "int32", 12) - expand_shape = paddle.static.data( - name="expand_shape", - shape=[2], - dtype="int32", - ) - - out_1 = paddle.broadcast_to(x, shape=[12, 14]) - out_2 = paddle.broadcast_to(x, shape=[positive_2, 14]) - out_3 = paddle.broadcast_to(x, shape=expand_shape) - - g0 = base.backward.calc_gradient(out_2, x) - - exe = base.Executor(place=base.CPUPlace()) - res_1, res_2, res_3 = exe.run( - base.default_main_program(), - feed={ - "x": input, - "expand_shape": np.array([12, 14]).astype("int32"), - }, - fetch_list=[out_1, out_2, out_3], - ) - np.testing.assert_array_equal(res_1, np.tile(input, (1, 1))) - np.testing.assert_array_equal(res_2, np.tile(input, (1, 1))) - np.testing.assert_array_equal(res_3, np.tile(input, (1, 1))) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input = np.random.random([12, 14]).astype("float32") + x = paddle.static.data(name='x', shape=[12, 14], dtype="float32") + + positive_2 = paddle.tensor.fill_constant([1], "int32", 12) + expand_shape = paddle.static.data( + name="expand_shape", + shape=[2], + dtype="int32", + ) + + out_1 = paddle.broadcast_to(x, shape=[12, 14]) + out_2 = paddle.broadcast_to(x, shape=[positive_2, 14]) + out_3 = paddle.broadcast_to(x, shape=expand_shape) + + g0 = base.backward.calc_gradient(out_2, x) + + exe = base.Executor(place=base.CPUPlace()) + res_1, res_2, res_3 = exe.run( + feed={ + "x": input, + "expand_shape": np.array([12, 14]).astype("int32"), + }, + fetch_list=[out_1, out_2, out_3], + ) + np.testing.assert_array_equal(res_1, np.tile(input, (1, 1))) + np.testing.assert_array_equal(res_2, np.tile(input, (1, 1))) + np.testing.assert_array_equal(res_3, np.tile(input, (1, 1))) + @test_with_pir_api def test_api_fp16_gpu(self): if paddle.base.core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) diff --git a/test/legacy_test/test_cholesky_op.py b/test/legacy_test/test_cholesky_op.py index 034cbb87366fa..fba0c337cfee6 100644 --- a/test/legacy_test/test_cholesky_op.py +++ b/test/legacy_test/test_cholesky_op.py @@ -22,6 +22,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.base.backward import _as_list @skip_check_grad_ci( @@ -58,7 +59,7 @@ def setUp(self): self.outputs = {"Out": output_data} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): places = [base.CPUPlace()] @@ -79,7 +80,38 @@ def func(self, place): root_t = paddle.transpose(root, self.trans_dims) x = paddle.matmul(x=root, y=root_t) + 1e-05 out = paddle.cholesky(x, upper=self.attrs["upper"]) - grad_check(root, out, x_init=root_data, place=place) + # check input arguments + root = _as_list(root) + out = _as_list(out) + + for v in root: + v.stop_gradient = False + v.persistable = True + for u in out: + u.stop_gradient = False + u.persistable = True + + # init variable in startup program + scope = base.executor.global_scope() + exe = base.Executor(place) + exe.run(base.default_startup_program()) + + x_init = _as_list(root_data) + # init inputs if x_init is not None + if x_init: + if len(x_init) != len(root): + raise ValueError( + 'len(x_init) (=%d) is not the same' + ' as len(x) (= %d)' % (len(x_init), len(root)) + ) + # init variable in main program + for var, arr in zip(root, x_init): + assert var.shape == arr.shape + feeds = {k.name: v for k, v in zip(root, x_init)} + exe.run(prog, feed=feeds, scope=scope) + grad_check( + root, out, x_init=x_init, place=place, program=prog, scope=scope + ) def init_config(self): self._upper = True diff --git a/test/legacy_test/test_cholesky_solve_op.py b/test/legacy_test/test_cholesky_solve_op.py index c1c9e4e7400bc..76f3e2e2a64eb 100644 --- a/test/legacy_test/test_cholesky_solve_op.py +++ b/test/legacy_test/test_cholesky_solve_op.py @@ -139,7 +139,7 @@ def setUp(self): # check Op forward result def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) # check Op grad def test_check_grad_normal(self): diff --git a/test/legacy_test/test_class_center_sample_op.py b/test/legacy_test/test_class_center_sample_op.py index da903b5a16689..ae54d0deaf850 100644 --- a/test/legacy_test/test_class_center_sample_op.py +++ b/test/legacy_test/test_class_center_sample_op.py @@ -19,7 +19,8 @@ from op_test import OpTest, paddle_static_guard import paddle -from paddle.base import Program, core, program_guard +from paddle.base import core +from paddle.pir_utils import test_with_pir_api def class_center_sample_numpy(label, classes_list, num_samples): @@ -118,7 +119,9 @@ def setUp(self): } def test_check_output(self): - self.check_output(no_check_set=['SampledLocalClassCenter']) + self.check_output( + no_check_set=['SampledLocalClassCenter'], check_pir=True + ) class TestClassCenterSampleOpINT32(TestClassCenterSampleOp): @@ -160,9 +163,12 @@ def test_static(self): for place in self.places: self.check_static_result(place=place) + @test_with_pir_api def check_static_result(self, place): with paddle_static_guard(): - with program_guard(Program(), Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): label_np = np.random.randint( 0, self.num_classes, (self.batch_size,), dtype=self.dtype ) @@ -185,7 +191,6 @@ def check_static_result(self, place): ) exe = paddle.base.Executor(place) [remapped_label_res, sampled_class_index_res] = exe.run( - paddle.base.default_main_program(), feed={'label': label_np}, fetch_list=[remapped_label, sampled_class_index], ) diff --git a/test/legacy_test/test_clip_op.py b/test/legacy_test/test_clip_op.py index 6ac2b0a17e7ca..1fad87de2d1dc 100644 --- a/test/legacy_test/test_clip_op.py +++ b/test/legacy_test/test_clip_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import Program, core, program_guard +from paddle.pir_utils import test_with_pir_api class TestClipOp(OpTest): @@ -266,16 +267,11 @@ class TestClipAPI(unittest.TestCase): def _executed_api(self, x, min=None, max=None): return paddle.clip(x, min, max) + @test_with_pir_api def test_clip(self): paddle.enable_static() data_shape = [1, 9, 9, 4] data = np.random.random(data_shape).astype('float32') - images = paddle.static.data( - name='image', shape=data_shape, dtype='float32' - ) - min = paddle.static.data(name='min', shape=[1], dtype='float32') - max = paddle.static.data(name='max', shape=[1], dtype='float32') - place = ( base.CUDAPlace(0) if base.core.is_compiled_with_cuda() @@ -283,23 +279,31 @@ def test_clip(self): ) exe = base.Executor(place) - out_1 = self._executed_api(images, min=min, max=max) - out_2 = self._executed_api(images, min=0.2, max=0.9) - out_3 = self._executed_api(images, min=0.3) - out_4 = self._executed_api(images, max=0.7) - out_5 = self._executed_api(images, min=min) - out_6 = self._executed_api(images, max=max) - out_7 = self._executed_api(images, max=-1.0) - out_8 = self._executed_api(images) - out_9 = self._executed_api( - paddle.cast(images, 'float64'), min=0.2, max=0.9 - ) - out_10 = self._executed_api( - paddle.cast(images * 10, 'int32'), min=2, max=8 - ) - out_11 = self._executed_api( - paddle.cast(images * 10, 'int64'), min=2, max=8 - ) + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + images = paddle.static.data( + name='image', shape=data_shape, dtype='float32' + ) + min = paddle.static.data(name='min', shape=[1], dtype='float32') + max = paddle.static.data(name='max', shape=[1], dtype='float32') + out_1 = self._executed_api(images, min=min, max=max) + out_2 = self._executed_api(images, min=0.2, max=0.9) + out_3 = self._executed_api(images, min=0.3) + out_4 = self._executed_api(images, max=0.7) + out_5 = self._executed_api(images, min=min) + out_6 = self._executed_api(images, max=max) + out_7 = self._executed_api(images, max=-1.0) + out_8 = self._executed_api(images) + out_9 = self._executed_api( + paddle.cast(images, 'float64'), min=0.2, max=0.9 + ) + out_10 = self._executed_api( + paddle.cast(images * 10, 'int32'), min=2, max=8 + ) + out_11 = self._executed_api( + paddle.cast(images * 10, 'int64'), min=2, max=8 + ) ( res1, @@ -314,7 +318,7 @@ def test_clip(self): res10, res11, ) = exe.run( - base.default_main_program(), + main, feed={ "image": data, "min": np.array([0.2]).astype('float32'), @@ -430,6 +434,7 @@ def test_errors(self): class TestClipOpFp16(unittest.TestCase): + @test_with_pir_api def test_fp16(self): paddle.enable_static() data_shape = [1, 9, 9, 4] diff --git a/test/legacy_test/test_collective_api_base.py b/test/legacy_test/test_collective_api_base.py index 669910ee0283a..8f6a382297a1f 100644 --- a/test/legacy_test/test_collective_api_base.py +++ b/test/legacy_test/test_collective_api_base.py @@ -189,7 +189,8 @@ def runtime_main(test_class, col_type): args["reduce_type"] = os.getenv("REDUCE_TYPE") args["use_comm_context"] = bool(int(os.getenv("USE_COMM_CONTEXT", "0"))) args["dynamic_static_unified_comm"] = bool( - int(os.getenv("FLAGS_dynamic_static_unified_comm", "0")) + os.getenv("FLAGS_dynamic_static_unified_comm", "false").lower() + == "true" ) model.run_trainer(args) @@ -609,16 +610,23 @@ def convertbf16(origin): send_ptr2 = send_ptr2 + global_expert_count2[idx] result1 = [] result2 = [] + + def is_empyt_list(x): + if isinstance(x, list) and len(x) == 0: + return True + return False + for i in range(tot_expert): for arr in output1[i]: - if arr == []: + if is_empyt_list(arr): continue result1.append(arr) for i in range(tot_expert): for arr in output2[i]: - if arr == []: + if is_empyt_list(arr): continue result2.append(arr) + if result1 == []: output1 = np.array([]) else: diff --git a/test/legacy_test/test_compare_op.py b/test/legacy_test/test_compare_op.py index 19b2b4a2406e2..91dce088ef88e 100755 --- a/test/legacy_test/test_compare_op.py +++ b/test/legacy_test/test_compare_op.py @@ -20,7 +20,8 @@ import paddle from paddle import base -from paddle.base import Program, core, program_guard +from paddle.base import core +from paddle.pir_utils import test_with_pir_api def create_test_class(op_type, typename, callback, check_pir=False): @@ -39,7 +40,9 @@ def test_output(self): def test_errors(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[-1, 2], dtype='int32') y = paddle.static.data(name='y', shape=[-1, 2], dtype='int32') a = paddle.static.data(name='a', shape=[-1, 2], dtype='int16') @@ -58,14 +61,14 @@ def test_errors(self): if _type_name == 'float16' and (not core.is_compiled_with_cuda()): continue - create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) - create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) - create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) + create_test_class('less_than', _type_name, lambda _a, _b: _a < _b, True) + create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b, True) + create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b, True) create_test_class( 'greater_equal', _type_name, lambda _a, _b: _a >= _b, True ) create_test_class('equal', _type_name, lambda _a, _b: _a == _b, True) - create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b) + create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b, True) def create_paddle_case(op_type, callback): @@ -79,9 +82,12 @@ def setUp(self): if core.is_compiled_with_cuda(): self.place = paddle.CUDAPlace(0) + @test_with_pir_api def test_api(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[4], dtype='int64') y = paddle.static.data(name='y', shape=[4], dtype='int64') op = eval("paddle.%s" % (self.op_type)) @@ -93,10 +99,13 @@ def test_api(self): ) self.assertEqual((res == self.real_result).all(), True) + @test_with_pir_api def test_api_float(self): if self.op_type == "equal": paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[4], dtype='int64') y = paddle.static.data(name='y', shape=[], dtype='int64') op = eval("paddle.%s" % (self.op_type)) @@ -290,9 +299,12 @@ def test_dynamic_api_bool(self): self.assertEqual((out.numpy() == self.real_result).all(), True) paddle.enable_static() + @test_with_pir_api def test_broadcast_api_1(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data( name='x', shape=[1, 2, 1, 3], dtype='int32' ) @@ -308,9 +320,12 @@ def test_broadcast_api_1(self): ) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_broadcast_api_2(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[1, 2, 3], dtype='int32') y = paddle.static.data( name='y', shape=[1, 2, 1, 3], dtype='int32' @@ -326,9 +341,12 @@ def test_broadcast_api_2(self): ) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_broadcast_api_3(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[5], dtype='int32') y = paddle.static.data(name='y', shape=[3, 1], dtype='int32') op = eval("paddle.%s" % (self.op_type)) @@ -342,9 +360,12 @@ def test_broadcast_api_3(self): ) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_zero_dim_api_1(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.randint(-3, 3, shape=[], dtype='int32') y = paddle.randint(-3, 3, shape=[], dtype='int32') op = eval("paddle.%s" % (self.op_type)) @@ -358,9 +379,12 @@ def test_zero_dim_api_1(self): real_result = callback(x_np, y_np) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_zero_dim_api_2(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32') y = paddle.randint(-3, 3, shape=[], dtype='int32') op = eval("paddle.%s" % (self.op_type)) @@ -374,9 +398,12 @@ def test_zero_dim_api_2(self): real_result = callback(x_np, y_np) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_zero_dim_api_3(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.randint(-3, 3, shape=[], dtype='int32') y = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32') op = eval("paddle.%s" % (self.op_type)) @@ -390,9 +417,12 @@ def test_zero_dim_api_3(self): real_result = callback(x_np, y_np) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_bool_api_4(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[3, 1], dtype='bool') y = paddle.static.data(name='y', shape=[3, 1], dtype='bool') op = eval("paddle.%s" % (self.op_type)) @@ -406,9 +436,12 @@ def test_bool_api_4(self): ) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_bool_broadcast_api_4(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[3, 1], dtype='bool') y = paddle.static.data(name='y', shape=[1], dtype='bool') op = eval("paddle.%s" % (self.op_type)) @@ -424,7 +457,9 @@ def test_bool_broadcast_api_4(self): def test_attr_name(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[-1, 4], dtype='int32') y = paddle.static.data(name='y', shape=[-1, 4], dtype='int32') op = eval("paddle.%s" % (self.op_type)) @@ -469,18 +504,20 @@ def test_check_output(self): globals()[cls_name] = TestCompareOpBF16Op -create_bf16_case('less_than', lambda _a, _b: _a < _b) -create_bf16_case('less_equal', lambda _a, _b: _a <= _b) -create_bf16_case('greater_than', lambda _a, _b: _a > _b) +create_bf16_case('less_than', lambda _a, _b: _a < _b, True) +create_bf16_case('less_equal', lambda _a, _b: _a <= _b, True) +create_bf16_case('greater_than', lambda _a, _b: _a > _b, True) create_bf16_case('greater_equal', lambda _a, _b: _a >= _b, True) create_bf16_case('equal', lambda _a, _b: _a == _b, True) -create_bf16_case('not_equal', lambda _a, _b: _a != _b) +create_bf16_case('not_equal', lambda _a, _b: _a != _b, True) class TestCompareOpError(unittest.TestCase): def test_errors(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): # The input x and y of compare_op must be Variable. x = paddle.static.data(name='x', shape=[-1, 1], dtype="float32") y = base.create_lod_tensor( @@ -490,9 +527,12 @@ def test_errors(self): class API_TestElementwise_Equal(unittest.TestCase): + @test_with_pir_api def test_api(self): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): label = paddle.assign(np.array([3, 3], dtype="int32")) limit = paddle.assign(np.array([3, 2], dtype="int32")) out = paddle.equal(x=label, y=limit) @@ -501,7 +541,9 @@ def test_api(self): (res,) = exe.run(fetch_list=[out]) self.assertEqual((res == np.array([True, False])).all(), True) - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): label = paddle.assign(np.array([3, 3], dtype="int32")) limit = paddle.assign(np.array([3, 3], dtype="int32")) out = paddle.equal(x=label, y=limit) @@ -510,9 +552,12 @@ def test_api(self): (res,) = exe.run(fetch_list=[out]) self.assertEqual((res == np.array([True, True])).all(), True) + @test_with_pir_api def test_api_fp16(self): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): label = paddle.to_tensor([3, 3], dtype="float16") limit = paddle.to_tensor([3, 2], dtype="float16") out = paddle.equal(x=label, y=limit) @@ -524,6 +569,7 @@ def test_api_fp16(self): class API_TestElementwise_Greater_Than(unittest.TestCase): + @test_with_pir_api def test_api_fp16(self): paddle.enable_static() with paddle.static.program_guard( @@ -540,17 +586,21 @@ def test_api_fp16(self): class TestCompareOpPlace(unittest.TestCase): + @test_with_pir_api def test_place_1(self): paddle.enable_static() place = paddle.CPUPlace() if core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) - label = paddle.assign(np.array([3, 3], dtype="int32")) - limit = paddle.assign(np.array([3, 2], dtype="int32")) - out = paddle.less_than(label, limit) - exe = base.Executor(place) - (res,) = exe.run(fetch_list=[out]) - self.assertEqual((res == np.array([False, False])).all(), True) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + label = paddle.assign(np.array([3, 3], dtype="int32")) + limit = paddle.assign(np.array([3, 2], dtype="int32")) + out = paddle.less_than(label, limit) + exe = base.Executor(place) + (res,) = exe.run(fetch_list=[out]) + self.assertEqual((res == np.array([False, False])).all(), True) def test_place_2(self): place = paddle.CPUPlace() diff --git a/test/legacy_test/test_compare_reduce_op.py b/test/legacy_test/test_compare_reduce_op.py index e281407c242b0..fdd08b2990cfe 100644 --- a/test/legacy_test/test_compare_reduce_op.py +++ b/test/legacy_test/test_compare_reduce_op.py @@ -32,7 +32,7 @@ def setUp(self): self.op_type = op_type def test_output(self): - self.check_output() + self.check_output(check_pir=True) cls_name = "{}_{}_{}".format(op_type, typename, 'not_equal_all') Cls.__name__ = cls_name @@ -51,7 +51,7 @@ def setUp(self): self.op_type = op_type def test_output(self): - self.check_output() + self.check_output(check_pir=True) cls_name = "{}_{}_{}".format(op_type, typename, 'not_shape_equal_all') Cls.__name__ = cls_name @@ -69,7 +69,7 @@ def setUp(self): self.op_type = op_type def test_output(self): - self.check_output() + self.check_output(check_pir=True) cls_name = "{}_{}_{}".format(op_type, typename, 'equal_all') Cls.__name__ = cls_name @@ -89,7 +89,7 @@ def setUp(self): self.op_type = op_type def test_output(self): - self.check_output() + self.check_output(check_pir=True) cls_name = "{}_{}_{}".format(op_type, typename, 'equal_all') Cls.__name__ = cls_name diff --git a/test/legacy_test/test_complex_op.py b/test/legacy_test/test_complex_op.py index 151ecfbdb6524..e0388b0c560d3 100644 --- a/test/legacy_test/test_complex_op.py +++ b/test/legacy_test/test_complex_op.py @@ -20,6 +20,7 @@ import paddle from paddle import static from paddle.base import dygraph +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -45,12 +46,13 @@ def setUp(self): self.outputs = {'Out': out_ref} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( ['X', 'Y'], 'Out', + check_pir=True, ) def test_check_grad_ignore_x(self): @@ -58,6 +60,7 @@ def test_check_grad_ignore_x(self): ['Y'], 'Out', no_grad_set=set('X'), + check_pir=True, ) def test_check_grad_ignore_y(self): @@ -65,6 +68,7 @@ def test_check_grad_ignore_y(self): ['X'], 'Out', no_grad_set=set('Y'), + check_pir=True, ) @@ -102,6 +106,7 @@ def test_dygraph(self): out_np = paddle.complex(x, y).numpy() np.testing.assert_allclose(self.out, out_np, rtol=1e-05) + @test_with_pir_api def test_static(self): mp, sp = static.Program(), static.Program() with static.program_guard(mp, sp): diff --git a/test/legacy_test/test_complex_view_op.py b/test/legacy_test/test_complex_view_op.py index b747804ca65c5..c529e3950a9fb 100644 --- a/test/legacy_test/test_complex_view_op.py +++ b/test/legacy_test/test_complex_view_op.py @@ -20,6 +20,7 @@ import paddle from paddle import static from paddle.base import dygraph +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -43,7 +44,7 @@ def setUp(self): self.outputs = {'Out': out_ref} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -64,7 +65,7 @@ def setUp(self): self.python_api = paddle.as_real def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -84,6 +85,7 @@ def test_dygraph(self): out_np = paddle.as_complex(x).numpy() np.testing.assert_allclose(self.out, out_np, rtol=1e-05) + @test_with_pir_api def test_static(self): mp, sp = static.Program(), static.Program() with static.program_guard(mp, sp): @@ -107,6 +109,7 @@ def test_dygraph(self): out_np = paddle.as_real(x).numpy() np.testing.assert_allclose(self.out, out_np, rtol=1e-05) + @test_with_pir_api def test_static(self): mp, sp = static.Program(), static.Program() with static.program_guard(mp, sp): diff --git a/test/legacy_test/test_concat_op.py b/test/legacy_test/test_concat_op.py index efa87c3609570..eac5df80f8379 100644 --- a/test/legacy_test/test_concat_op.py +++ b/test/legacy_test/test_concat_op.py @@ -22,6 +22,7 @@ import paddle from paddle import base from paddle.base import Program, core, program_guard +from paddle.pir_utils import test_with_pir_api class TestConcatOp(OpTest): @@ -591,76 +592,88 @@ def test_input_same_dtype(): class TestConcatAPI(unittest.TestCase): + @test_with_pir_api def test_base_api(self): paddle.enable_static() - x_1 = paddle.static.data( - shape=[None, 1, 4, 5], dtype='int32', name='x_1' - ) - paddle.concat([x_1, x_1], 0) - - input_2 = np.random.random([2, 1, 4, 5]).astype("int32") - input_3 = np.random.random([2, 2, 4, 5]).astype("int32") - x_2 = paddle.static.data(shape=[2, 1, 4, 5], dtype='int32', name='x_2') - x_3 = paddle.static.data(shape=[2, 2, 4, 5], dtype='int32', name='x_3') - positive_1_int32 = paddle.tensor.fill_constant([1], "int32", 1) - positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1) - out_1 = paddle.concat([x_2, x_3], axis=1) - out_2 = paddle.concat([x_2, x_3], axis=positive_1_int32) - out_3 = paddle.concat([x_2, x_3], axis=positive_1_int64) - - exe = base.Executor(place=base.CPUPlace()) - [res_1, res_2, res_3] = exe.run( - base.default_main_program(), - feed={"x_1": input_2, "x_2": input_2, "x_3": input_3}, - fetch_list=[out_1, out_2, out_3], - ) - np.testing.assert_array_equal( - res_1, np.concatenate((input_2, input_3), axis=1) - ) - np.testing.assert_array_equal( - res_2, np.concatenate((input_2, input_3), axis=1) - ) - np.testing.assert_array_equal( - res_3, np.concatenate((input_2, input_3), axis=1) - ) + with paddle.static.program_guard(paddle.static.Program()): + x_1 = paddle.static.data( + shape=[None, 1, 4, 5], dtype='int32', name='x_1' + ) + paddle.concat([x_1, x_1], 0) + input_2 = np.random.random([2, 1, 4, 5]).astype("int32") + input_3 = np.random.random([2, 2, 4, 5]).astype("int32") + x_2 = paddle.static.data( + shape=[2, 1, 4, 5], dtype='int32', name='x_2' + ) + x_3 = paddle.static.data( + shape=[2, 2, 4, 5], dtype='int32', name='x_3' + ) + positive_1_int32 = paddle.tensor.fill_constant([1], "int32", 1) + positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1) + out_1 = paddle.concat([x_2, x_3], axis=1) + out_2 = paddle.concat([x_2, x_3], axis=positive_1_int32) + out_3 = paddle.concat([x_2, x_3], axis=positive_1_int64) + + exe = base.Executor(place=base.CPUPlace()) + [res_1, res_2, res_3] = exe.run( + paddle.static.default_main_program(), + feed={"x_1": input_2, "x_2": input_2, "x_3": input_3}, + fetch_list=[out_1, out_2, out_3], + ) + np.testing.assert_array_equal( + res_1, np.concatenate((input_2, input_3), axis=1) + ) + np.testing.assert_array_equal( + res_2, np.concatenate((input_2, input_3), axis=1) + ) + np.testing.assert_array_equal( + res_3, np.concatenate((input_2, input_3), axis=1) + ) + + @test_with_pir_api def test_api(self): paddle.enable_static() - x_1 = paddle.static.data( - shape=[None, 1, 4, 5], dtype='int32', name='x_1' - ) - paddle.concat([x_1, x_1], 0) - - input_2 = np.random.random([2, 1, 4, 5]).astype("int32") - input_3 = np.random.random([2, 2, 4, 5]).astype("int32") - x_2 = paddle.static.data(shape=[2, 1, 4, 5], dtype='int32', name='x_2') - x_3 = paddle.static.data(shape=[2, 2, 4, 5], dtype='int32', name='x_3') - positive_1_int32 = paddle.tensor.fill_constant([1], "int32", 1) - positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1) - negative_int64 = paddle.tensor.fill_constant([1], "int64", -3) - out_1 = paddle.concat(x=[x_2, x_3], axis=1) - out_2 = paddle.concat(x=[x_2, x_3], axis=positive_1_int32) - out_3 = paddle.concat(x=[x_2, x_3], axis=positive_1_int64) - out_4 = paddle.concat(x=[x_2, x_3], axis=negative_int64) - - exe = paddle.static.Executor(place=paddle.CPUPlace()) - [res_1, res_2, res_3, res_4] = exe.run( - paddle.static.default_main_program(), - feed={"x_1": input_2, "x_2": input_2, "x_3": input_3}, - fetch_list=[out_1, out_2, out_3, out_4], - ) - np.testing.assert_array_equal( - res_1, np.concatenate((input_2, input_3), axis=1) - ) - np.testing.assert_array_equal( - res_2, np.concatenate((input_2, input_3), axis=1) - ) - np.testing.assert_array_equal( - res_3, np.concatenate((input_2, input_3), axis=1) - ) - np.testing.assert_array_equal( - res_4, np.concatenate((input_2, input_3), axis=1) - ) + with paddle.static.program_guard(paddle.static.Program()): + x_1 = paddle.static.data( + shape=[None, 1, 4, 5], dtype='int32', name='x_1' + ) + paddle.concat([x_1, x_1], 0) + + input_2 = np.random.random([2, 1, 4, 5]).astype("int32") + input_3 = np.random.random([2, 2, 4, 5]).astype("int32") + x_2 = paddle.static.data( + shape=[2, 1, 4, 5], dtype='int32', name='x_2' + ) + x_3 = paddle.static.data( + shape=[2, 2, 4, 5], dtype='int32', name='x_3' + ) + positive_1_int32 = paddle.tensor.fill_constant([1], "int32", 1) + positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1) + negative_int64 = paddle.tensor.fill_constant([1], "int64", -3) + out_1 = paddle.concat(x=[x_2, x_3], axis=1) + out_2 = paddle.concat(x=[x_2, x_3], axis=positive_1_int32) + out_3 = paddle.concat(x=[x_2, x_3], axis=positive_1_int64) + out_4 = paddle.concat(x=[x_2, x_3], axis=negative_int64) + + exe = paddle.static.Executor(place=paddle.CPUPlace()) + [res_1, res_2, res_3, res_4] = exe.run( + paddle.static.default_main_program(), + feed={"x_1": input_2, "x_2": input_2, "x_3": input_3}, + fetch_list=[out_1, out_2, out_3, out_4], + ) + np.testing.assert_array_equal( + res_1, np.concatenate((input_2, input_3), axis=1) + ) + np.testing.assert_array_equal( + res_2, np.concatenate((input_2, input_3), axis=1) + ) + np.testing.assert_array_equal( + res_3, np.concatenate((input_2, input_3), axis=1) + ) + np.testing.assert_array_equal( + res_4, np.concatenate((input_2, input_3), axis=1) + ) def test_imperative(self): in1 = np.array([[1, 2, 3], [4, 5, 6]]) @@ -729,8 +742,8 @@ def setUp(self): def set_program(self, use_base_api): paddle.enable_static() if use_base_api: - self.program = base.Program() - with base.program_guard(self.program): + self.program = paddle.static.Program() + with paddle.static.program_guard(self.program): input = paddle.assign(self.x) tensor_array = paddle.tensor.create_array(dtype='float32') zero = paddle.tensor.fill_constant( diff --git a/test/legacy_test/test_cond.py b/test/legacy_test/test_cond.py index cec7664ae6cb6..76467328f7725 100644 --- a/test/legacy_test/test_cond.py +++ b/test/legacy_test/test_cond.py @@ -19,7 +19,7 @@ from simple_nets import batchnorm_fc_with_inputs, simple_fc_net_with_inputs sys.path.append("../dygraph_to_static") -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import compare_legacy_with_pir import paddle from paddle import base @@ -31,7 +31,7 @@ class TestCondInputOutput(unittest.TestCase): - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_return_single_var(self): """ pseudocode: @@ -78,7 +78,7 @@ def false_func(): np.asarray(ret), np.full((3, 2), -1, np.int32), rtol=1e-05 ) - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_return_0d_tensor(self): """ pseudocode: @@ -116,7 +116,7 @@ def false_func(): np.testing.assert_allclose(np.asarray(ret), np.array(2), rtol=1e-05) self.assertEqual(ret.shape, ()) - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_0d_tensor_as_cond(self): """ pseudocode: @@ -217,7 +217,7 @@ def test_0d_tensor_dygraph(self): ) self.assertEqual(a.grad.shape, []) - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_return_var_tuple(self): """ pseudocode: @@ -265,7 +265,7 @@ def false_func(): np.asarray(ret[1]), np.full((2, 3), True, bool), rtol=1e-05 ) - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_pass_and_modify_var(self): """ pseudocode: @@ -356,7 +356,7 @@ def false_func(): self.assertIsNone(out2) self.assertIsNone(out3) - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_wrong_structure_exception(self): """ test returning different number of tensors cannot merge into output diff --git a/test/legacy_test/test_conv1d_layer.py b/test/legacy_test/test_conv1d_layer.py index e284c25568abf..48bd182c486b9 100644 --- a/test/legacy_test/test_conv1d_layer.py +++ b/test/legacy_test/test_conv1d_layer.py @@ -20,6 +20,7 @@ import paddle.base.dygraph as dg import paddle.nn.functional as F from paddle import base, nn +from paddle.pir_utils import test_with_pir_api class Conv1DTestCase(unittest.TestCase): @@ -99,13 +100,16 @@ def functional(self, place): w_var = paddle.static.data( "weight", self.weight_shape, dtype=self.dtype ) - b_var = paddle.static.data( - "bias", (self.num_filters,), dtype=self.dtype - ) + if not self.no_bias: + b_var = paddle.static.data( + "bias", (self.num_filters,), dtype=self.dtype + ) + else: + b_var = None y_var = F.conv1d( x_var, w_var, - b_var if not self.no_bias else None, + b_var, padding=self.padding, stride=self.stride, dilation=self.dilation, @@ -140,6 +144,7 @@ def paddle_nn_layer(self): y_np = y_var.numpy() return y_np + @test_with_pir_api def _test_equivalence(self, place): result1 = self.functional(place) with dg.guard(place): diff --git a/test/legacy_test/test_conv2d_layer.py b/test/legacy_test/test_conv2d_layer.py index 4290a7352afed..a347472bd2a87 100644 --- a/test/legacy_test/test_conv2d_layer.py +++ b/test/legacy_test/test_conv2d_layer.py @@ -218,8 +218,53 @@ def paddle_nn_layer(self): t1 = x_var.gradient() return y_np, t1 + def run_Conv2D_static(self, place): + paddle.seed(2023) + main = base.Program() + start = base.Program() + with base.unique_name.guard(): + with base.program_guard(main, start): + x_var = paddle.static.data( + "input", self.input.shape, dtype=self.dtype + ) + conv = nn.Conv2D( + self.num_channels, + self.num_filters, + self.filter_size, + padding=self.padding, + padding_mode=self.padding_mode, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + data_format=self.data_format, + ) + y_var = conv(x_var) + feed_dict = {"input": self.input} + exe = base.Executor(place) + exe.run(start) + (y_np,) = exe.run(main, feed=feed_dict, fetch_list=[y_var]) + return y_np + + def run_Conv2D_dygraph(self): + paddle.seed(2023) + x_var = paddle.to_tensor(self.input) + x_var.stop_gradient = False + conv = nn.Conv2D( + self.num_channels, + self.num_filters, + self.filter_size, + padding=self.padding, + padding_mode=self.padding_mode, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + data_format=self.data_format, + ) + y_var = conv(x_var) + y_np = y_var.numpy() + return y_np + def _test_equivalence(self, place): - place = base.CPUPlace() result1 = self.base_layer(place) result2 = self.functional(place) with dg.guard(place): @@ -227,13 +272,22 @@ def _test_equivalence(self, place): np.testing.assert_array_almost_equal(result1, result2) np.testing.assert_array_almost_equal(result2, result3) + def _test_equivalence_in_pir(self, place): + with paddle.pir_utils.IrGuard(): + result1 = self.run_Conv2D_static(place) + with dg.guard(place): + result2 = self.run_Conv2D_dygraph() + np.testing.assert_array_almost_equal(result1, result2) + def runTest(self): place = base.CPUPlace() self._test_equivalence(place) + self._test_equivalence_in_pir(place) if base.core.is_compiled_with_cuda(): place = base.CUDAPlace(0) self._test_equivalence(place) + self._test_equivalence_in_pir(place) class Conv2DErrorTestCase(Conv2DTestCase): diff --git a/test/legacy_test/test_conv2d_transpose_layer.py b/test/legacy_test/test_conv2d_transpose_layer.py index 78634d5124929..6f4a5bb3868c7 100644 --- a/test/legacy_test/test_conv2d_transpose_layer.py +++ b/test/legacy_test/test_conv2d_transpose_layer.py @@ -143,9 +143,12 @@ def functional(self, place): w_var = paddle.static.data( "weight", self.weight_shape, dtype=self.dtype ) - b_var = paddle.static.data( - "bias", (self.num_filters,), dtype=self.dtype - ) + if not self.no_bias: + b_var = paddle.static.data( + "bias", (self.num_filters,), dtype=self.dtype + ) + else: + b_var = None if self.output_padding != 0: output_size = None @@ -155,7 +158,7 @@ def functional(self, place): y_var = F.conv2d_transpose( x_var, w_var, - None if self.no_bias else b_var, + b_var, output_size=output_size, padding=self.padding, output_padding=self.output_padding, @@ -199,8 +202,6 @@ def paddle_nn_layer(self): return y_np def _test_equivalence(self, place): - place = base.CPUPlace() - result1 = self.base_layer(place) result2 = self.functional(place) @@ -210,13 +211,18 @@ def _test_equivalence(self, place): np.testing.assert_array_almost_equal(result1, result2) np.testing.assert_array_almost_equal(result2, result3) + def _test_pir_equivalence(self, place): + with paddle.pir_utils.IrGuard(): + result1 = self.functional(place) + with dg.guard(place): + result2 = self.paddle_nn_layer() + + np.testing.assert_array_almost_equal(result1, result2) + def runTest(self): place = base.CPUPlace() self._test_equivalence(place) - - if base.core.is_compiled_with_cuda(): - place = base.CUDAPlace(0) - self._test_equivalence(place) + self._test_pir_equivalence(place) class Conv2DTransposeErrorTestCase(Conv2DTransposeTestCase): diff --git a/test/legacy_test/test_conv2d_transpose_op.py b/test/legacy_test/test_conv2d_transpose_op.py index ef610d3af0516..339ef086d7b81 100644 --- a/test/legacy_test/test_conv2d_transpose_op.py +++ b/test/legacy_test/test_conv2d_transpose_op.py @@ -227,10 +227,15 @@ def test_check_output(self): if self.use_cudnn: place = core.CUDAPlace(0) self.check_output_with_place( - place, atol=1e-5, check_dygraph=(not self.use_mkldnn) + place, + atol=1e-5, + check_dygraph=(not self.use_mkldnn), + check_pir=True, ) else: - self.check_output(check_dygraph=(not self.use_mkldnn)) + self.check_output( + check_dygraph=(not self.use_mkldnn), check_pir=True + ) def test_check_grad_no_input(self): if self.need_check_grad: @@ -242,19 +247,28 @@ def test_check_grad_no_input(self): 'Output', max_relative_error=0.02, no_grad_set={'Input'}, + check_pir=True, ) else: - self.check_grad(['Filter'], 'Output', no_grad_set={'Input'}) + self.check_grad( + ['Filter'], 'Output', no_grad_set={'Input'}, check_pir=True + ) def test_check_grad_no_filter(self): if self.need_check_grad: if self.use_cudnn: place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['Input'], 'Output', no_grad_set={'Filter'} + place, + ['Input'], + 'Output', + no_grad_set={'Filter'}, + check_pir=True, ) else: - self.check_grad(['Input'], 'Output', no_grad_set={'Filter'}) + self.check_grad( + ['Input'], 'Output', no_grad_set={'Filter'}, check_pir=True + ) def test_check_grad(self): if self.need_check_grad: @@ -265,10 +279,14 @@ def test_check_grad(self): {'Input', 'Filter'}, 'Output', max_relative_error=0.02, + check_pir=True, ) else: self.check_grad( - {'Input', 'Filter'}, 'Output', max_relative_error=0.02 + {'Input', 'Filter'}, + 'Output', + max_relative_error=0.02, + check_pir=True, ) def init_test_case(self): @@ -781,10 +799,15 @@ def test_check_output(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_output_with_place( - place, atol=0.02, check_dygraph=(not self.use_mkldnn) + place, + atol=0.02, + check_dygraph=(not self.use_mkldnn), + check_pir=True, ) else: - self.check_output(check_dygraph=(not self.use_mkldnn)) + self.check_output( + check_dygraph=(not self.use_mkldnn), check_pir=True + ) def test_check_grad_no_input(self): if self.need_check_grad: @@ -797,9 +820,12 @@ def test_check_grad_no_input(self): 'Output', max_relative_error=0.02, no_grad_set={'Input'}, + check_pir=True, ) else: - self.check_grad(['Filter'], 'Output', no_grad_set={'Input'}) + self.check_grad( + ['Filter'], 'Output', no_grad_set={'Input'}, check_pir=True + ) def test_check_grad_no_filter(self): if self.need_check_grad: @@ -812,9 +838,12 @@ def test_check_grad_no_filter(self): 'Output', max_relative_error=0.02, no_grad_set={'Filter'}, + check_pir=True, ) else: - self.check_grad(['Input'], 'Output', no_grad_set={'Filter'}) + self.check_grad( + ['Input'], 'Output', no_grad_set={'Filter'}, check_pir=True + ) def test_check_grad(self): if self.need_check_grad: @@ -826,10 +855,14 @@ def test_check_grad(self): {'Input', 'Filter'}, 'Output', max_relative_error=0.02, + check_pir=True, ) else: self.check_grad( - {'Input', 'Filter'}, 'Output', max_relative_error=0.02 + {'Input', 'Filter'}, + 'Output', + max_relative_error=0.02, + check_pir=True, ) @@ -965,7 +998,10 @@ def init_op_type(self): def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place( - place, atol=0.02, check_dygraph=(not self.use_mkldnn) + place, + atol=0.02, + check_dygraph=(not self.use_mkldnn), + check_pir=True, ) def test_check_grad_no_input(self): @@ -978,6 +1014,7 @@ def test_check_grad_no_input(self): max_relative_error=0.02, no_grad_set={'Input'}, user_defined_grads=[numeric_grads], + check_pir=True, ) def test_check_grad_no_filter(self): @@ -990,6 +1027,7 @@ def test_check_grad_no_filter(self): max_relative_error=0.02, no_grad_set={'Filter'}, user_defined_grads=[numeric_grads], + check_pir=True, ) diff --git a/test/legacy_test/test_conv3d_layer.py b/test/legacy_test/test_conv3d_layer.py index da1a21edbc435..d514f56c2631a 100644 --- a/test/legacy_test/test_conv3d_layer.py +++ b/test/legacy_test/test_conv3d_layer.py @@ -137,13 +137,16 @@ def functional(self, place): w_var = paddle.static.data( "weight", self.weight_shape, dtype=self.dtype ) - b_var = paddle.static.data( - "bias", (self.num_filters,), dtype=self.dtype - ) + if not self.no_bias: + b_var = paddle.static.data( + "bias", (self.num_filters,), dtype=self.dtype + ) + else: + b_var = None y_var = F.conv3d( x_var, w_var, - None if self.no_bias else b_var, + b_var, padding=self.padding, stride=self.stride, dilation=self.dilation, @@ -181,7 +184,6 @@ def paddle_nn_layer(self): return y_np, t1 def _test_equivalence(self, place): - place = base.CPUPlace() result1 = self.base_layer(place) result2 = self.functional(place) with dg.guard(place): @@ -189,13 +191,22 @@ def _test_equivalence(self, place): np.testing.assert_array_almost_equal(result1, result2) np.testing.assert_array_almost_equal(result2, result3) + def _test_pir_equivalence(self, place): + with paddle.pir_utils.IrGuard(): + result1 = self.functional(place) + with dg.guard(place): + result2, g1 = self.paddle_nn_layer() + np.testing.assert_array_almost_equal(result1, result2) + def runTest(self): place = base.CPUPlace() self._test_equivalence(place) + self._test_pir_equivalence(place) if base.core.is_compiled_with_cuda(): place = base.CUDAPlace(0) self._test_equivalence(place) + self._test_pir_equivalence(place) class Conv3DErrorTestCase(Conv3DTestCase): diff --git a/test/legacy_test/test_conv3d_op.py b/test/legacy_test/test_conv3d_op.py index 14e7d9dc09930..26e58e026ebd4 100644 --- a/test/legacy_test/test_conv3d_op.py +++ b/test/legacy_test/test_conv3d_op.py @@ -208,7 +208,7 @@ def init_kernel_type(self): def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place( - place, check_dygraph=(not self.use_mkldnn) + place, check_dygraph=(not self.use_mkldnn), check_pir=True ) def test_check_grad_no_filter(self): @@ -222,6 +222,7 @@ def test_check_grad_no_filter(self): no_grad_set={'Filter'}, check_dygraph=(not self.use_mkldnn), user_defined_grads=[numeric_grads], + check_pir=True, ) def test_check_grad_no_input(self): @@ -235,6 +236,7 @@ def test_check_grad_no_input(self): no_grad_set={'Input'}, check_dygraph=(not self.use_mkldnn), user_defined_grads=[numeric_grads], + check_pir=True, ) def test_check_grad(self): @@ -248,6 +250,7 @@ def test_check_grad(self): 'Output', user_defined_grads=[numeric_input_grads, numeric_fliter_grads], check_dygraph=(not self.use_mkldnn), + check_pir=True, ) cls_name = "{}_{}".format(parent.__name__, "CUDNNBF16OP") @@ -448,7 +451,10 @@ def test_check_output(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace() self.check_output_with_place( - place, atol=1e-5, check_dygraph=(not self.use_mkldnn) + place, + atol=1e-5, + check_dygraph=(not self.use_mkldnn), + check_pir=True, ) def test_check_grad(self): @@ -460,6 +466,7 @@ def test_check_grad(self): 'Output', max_relative_error=0.03, check_dygraph=(not self.use_mkldnn), + check_pir=True, ) def test_check_grad_no_filter(self): @@ -472,6 +479,7 @@ def test_check_grad_no_filter(self): max_relative_error=0.03, no_grad_set={'Filter'}, check_dygraph=(not self.use_mkldnn), + check_pir=True, ) def test_check_grad_no_input(self): @@ -484,6 +492,7 @@ def test_check_grad_no_input(self): max_relative_error=0.03, no_grad_set={'Input'}, check_dygraph=(not self.use_mkldnn), + check_pir=True, ) def init_test_case(self): @@ -599,7 +608,7 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, atol=2e-2) + self.check_output_with_place(place, atol=2e-2, check_pir=True) @unittest.skipIf( @@ -623,7 +632,7 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, atol=2e-2) + self.check_output_with_place(place, atol=2e-2, check_pir=True) @unittest.skipIf( @@ -647,7 +656,7 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, atol=2e-2) + self.check_output_with_place(place, atol=2e-2, check_pir=True) @unittest.skipIf( @@ -671,7 +680,7 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, atol=2e-2) + self.check_output_with_place(place, atol=2e-2, check_pir=True) @unittest.skipIf( @@ -695,7 +704,7 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, atol=2e-2) + self.check_output_with_place(place, atol=2e-2, check_pir=True) class TestCUDNNExhaustiveSearch(TestCUDNN): @@ -771,14 +780,18 @@ def has_cudnn(self): def test_check_output(self): place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace() - self.check_output_with_place(place, atol=1e-5) + self.check_output_with_place(place, atol=1e-5, check_pir=True) def test_check_grad(self): if self.dtype == np.float16: return place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace() self.check_grad_with_place( - place, {'Input', 'Filter'}, 'Output', max_relative_error=0.03 + place, + {'Input', 'Filter'}, + 'Output', + max_relative_error=0.03, + check_pir=True, ) def test_check_grad_no_filter(self): @@ -791,6 +804,7 @@ def test_check_grad_no_filter(self): 'Output', max_relative_error=0.03, no_grad_set={'Filter'}, + check_pir=True, ) def test_check_grad_no_input(self): @@ -803,6 +817,7 @@ def test_check_grad_no_input(self): 'Output', max_relative_error=0.03, no_grad_set={'Input'}, + check_pir=True, ) def init_test_case(self): diff --git a/test/legacy_test/test_cosine_similarity_api.py b/test/legacy_test/test_cosine_similarity_api.py index 7fe78e42c7ab1..b563d5717eaba 100644 --- a/test/legacy_test/test_cosine_similarity_api.py +++ b/test/legacy_test/test_cosine_similarity_api.py @@ -18,14 +18,9 @@ import paddle import paddle.nn.functional as F -from paddle import nn -from paddle.base import ( - Executor, - Program, - core, - default_main_program, - program_guard, -) +from paddle import nn, static +from paddle.base import Executor, core +from paddle.pir_utils import test_with_pir_api class TestCosineSimilarityAPI(unittest.TestCase): @@ -42,10 +37,15 @@ def _get_numpy_out(self, x1, x2, axis=1, eps=1e-8): cos_sim = w12 / n12 return cos_sim + @test_with_pir_api def check_static_result(self, place): paddle.enable_static() - with program_guard(Program(), Program()): + main_program = static.Program() + startup_program = static.Program() + with static.program_guard( + main_program=main_program, startup_program=startup_program + ): shape = [10, 15] axis = 1 eps = 1e-8 @@ -58,7 +58,6 @@ def check_static_result(self, place): result = F.cosine_similarity(x1, x2, axis=axis, eps=eps) exe = Executor(place) fetches = exe.run( - default_main_program(), feed={"x1": np_x1, "x2": np_x2}, fetch_list=[result], ) diff --git a/test/legacy_test/test_cross_entropy_loss.py b/test/legacy_test/test_cross_entropy_loss.py index 78901bb75bf1b..5c50ceab5dbe6 100644 --- a/test/legacy_test/test_cross_entropy_loss.py +++ b/test/legacy_test/test_cross_entropy_loss.py @@ -21,6 +21,7 @@ import paddle from paddle import base from paddle.base import Program, program_guard +from paddle.pir_utils import test_with_pir_api def label_smooth(label, C, epsilon, is_onehot=True): @@ -272,6 +273,7 @@ def test_softmax_with_cross_entropy(self): # soft_label test start # soft_label test 1 + @test_with_pir_api def test_cross_entropy_loss_soft_1d(self): self.numeric_stable_mode = False self.soft_label = True @@ -360,6 +362,7 @@ def test_cross_entropy_loss_soft_1d(self): np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) # soft_label test 2 + @test_with_pir_api def test_cross_entropy_loss_soft_1d_weight(self): self.numeric_stable_mode = False self.soft_label = True @@ -460,6 +463,7 @@ def test_cross_entropy_loss_soft_1d_weight(self): np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) # soft_label test 3 + @test_with_pir_api def test_cross_entropy_loss_soft_1d_mean(self): self.numeric_stable_mode = False self.soft_label = True @@ -544,6 +548,7 @@ def test_cross_entropy_loss_soft_1d_mean(self): np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) # soft_label test 4 + @test_with_pir_api def test_cross_entropy_loss_soft_1d_weight_mean(self): self.numeric_stable_mode = False self.soft_label = True @@ -634,6 +639,7 @@ def test_cross_entropy_loss_soft_1d_weight_mean(self): np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) # soft_label test 5 + @test_with_pir_api def test_cross_entropy_loss_soft_2d(self): def inner_cross_entropy_loss_soft_2d(soft_label): self.numeric_stable_mode = False @@ -739,6 +745,7 @@ def inner_cross_entropy_loss_soft_2d(soft_label): inner_cross_entropy_loss_soft_2d(False) # soft_label test 6 + @test_with_pir_api def test_cross_entropy_loss_soft_2d_weight_mean(self): self.numeric_stable_mode = False self.soft_label = True @@ -840,6 +847,7 @@ def test_cross_entropy_loss_soft_2d_weight_mean(self): # soft_label test end # label_smoothing test 1 + @test_with_pir_api def test_cross_entropy_loss_onehot_label_smoothing_1d(self): self.numeric_stable_mode = False self.soft_label = True @@ -937,6 +945,7 @@ def test_cross_entropy_loss_onehot_label_smoothing_1d(self): paddle.enable_static() # label_smoothing test 2 + @test_with_pir_api def test_cross_entropy_loss_onehot_label_smoothing_1d_weight_mean(self): self.numeric_stable_mode = False self.soft_label = True @@ -1036,6 +1045,7 @@ def test_cross_entropy_loss_onehot_label_smoothing_1d_weight_mean(self): np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) # label_smoothing test 3 + @test_with_pir_api def test_cross_entropy_loss_onehot_label_smoothing_2d(self): self.numeric_stable_mode = False self.soft_label = True @@ -1143,6 +1153,7 @@ def test_cross_entropy_loss_onehot_label_smoothing_2d(self): np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) # label_smoothing test 4 + @test_with_pir_api def test_cross_entropy_loss_onehot_label_smoothing_2d_weight_mean(self): self.numeric_stable_mode = False self.soft_label = True @@ -1253,6 +1264,7 @@ def test_cross_entropy_loss_onehot_label_smoothing_2d_weight_mean(self): np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) # label_smoothing test 5 + @test_with_pir_api def test_cross_entropy_loss_integer_label_smoothing_1d(self): self.numeric_stable_mode = False self.soft_label = True @@ -1350,6 +1362,7 @@ def test_cross_entropy_loss_integer_label_smoothing_1d(self): np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) # label_smoothing test 6 + @test_with_pir_api def test_cross_entropy_loss_integer_label_smoothing_1d_weight_mean(self): self.numeric_stable_mode = False self.soft_label = True @@ -1452,6 +1465,7 @@ def test_cross_entropy_loss_integer_label_smoothing_1d_weight_mean(self): np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) # label_smoothing test 7 + @test_with_pir_api def test_cross_entropy_loss_integer_label_smoothing_2d(self): self.numeric_stable_mode = False self.soft_label = True @@ -1557,6 +1571,7 @@ def test_cross_entropy_loss_integer_label_smoothing_2d(self): np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) # label_smoothing test 8 + @test_with_pir_api def test_cross_entropy_loss_integer_label_smoothing_2d_weight_mean(self): self.numeric_stable_mode = False self.soft_label = True @@ -1667,7 +1682,7 @@ def test_cross_entropy_loss_integer_label_smoothing_2d_weight_mean(self): np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) # label_smoothing test end - + @test_with_pir_api def test_cross_entropy_loss_1d_with_mean_ignore(self): input_np = np.random.random([2, 4]).astype(self.dtype) label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) @@ -1714,6 +1729,7 @@ def test_cross_entropy_loss_1d_with_mean_ignore(self): np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_1d_with_mean_ignore_negative(self): N = 100 C = 200 @@ -1763,6 +1779,7 @@ def test_cross_entropy_loss_1d_with_mean_ignore_negative(self): np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_1d_with_weight_mean_ignore(self): N = 100 C = 200 @@ -1846,6 +1863,7 @@ def test_cross_entropy_loss_1d_with_weight_mean_ignore_exceedlabel(self): np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_1d_with_weight_mean(self): input_np = np.random.random([2, 4]).astype(self.dtype) label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) @@ -1901,6 +1919,7 @@ def test_cross_entropy_loss_1d_with_weight_mean(self): np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_1d_with_weight_sum(self): input_np = np.random.random([100, 200]).astype(self.dtype) # N,C label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) # N,1 @@ -1954,6 +1973,7 @@ def test_cross_entropy_loss_1d_with_weight_sum(self): np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_1d_with_weight_none(self): input_np = np.random.random([100, 200]).astype(self.dtype) # N,C label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) # N,1 @@ -2011,6 +2031,7 @@ def test_cross_entropy_loss_1d_with_weight_none(self): np.testing.assert_allclose(static_ret, expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_1d_with_weight_none_func(self): input_np = np.random.random([100, 200]).astype(self.dtype) # N,C label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) # N @@ -2064,6 +2085,7 @@ def test_cross_entropy_loss_1d_with_weight_none_func(self): np.testing.assert_allclose(static_ret, expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_1d_mean(self): input_np = np.random.random([100, 200]).astype(self.dtype) # N,C label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) # N,1 @@ -2102,6 +2124,7 @@ def test_cross_entropy_loss_1d_mean(self): np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_1d_sum(self): input_np = np.random.random([100, 200]).astype(self.dtype) # N,C label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) # N,1 @@ -2144,6 +2167,7 @@ def test_cross_entropy_loss_1d_sum(self): np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_1d_none(self): input_np = np.random.random([100, 200]).astype(self.dtype) # N,C label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) # N,1 @@ -2188,6 +2212,7 @@ def test_cross_entropy_loss_1d_none(self): np.testing.assert_allclose(static_ret, expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_2d_with_weight_none(self): input_np = np.random.random(size=(2, 2, 2, 3)).astype( self.dtype @@ -2250,6 +2275,7 @@ def test_cross_entropy_loss_2d_with_weight_none(self): np.testing.assert_allclose(static_ret, expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_2d_with_weight_axis_change_mean(self): input_np = np.random.random(size=(2, 3, 2, 2)).astype( self.dtype @@ -2341,6 +2367,7 @@ def test_cross_entropy_loss_2d_with_weight_mean_ignore_exceedlabel(self): )[0] np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_2d_with_weight_mean(self): input_np = np.random.random(size=(2, 2, 2, 3)).astype( self.dtype @@ -2400,6 +2427,7 @@ def test_cross_entropy_loss_2d_with_weight_mean(self): np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_2d_with_weight_sum(self): input_np = np.random.random(size=(2, 2, 2, 3)).astype( self.dtype @@ -2460,6 +2488,7 @@ def test_cross_entropy_loss_2d_with_weight_sum(self): np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_2d_none(self): input_np = np.random.random(size=(2, 2, 2, 3)).astype( self.dtype @@ -2513,6 +2542,7 @@ def test_cross_entropy_loss_2d_none(self): np.testing.assert_allclose(static_ret, expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_2d_mean(self): input_np = np.random.random(size=(2, 2, 2, 3)).astype( self.dtype @@ -2567,6 +2597,7 @@ def test_cross_entropy_loss_2d_mean(self): np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05) np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) + @test_with_pir_api def test_cross_entropy_loss_2d_sum(self): input_np = np.random.random(size=(2, 2, 2, 3)).astype( self.dtype diff --git a/test/legacy_test/test_cuda_graph.py b/test/legacy_test/test_cuda_graph.py index 58728ec476a2d..4e14e8b3c1df4 100644 --- a/test/legacy_test/test_cuda_graph.py +++ b/test/legacy_test/test_cuda_graph.py @@ -27,6 +27,10 @@ def can_use_cuda_graph(): return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or float(paddle.version.cuda()) < 11.0, + "only support cuda >= 11.0", +) class TestCUDAGraphInDygraphMode(unittest.TestCase): def setUp(self): if can_use_cuda_graph(): diff --git a/test/legacy_test/test_cuda_graph_partial_graph.py b/test/legacy_test/test_cuda_graph_partial_graph.py index 5737b6d7a08cb..e0cdf43f8627b 100644 --- a/test/legacy_test/test_cuda_graph_partial_graph.py +++ b/test/legacy_test/test_cuda_graph_partial_graph.py @@ -39,6 +39,10 @@ def forward(self, x): return x +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or float(paddle.version.cuda()) < 11.0, + "only support cuda >= 11.0", +) class TestSimpleModel(unittest.TestCase): def setUp(self): paddle.set_flags({'FLAGS_eager_delete_tensor_gb': 0.0}) diff --git a/test/legacy_test/test_cuda_graph_partial_graph_static.py b/test/legacy_test/test_cuda_graph_partial_graph_static.py index 214834423b797..39d6ba382bf5e 100644 --- a/test/legacy_test/test_cuda_graph_partial_graph_static.py +++ b/test/legacy_test/test_cuda_graph_partial_graph_static.py @@ -39,6 +39,10 @@ def forward(self, x): return x +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or float(paddle.version.cuda()) < 11.0, + "only support cuda >= 11.0", +) class TestCudaGraphAttrAll(unittest.TestCase): def test_all_program(self): if not is_cuda_graph_supported(): diff --git a/test/legacy_test/test_cuda_graph_partial_graph_static_run.py b/test/legacy_test/test_cuda_graph_partial_graph_static_run.py index 2e301bdbd94da..41841c4204c23 100644 --- a/test/legacy_test/test_cuda_graph_partial_graph_static_run.py +++ b/test/legacy_test/test_cuda_graph_partial_graph_static_run.py @@ -45,6 +45,10 @@ def forward(self, x): return x +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or float(paddle.version.cuda()) < 11.0, + "only support cuda >= 11.0", +) class TestCudaGraphAttrAll(unittest.TestCase): def setUp(self): paddle.set_flags({'FLAGS_eager_delete_tensor_gb': 0.0}) diff --git a/test/legacy_test/test_cuda_graph_static_mode.py b/test/legacy_test/test_cuda_graph_static_mode.py index 746a3db02c222..035b442729665 100644 --- a/test/legacy_test/test_cuda_graph_static_mode.py +++ b/test/legacy_test/test_cuda_graph_static_mode.py @@ -48,6 +48,10 @@ def build_program(main, startup, batch_size, class_num): return image, label, loss, lr +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or float(paddle.version.cuda()) < 11.0, + "only support cuda >= 11.0", +) class TestCUDAGraphInStaticMode(unittest.TestCase): def setUp(self): if can_use_cuda_graph(): diff --git a/test/legacy_test/test_cuda_graph_static_mode_error.py b/test/legacy_test/test_cuda_graph_static_mode_error.py index a718f1b7009bd..2c1b26c93db34 100644 --- a/test/legacy_test/test_cuda_graph_static_mode_error.py +++ b/test/legacy_test/test_cuda_graph_static_mode_error.py @@ -22,6 +22,10 @@ from paddle.device.cuda.graphs import CUDAGraph +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or float(paddle.version.cuda()) < 11.0, + "only support cuda >= 11.0", +) class TestCUDAGraphInFirstBatch(unittest.TestCase): def setUp(self): if can_use_cuda_graph(): diff --git a/test/legacy_test/test_cuda_graphed_layer.py b/test/legacy_test/test_cuda_graphed_layer.py new file mode 100644 index 0000000000000..5bfdd3c81f5c8 --- /dev/null +++ b/test/legacy_test/test_cuda_graphed_layer.py @@ -0,0 +1,90 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import nn +from paddle.device.cuda.cuda_graphed_layer import CUDAGraphedLayer + +seed = 102 + + +class Model(nn.Layer): + def __init__(self, in_size, out_size, dropout=0): + paddle.seed(seed) + super().__init__() + self.linear = nn.Linear(in_size, out_size) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.linear(x) + x = self.relu(x) + return x + + +class DropoutModel(nn.Layer): + def __init__(self, in_size, out_size, dropout=0.5): + paddle.seed(seed) + super().__init__() + self.linear = nn.Linear(in_size, out_size) + self.dropout_1 = paddle.nn.Dropout(dropout) + self.relu = nn.ReLU() + self.dropout_2 = paddle.nn.Dropout(dropout) + + def forward(self, x): + x = self.linear(x) + x = self.dropout_1(x) + x = self.relu(x) + x = self.dropout_2(x) + return x + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or float(paddle.version.cuda()) < 11.0, + "only support cuda >= 11.0", +) +class TestSimpleModel(unittest.TestCase): + def train(self, model): + paddle.seed(seed) + + ans = [] + for _ in range(10): + x = paddle.randn([3, 10], dtype='float32') + x.stop_gradient = False + loss = model(x).mean() + loss.backward() + ans.append(x.grad.numpy()) + + return np.array(ans) + + def test_layer(self): + model = Model(10, 20) + cuda_graphed_model = CUDAGraphedLayer(Model(10, 20)) + + dropout_model = DropoutModel(10, 20) + cuda_graphed_dropout_model = CUDAGraphedLayer(DropoutModel(10, 20)) + + np.testing.assert_array_equal( + self.train(model), self.train(cuda_graphed_model) + ) + np.testing.assert_array_equal( + self.train(dropout_model), self.train(cuda_graphed_dropout_model) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_cuda_max_memory_allocated.py b/test/legacy_test/test_cuda_max_memory_allocated.py index 90e016921f8a2..969489fa8f925 100644 --- a/test/legacy_test/test_cuda_max_memory_allocated.py +++ b/test/legacy_test/test_cuda_max_memory_allocated.py @@ -61,10 +61,10 @@ def test_max_memory_allocated_exception(self): "gpu1", ] for device in wrong_device: - with self.assertRaises(BaseException): + with self.assertRaises(BaseException): # noqa: B017 max_memory_allocated(device) else: - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): max_memory_allocated() diff --git a/test/legacy_test/test_cuda_max_memory_reserved.py b/test/legacy_test/test_cuda_max_memory_reserved.py index ac3b2b712e2ff..7f0a3f4da388f 100644 --- a/test/legacy_test/test_cuda_max_memory_reserved.py +++ b/test/legacy_test/test_cuda_max_memory_reserved.py @@ -61,10 +61,10 @@ def test_max_memory_reserved_exception(self): "gpu1", ] for device in wrong_device: - with self.assertRaises(BaseException): + with self.assertRaises(BaseException): # noqa: B017 max_memory_reserved(device) else: - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): max_memory_reserved() diff --git a/test/legacy_test/test_cuda_memory_allocated.py b/test/legacy_test/test_cuda_memory_allocated.py index 3e4c258940659..192126c092a4b 100644 --- a/test/legacy_test/test_cuda_memory_allocated.py +++ b/test/legacy_test/test_cuda_memory_allocated.py @@ -46,10 +46,10 @@ def test_memory_allocated_exception(self): "gpu1", ] for device in wrong_device: - with self.assertRaises(BaseException): + with self.assertRaises(BaseException): # noqa: B017 memory_allocated(device) else: - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): memory_allocated() diff --git a/test/legacy_test/test_cuda_memory_reserved.py b/test/legacy_test/test_cuda_memory_reserved.py index d639eab054ff5..8a02834f8fd3a 100644 --- a/test/legacy_test/test_cuda_memory_reserved.py +++ b/test/legacy_test/test_cuda_memory_reserved.py @@ -46,10 +46,10 @@ def test_memory_reserved_exception(self): "gpu1", ] for device in wrong_device: - with self.assertRaises(BaseException): + with self.assertRaises(BaseException): # noqa: B017 memory_reserved(device) else: - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): memory_reserved() diff --git a/test/legacy_test/test_cummax_op.py b/test/legacy_test/test_cummax_op.py index 91df4866a75a6..89429cf347096 100644 --- a/test/legacy_test/test_cummax_op.py +++ b/test/legacy_test/test_cummax_op.py @@ -21,6 +21,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def cummax_dim2(arr, axis=None): @@ -91,11 +92,11 @@ def set_attrs(self): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): paddle.enable_static() - self.check_grad(['x'], 'out') + self.check_grad(['x'], 'out', check_pir=True) class TestCummaxOpAxis1(TestCummaxOp): @@ -151,6 +152,7 @@ def run_cases(self): np.testing.assert_array_equal(z, y.numpy()) np.testing.assert_array_equal(ind, indices.numpy()) + @test_with_pir_api def run_static(self, use_gpu=False): with base.program_guard(base.Program()): data_np = np.random.random((100, 100)).astype(np.float32) @@ -163,20 +165,19 @@ def run_static(self, use_gpu=False): place = base.CUDAPlace(0) if use_gpu else base.CPUPlace() exe = base.Executor(place) - exe.run(base.default_startup_program()) out = exe.run( feed={'x': data_np}, fetch_list=[ - y1.name, - indices1.name, - y2.name, - indices2.name, - y3.name, - indices3.name, - y4.name, - indices4.name, - y5.name, - indices5.name, + y1, + indices1, + y2, + indices2, + y3, + indices3, + y4, + indices4, + y5, + indices5, ], ) @@ -218,6 +219,7 @@ def test_errors(self): paddle.enable_static() with base.program_guard(base.Program()): + @test_with_pir_api def test_x_type(): data = [1, 2, 3] y, indices = paddle.cummax(data, axis=0) diff --git a/test/legacy_test/test_cummin_op.py b/test/legacy_test/test_cummin_op.py index 416e4c48f0fc0..d8e5512cbf9b4 100644 --- a/test/legacy_test/test_cummin_op.py +++ b/test/legacy_test/test_cummin_op.py @@ -21,6 +21,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def cummin_dim2(arr, axis=None): @@ -91,11 +92,11 @@ def set_attrs(self): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): paddle.enable_static() - self.check_grad(['x'], 'out') + self.check_grad(['x'], 'out', check_pir=True) class TestCuinOpAxis1(TestCumminOp): @@ -151,6 +152,7 @@ def run_cases(self): np.testing.assert_array_equal(z, y.numpy()) np.testing.assert_array_equal(ind, indices.numpy()) + @test_with_pir_api def run_static(self, use_gpu=False): with base.program_guard(base.Program()): data_np = np.random.random((100, 100)).astype(np.float32) @@ -163,20 +165,19 @@ def run_static(self, use_gpu=False): place = base.CUDAPlace(0) if use_gpu else base.CPUPlace() exe = base.Executor(place) - exe.run(base.default_startup_program()) out = exe.run( feed={'x': data_np}, fetch_list=[ - y1.name, - indices1.name, - y2.name, - indices2.name, - y3.name, - indices3.name, - y4.name, - indices4.name, - y5.name, - indices5.name, + y1, + indices1, + y2, + indices2, + y3, + indices3, + y4, + indices4, + y5, + indices5, ], ) @@ -218,6 +219,7 @@ def test_errors(self): paddle.enable_static() with base.program_guard(base.Program()): + @test_with_pir_api def test_x_type(): data = [1, 2, 3] y, indices = paddle.cummin(data, axis=0) diff --git a/test/legacy_test/test_cumprod_op.py b/test/legacy_test/test_cumprod_op.py index da3db1ee1ef6f..019095c2dea0f 100644 --- a/test/legacy_test/test_cumprod_op.py +++ b/test/legacy_test/test_cumprod_op.py @@ -20,6 +20,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api np.random.seed(0) @@ -124,7 +125,7 @@ def test_check_output(self): for dim in range(-len(self.shape), len(self.shape)): for zero_num in self.zero_nums: self.prepare_inputs_outputs_attrs(dim, zero_num) - self.check_output() + self.check_output(check_pir=True) # test backward. def test_check_grad(self): @@ -133,13 +134,14 @@ def test_check_grad(self): self.prepare_inputs_outputs_attrs(dim, zero_num) self.init_grad_input_output(dim) if self.dtype == np.float64: - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) else: self.check_grad( ['X'], 'Out', user_defined_grads=[self.grad_x], user_defined_grad_outputs=[self.grad_out], + check_pir=True, ) @@ -217,6 +219,7 @@ def setUp(self): self.place.append(paddle.CUDAPlace(0)) # test static graph api. + @test_with_pir_api def test_static_api(self): paddle.enable_static() diff --git a/test/legacy_test/test_cumsum_op.py b/test/legacy_test/test_cumsum_op.py index be733d989a93f..ee853bd553eb0 100644 --- a/test/legacy_test/test_cumsum_op.py +++ b/test/legacy_test/test_cumsum_op.py @@ -23,6 +23,7 @@ import paddle.inference as paddle_infer from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestCumsumOp(unittest.TestCase): @@ -53,7 +54,7 @@ def run_cases(self): np.testing.assert_array_equal(z, y.numpy()) def run_static(self, use_gpu=False): - with base.program_guard(base.Program()): + with paddle.static.program_guard(paddle.static.Program()): data_np = np.random.random((100, 100)).astype(np.float32) x = paddle.static.data('X', [100, 100]) y = paddle.cumsum(x) @@ -65,16 +66,16 @@ def run_static(self, use_gpu=False): place = base.CUDAPlace(0) if use_gpu else base.CPUPlace() exe = base.Executor(place) - exe.run(base.default_startup_program()) + exe.run(paddle.static.default_startup_program()) out = exe.run( feed={'X': data_np}, fetch_list=[ - y.name, - y2.name, - y3.name, - y4.name, - y5.name, - y6.name, + y, + y2, + y3, + y4, + y5, + y6, ], ) @@ -89,20 +90,26 @@ def run_static(self, use_gpu=False): z = np.cumsum(data_np, axis=-2) np.testing.assert_allclose(z, out[5], rtol=1e-05) - def test_cpu(self): + def test_cpu_dygraph(self): paddle.disable_static(paddle.base.CPUPlace()) self.run_cases() paddle.enable_static() + @test_with_pir_api + def test_cpu_static(self): self.run_static() - def test_gpu(self): + def test_gpu_dygraph(self): if not base.core.is_compiled_with_cuda(): return paddle.disable_static(paddle.base.CUDAPlace(0)) self.run_cases() paddle.enable_static() + @test_with_pir_api + def test_gpu_static(self): + if not base.core.is_compiled_with_cuda(): + return self.run_static(use_gpu=True) def test_name(self): @@ -133,10 +140,10 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def init_dtype(self): self.dtype = self.dtype_ = np.float64 @@ -242,10 +249,10 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def init_dtype(self): self.dtype = self.dtype_ = np.float64 @@ -341,10 +348,10 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def init_dtype(self): self.dtype = np.float16 @@ -380,10 +387,10 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def init_dtype(self): self.dtype = self.dtype_ = np.float64 @@ -401,14 +408,10 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad( - ['X'], - 'Out', - check_prim=True, - ) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) cls_name = "{}_{}".format(parent.__name__, "Fp16") TestCumsumFP16Op.__name__ = cls_name @@ -445,12 +448,17 @@ def if_enable_cinn(self): def test_check_output(self): place = paddle.CUDAPlace(0) - self.check_output_with_place(place, check_prim=True) + self.check_output_with_place(place, check_prim=True, check_pir=True) def test_check_grad(self): place = paddle.CUDAPlace(0) self.check_grad_with_place( - place, ["X"], "Out", check_prim=True, numeric_grad_delta=0.05 + place, + ["X"], + "Out", + check_prim=True, + numeric_grad_delta=0.05, + check_pir=True, ) cls_name = "{}_{}".format(parent.__name__, "BF16") @@ -552,6 +560,7 @@ def test_static_and_infer(self): class TestCumSumOpFp16(unittest.TestCase): + @test_with_pir_api def test_fp16(self): paddle.enable_static() x_np = np.random.random((100, 100)).astype('float16') diff --git a/test/legacy_test/test_determinant_op.py b/test/legacy_test/test_determinant_op.py index f0066edf10424..2fe7217225f74 100644 --- a/test/legacy_test/test_determinant_op.py +++ b/test/legacy_test/test_determinant_op.py @@ -18,6 +18,7 @@ from op_test import OpTest import paddle +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -30,10 +31,10 @@ def setUp(self): self.outputs = {'Out': self.target} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['Input'], ['Out']) + self.check_grad(['Input'], ['Out'], check_pir=True) def init_data(self): np.random.seed(0) @@ -85,6 +86,7 @@ def setUp(self): self.x = np.random.random(self.shape).astype(np.float32) self.place = paddle.CPUPlace() + @test_with_pir_api def test_api_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): @@ -114,11 +116,13 @@ def setUp(self): self.outputs = {'Out': self.target} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): # the slog det's grad value is always huge - self.check_grad(['Input'], ['Out'], max_relative_error=0.1) + self.check_grad( + ['Input'], ['Out'], max_relative_error=0.1, check_pir=True + ) def init_data(self): np.random.seed(0) @@ -142,6 +146,7 @@ def setUp(self): self.x = np.random.random(self.shape).astype(np.float32) self.place = paddle.CPUPlace() + @test_with_pir_api def test_api_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): diff --git a/test/legacy_test/test_diag_embed.py b/test/legacy_test/test_diag_embed.py index 2f3869713f0e3..ab2955f9d4405 100644 --- a/test/legacy_test/test_diag_embed.py +++ b/test/legacy_test/test_diag_embed.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import unittest import numpy as np from op_test import OpTest, paddle_static_guard import paddle -import paddle.nn.functional as F from paddle import base from paddle.base import core @@ -26,7 +26,7 @@ class TestDiagEmbedOp(OpTest): def setUp(self): self.op_type = "diag_embed" - self.python_api = F.diag_embed + self.python_api = paddle.diag_embed self.init_config() self.outputs = {'Out': self.target} @@ -57,8 +57,8 @@ def test_case1(self): data1 = paddle.static.data( name='data1', shape=[2, 3, 4], dtype='float32' ) - out1 = F.diag_embed(data1) - out2 = F.diag_embed(data1, offset=1, dim1=-2, dim2=3) + out1 = paddle.diag_embed(data1) + out2 = paddle.diag_embed(data1, offset=1, dim1=-2, dim2=3) place = core.CPUPlace() exe = base.Executor(place) @@ -77,6 +77,11 @@ def test_case1(self): np.testing.assert_allclose(results[0], target1, rtol=1e-05) np.testing.assert_allclose(results[1], target2, rtol=1e-05) + def test_tensor_method(self): + paddle.disable_static() + x = paddle.arange(15).reshape((3, 5)).astype('float64') + self.assertTrue(inspect.ismethod(x.diag_embed)) + if __name__ == "__main__": unittest.main() diff --git a/test/legacy_test/test_diag_v2.py b/test/legacy_test/test_diag_v2.py index 2458a280ce039..bfbd0f65ef107 100644 --- a/test/legacy_test/test_diag_v2.py +++ b/test/legacy_test/test_diag_v2.py @@ -51,11 +51,11 @@ def init_input_output(self): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): paddle.enable_static() - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestDiagV2OpCase1(TestDiagV2Op): @@ -335,12 +335,12 @@ def setUp(self): def test_check_output(self): paddle.enable_static() place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): paddle.enable_static() place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out') + self.check_grad_with_place(place, ['X'], 'Out', check_pir=True) if __name__ == "__main__": diff --git a/test/legacy_test/test_digamma_op.py b/test/legacy_test/test_digamma_op.py index 04bb768a5b179..d5b2ce37b6df1 100644 --- a/test/legacy_test/test_digamma_op.py +++ b/test/legacy_test/test_digamma_op.py @@ -21,6 +21,7 @@ import paddle from paddle import base, static from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestDigammaOp(OpTest): @@ -42,10 +43,10 @@ def init_dtype_type(self): self.dtype = np.float64 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestDigammaOpFp32(TestDigammaOp): @@ -53,7 +54,7 @@ def init_dtype_type(self): self.dtype = np.float32 def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestDigammaFP16Op(TestDigammaOp): @@ -87,10 +88,12 @@ def init_dtype_type(self): def test_check_output(self): # bfloat16 needs to set the parameter place - self.check_output_with_place(core.CUDAPlace(0)) + self.check_output_with_place(core.CUDAPlace(0), check_pir=True) def test_check_grad_normal(self): - self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Out') + self.check_grad_with_place( + core.CUDAPlace(0), ['X'], 'Out', check_pir=True + ) class TestDigammaAPI(unittest.TestCase): @@ -104,6 +107,7 @@ def setUp(self): self.places.append(paddle.CUDAPlace(0)) self._shape = [8, 3, 32, 32] + @test_with_pir_api def test_in_static_mode(self): def init_input_output(dtype): input = np.random.random(self._shape).astype(dtype) @@ -117,7 +121,7 @@ def init_input_output(dtype): out = paddle.digamma(x) exe = static.Executor(place) - out_value = exe.run(feed=input_dict, fetch_list=[out.name]) + out_value = exe.run(feed=input_dict, fetch_list=[out]) np.testing.assert_allclose(out_value[0], sc_res, rtol=1e-05) def test_in_dynamic_mode(self): diff --git a/test/legacy_test/test_dist_op.py b/test/legacy_test/test_dist_op.py index bd2e9c828e144..af7576df1dec0 100644 --- a/test/legacy_test/test_dist_op.py +++ b/test/legacy_test/test_dist_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -113,13 +114,11 @@ def get_reduce_dims(x, y): return x_grad, y_grad def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( - ["X", "Y"], - "Out", - user_defined_grads=self.gradient, + ["X", "Y"], "Out", user_defined_grads=self.gradient, check_pir=True ) @@ -244,10 +243,11 @@ def init_data_type(self): 'float32' if core.is_compiled_with_rocm() else 'float64' ) + @test_with_pir_api def test_api(self): self.init_data_type() - main_program = base.Program() - startup_program = base.Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() with base.program_guard(main_program, startup_program): x = paddle.static.data( name='x', shape=[2, 3, 4, 5], dtype=self.data_type @@ -266,7 +266,7 @@ def test_api(self): ) exe = base.Executor(place) out = exe.run( - base.default_main_program(), + main_program, feed={'x': x_i, 'y': y_i}, fetch_list=[result], ) diff --git a/test/legacy_test/test_dist_train.py b/test/legacy_test/test_dist_train.py index c1d8e5426db35..668096e3bbe16 100644 --- a/test/legacy_test/test_dist_train.py +++ b/test/legacy_test/test_dist_train.py @@ -67,7 +67,7 @@ def _wait_ps_ready(self, pid): # on the /tmp directory until it was ready to process all the RPC call. os.stat("/tmp/paddle.%d.port" % pid) return - except os.error: + except OSError: start_left_time -= sleep_time def init_serv(self, place): diff --git a/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py b/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py index 62a94832d1ae9..9133577bddb2e 100644 --- a/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py +++ b/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py @@ -96,14 +96,14 @@ def test_1_new_comm(self): run_test( clip_after_allreduce=True, max_global_norm=0.01, - need_env={"FLAGS_dynamic_static_unified_comm": "1"}, + need_env={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_2_new_comm(self): run_test( clip_after_allreduce=False, max_global_norm=0.01, - need_env={"FLAGS_dynamic_static_unified_comm": "1"}, + need_env={"FLAGS_dynamic_static_unified_comm": "true"}, ) diff --git a/test/legacy_test/test_distributed_fused_lamb_op_with_gradient_merge.py b/test/legacy_test/test_distributed_fused_lamb_op_with_gradient_merge.py index f236be3a8d150..279c2dd101631 100644 --- a/test/legacy_test/test_distributed_fused_lamb_op_with_gradient_merge.py +++ b/test/legacy_test/test_distributed_fused_lamb_op_with_gradient_merge.py @@ -38,7 +38,7 @@ def test_gm_new_comm(self): clip_after_allreduce=True, max_global_norm=-1.0, gradient_merge_steps=2, - need_env={"FLAGS_dynamic_static_unified_comm": "1"}, + need_env={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_gm_with_fp16_acc_grad_new_comm(self): @@ -47,7 +47,7 @@ def test_gm_with_fp16_acc_grad_new_comm(self): max_global_norm=-1.0, gradient_merge_steps=2, use_master_acc_grad=False, - need_env={"FLAGS_dynamic_static_unified_comm": "1"}, + need_env={"FLAGS_dynamic_static_unified_comm": "true"}, ) diff --git a/test/legacy_test/test_dot_op.py b/test/legacy_test/test_dot_op.py index 5ad707fadef22..3b1a216add6da 100644 --- a/test/legacy_test/test_dot_op.py +++ b/test/legacy_test/test_dot_op.py @@ -230,20 +230,22 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, atol=0.125) + self.check_output_with_place(place, atol=0.125, check_pir=True) def test_check_grad_normal(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_grad_with_place(place, ['X', 'Y'], 'Out') + self.check_grad_with_place( + place, ['X', 'Y'], 'Out', check_pir=True + ) def test_check_grad_ingore_x(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_grad_with_place( - place, ['Y'], 'Out', no_grad_set=set("X") + place, ['Y'], 'Out', no_grad_set=set("X"), check_pir=True ) def test_check_grad_ingore_y(self): @@ -251,7 +253,7 @@ def test_check_grad_ingore_y(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_grad_with_place( - place, ['X'], 'Out', no_grad_set=set("Y") + place, ['X'], 'Out', no_grad_set=set("Y"), check_pir=True ) def init_input_output(self): @@ -302,7 +304,7 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_bfloat16_supported(place): - self.check_output_with_place(place, atol=0.5) + self.check_output_with_place(place, atol=0.5, check_pir=True) def test_check_grad_normal(self): if core.is_compiled_with_cuda(): @@ -313,6 +315,7 @@ def test_check_grad_normal(self): ['X', 'Y'], 'Out', user_defined_grads=[self.inputs['Y'], self.inputs['X']], + check_pir=True, ) def test_check_grad_ingore_x(self): @@ -325,6 +328,7 @@ def test_check_grad_ingore_x(self): 'Out', no_grad_set=set("X"), user_defined_grads=[self.inputs['X']], + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -337,6 +341,7 @@ def test_check_grad_ingore_y(self): 'Out', no_grad_set=set("Y"), user_defined_grads=[self.inputs['Y']], + check_pir=True, ) def init_input_output(self): @@ -374,6 +379,7 @@ def test_check_grad_normal(self): self.y / self.y.shape[0], self.x / self.x.shape[0], ], + check_pir=True, ) def test_check_grad_ingore_x(self): @@ -386,6 +392,7 @@ def test_check_grad_ingore_x(self): 'Out', no_grad_set=set("X"), user_defined_grads=[self.x / self.x.shape[0]], + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -398,6 +405,7 @@ def test_check_grad_ingore_y(self): 'Out', no_grad_set=set("Y"), user_defined_grads=[self.y / self.y.shape[0]], + check_pir=True, ) diff --git a/test/legacy_test/test_dropout_op.py b/test/legacy_test/test_dropout_op.py index 3ef733101a671..433b9eeff7056 100644 --- a/test/legacy_test/test_dropout_op.py +++ b/test/legacy_test/test_dropout_op.py @@ -26,6 +26,7 @@ from paddle.base.executor import scope_guard from paddle.decomposition import decompose from paddle.incubate.autograd import primapi +from paddle.pir_utils import test_with_pir_api def dropout_wapper( @@ -523,9 +524,11 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog): input = paddle.static.data( name="input", shape=[-1, -1], dtype="float32" ) @@ -574,7 +577,6 @@ def check_static_result(self, place): training=False, mode='downscale_in_infer', ) - res10 = paddle.nn.functional.dropout(x=input, p=1.0, training=True) res11 = paddle.nn.functional.dropout(x=input, p=0.0) res12 = paddle.nn.functional.dropout( x=input, @@ -584,13 +586,8 @@ def check_static_result(self, place): mode='upscale_in_train', ) - res13 = paddle.nn.functional.dropout( - x=input, p=0.7, axis=1, training=True, mode='upscale_in_train' - ) - in_np = np.ones([40, 40]).astype("float32") res_np = in_np - res_np2 = np.zeros_like(in_np) exe = base.Executor(place) res_list = [ @@ -608,26 +605,39 @@ def check_static_result(self, place): ] for res in res_list: fetches = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, fetch_list=[res], ) np.testing.assert_allclose(fetches[0], res_np, rtol=1e-05) + + @test_with_pir_api + def check_static_result2(self, place): + paddle.enable_static() + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog): + input = paddle.static.data( + name="input", shape=[-1, -1], dtype="float32" + ) + res10 = paddle.nn.functional.dropout(x=input, p=1.0, training=True) + res13 = paddle.nn.functional.dropout( + x=input, p=0.7, axis=1, training=True, mode='upscale_in_train' + ) + in_np = np.ones([40, 40]).astype("float32") + res_np2 = np.zeros_like(in_np) + + exe = base.Executor(place) fetches2 = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, - fetch_list=[res10], + fetch_list=[res10, res13], ) np.testing.assert_allclose(fetches2[0], res_np2, rtol=1e-05) - fetches3 = exe.run( - base.default_main_program(), - feed={"input": in_np}, - fetch_list=[res13], - ) def test_static(self): for place in self.places: self.check_static_result(place=place) + self.check_static_result2(place=place) def test_dygraph(self): for place in self.places: @@ -769,6 +779,13 @@ def test_dtype(): self.assertRaises(TypeError, test_dtype) + @test_with_pir_api + def test_errors2(self): + paddle.enable_static() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + def test_pdtype(): # p should be int or float x2 = paddle.static.data( @@ -861,9 +878,12 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): input = paddle.static.data( name="input", shape=[2, 3, 4, 5], dtype="float32" ) @@ -881,7 +901,7 @@ def check_static_result(self, place): res_list = [res1, res2] for res in res_list: fetches = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, fetch_list=[res], ) @@ -911,9 +931,12 @@ def test_dygraph(self): class TestDropout2DFAPIError(unittest.TestCase): + @test_with_pir_api def test_errors(self): paddle.enable_static() - with program_guard(Program(), Program()): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): def test_xdim(): # dimentions of x should be 4 @@ -954,6 +977,7 @@ def test_dygraph(self): result.numpy(), result_np, rtol=1e-05 ) + @test_with_pir_api def test_static_fp16_with_gpu(self): if paddle.base.core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) @@ -986,9 +1010,12 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): input = paddle.static.data( name="input", shape=[2, 3, 4, 5, 6], dtype="float32" ) @@ -1006,7 +1033,7 @@ def check_static_result(self, place): res_list = [res1, res2] for res in res_list: fetches = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, fetch_list=[res], ) @@ -1036,9 +1063,12 @@ def test_dygraph(self): class TestDropout3DFAPIError(unittest.TestCase): + @test_with_pir_api def test_errors(self): paddle.enable_static() - with program_guard(Program(), Program()): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): def test_xdim(): # dimentions of x should be 5 @@ -1087,8 +1117,12 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): - with base.program_guard(base.Program(), base.Program()): + paddle.enable_static() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): input = paddle.static.data( name="input", shape=[40, 40], dtype="float32" ) @@ -1103,20 +1137,15 @@ def check_static_result(self, place): res_np3 = np.zeros_like(in_np) exe = base.Executor(place) - res_list = [res1, res2] - for res in res_list: - fetches = exe.run( - base.default_main_program(), - feed={"input": in_np}, - fetch_list=[res], - ) - np.testing.assert_allclose(fetches[0], res_np, rtol=1e-05) + fetches = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, - fetch_list=[res3], + fetch_list=[res1, res2, res3], ) - np.testing.assert_allclose(fetches[0], res_np3, rtol=1e-05) + np.testing.assert_allclose(fetches[0], res_np, rtol=1e-05) + np.testing.assert_allclose(fetches[1], res_np, rtol=1e-05) + np.testing.assert_allclose(fetches[2], res_np3, rtol=1e-05) def test_static(self): for place in self.places: @@ -1155,6 +1184,13 @@ def test_Variable(): self.assertRaises(TypeError, test_Variable) + @test_with_pir_api + def test_errors2(self): + paddle.enable_static() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + def test_dtype(): # the input dtype of dropout must be float32 or float64 xr = paddle.static.data( @@ -1203,6 +1239,7 @@ def test_dygraph(self): result.numpy(), result_np, rtol=1e-05 ) + @test_with_pir_api def test_static_fp16_gpu(self): if paddle.base.core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) @@ -1362,9 +1399,9 @@ def api_case(self, x): def run_static(self, x): paddle.seed(2022) - main_program = Program() paddle.enable_static() - with program_guard(main_program): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): input = paddle.static.data(shape=x.shape, name='x', dtype='float32') out = self.api_case(input) sgd = paddle.optimizer.SGD(learning_rate=0.1) diff --git a/test/legacy_test/test_eager_run_program.py b/test/legacy_test/test_eager_run_program.py index 3014b791b47e1..be00a4f83c05c 100644 --- a/test/legacy_test/test_eager_run_program.py +++ b/test/legacy_test/test_eager_run_program.py @@ -152,7 +152,7 @@ def test_eager(self): ) _legacy_C_ops.run_program( - [x_t, y_t], [fake_var], [out_t], [scope], [fake_var], None, *attrs + [x_t, y_t], [fake_var], [out_t], [scope], None, *attrs ) loss = paddle.mean(out_t) diff --git a/test/legacy_test/test_eig_op.py b/test/legacy_test/test_eig_op.py index c5ba7262902c7..c6b57258fc820 100644 --- a/test/legacy_test/test_eig_op.py +++ b/test/legacy_test/test_eig_op.py @@ -183,7 +183,7 @@ def init_grad(self): def test_check_output(self): self.check_output_with_place_customized( - checker=self.checker, place=core.CPUPlace() + checker=self.checker, place=core.CPUPlace(), check_pir=True ) def test_check_grad(self): @@ -193,6 +193,7 @@ def test_check_grad(self): ['Eigenvalues', 'Eigenvectors'], user_defined_grads=[self.grad_x], user_defined_grad_outputs=[self.grad_w, self.grad_v], + check_pir=True, ) @@ -319,6 +320,7 @@ def test_check_grad(self): test_type = 'float64' paddle.set_device("cpu") + np.random.seed(1024) input_np = np.random.random(test_shape).astype(test_type) real_w, real_v = np.linalg.eig(input_np) diff --git a/test/legacy_test/test_eigh_op.py b/test/legacy_test/test_eigh_op.py index 12042b89a17f0..004d1e164a8cd 100644 --- a/test/legacy_test/test_eigh_op.py +++ b/test/legacy_test/test_eigh_op.py @@ -18,6 +18,7 @@ from op_test import OpTest import paddle +from paddle.pir_utils import test_with_pir_api def valid_eigh_result(A, eigh_value, eigh_vector, uplo): @@ -92,7 +93,7 @@ def init_input(self): # self.check_output(no_check_set=['Eigenvectors']) def test_grad(self): - self.check_grad(["X"], ["Eigenvalues"]) + self.check_grad(["X"], ["Eigenvalues"], check_pir=True) class TestEighUPLOCase(TestEighOp): @@ -183,6 +184,7 @@ def check_static_complex_result(self): ) valid_eigh_result(self.complex_symm, actual_w, actual_v, self.UPLO) + @test_with_pir_api def test_in_static_mode(self): paddle.enable_static() self.check_static_float_result() diff --git a/test/legacy_test/test_eigvals_op.py b/test/legacy_test/test_eigvals_op.py index 6f3f126b2db3e..c54a4070be3a4 100644 --- a/test/legacy_test/test_eigvals_op.py +++ b/test/legacy_test/test_eigvals_op.py @@ -37,6 +37,7 @@ class TestEigvalsOp(OpTest): def setUp(self): np.random.seed(0) paddle.enable_static() + self.python_api = paddle.linalg.eigvals self.op_type = "eigvals" self.set_dtype() self.set_input_dims() @@ -67,7 +68,7 @@ def set_input_data(self): def test_check_output(self): self.__class__.no_need_check_grad = True self.check_output_with_place_customized( - checker=self.verify_output, place=core.CPUPlace() + checker=self.verify_output, place=core.CPUPlace(), check_pir=True ) def verify_output(self, outs): @@ -326,13 +327,13 @@ def test_cases(self): def test_error(self): paddle.disable_static() x = paddle.to_tensor([1]) - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): paddle.linalg.eigvals(x) self.input_dims = [1, 2, 3, 4] self.set_input_data() x = paddle.to_tensor(self.input_data) - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): paddle.linalg.eigvals(x) diff --git a/test/legacy_test/test_eigvalsh_op.py b/test/legacy_test/test_eigvalsh_op.py index 654702f856188..7ca95874b3e11 100644 --- a/test/legacy_test/test_eigvalsh_op.py +++ b/test/legacy_test/test_eigvalsh_op.py @@ -18,6 +18,7 @@ from op_test import OpTest import paddle +from paddle.pir_utils import test_with_pir_api def compare_result(actual, expected): @@ -72,10 +73,10 @@ def init_input(self): def test_check_output(self): # Vectors in posetive or negative is equivalent - self.check_output(no_check_set=['Eigenvectors']) + self.check_output(no_check_set=['Eigenvectors'], check_pir=True) def test_grad(self): - self.check_grad(["X"], ["Eigenvalues"]) + self.check_grad(["X"], ["Eigenvalues"], check_pir=True) class TestEigvalshUPLOCase(TestEigvalshOp): @@ -166,6 +167,7 @@ def check_static_complex_result(self): expected_w = np.linalg.eigvalsh(self.complex_symm) compare_result(actual_w[0], expected_w) + @test_with_pir_api def test_in_static_mode(self): paddle.enable_static() self.check_static_float_result() diff --git a/test/legacy_test/test_einsum_v2.py b/test/legacy_test/test_einsum_v2.py index 4009518329d0e..81a46e52add6c 100644 --- a/test/legacy_test/test_einsum_v2.py +++ b/test/legacy_test/test_einsum_v2.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api os.environ['FLAGS_new_einsum'] = "1" @@ -382,7 +383,7 @@ def check_output_equal(self, actual, expect, rtol=1.0e-5, atol=1.0e-8): rtol=rtol, atol=atol, err_msg=error_msg.format( - paddle.get_device(), expect, actual, self.__class__.__name__ + self._get_place(False), expect, actual, self.__class__.__name__ ), ) @@ -465,6 +466,7 @@ def test_sums(self): self.check_output("i,ij->", y, x) self.check_output("ij,i->", x, y) + @test_with_pir_api def test_static_graph(self): paddle.enable_static() base = paddle.base @@ -523,11 +525,12 @@ def setUp(self): def tearDown(self): paddle.disable_static() + @test_with_pir_api def test_shape(self): A = paddle.static.data(name='x', shape=[-1]) B = paddle.static.data(name='y', shape=[384]) C = paddle.einsum('i,d->id', A, B) - self.assertEqual(C.shape, (-1, 384)) + self.assertEqual(tuple(C.shape), (-1, 384)) @unittest.skipIf( diff --git a/test/legacy_test/test_elementwise_add_op.py b/test/legacy_test/test_elementwise_add_op.py index 34e5e264aa3d9..097f046392af6 100644 --- a/test/legacy_test/test_elementwise_add_op.py +++ b/test/legacy_test/test_elementwise_add_op.py @@ -744,16 +744,16 @@ def init_input_output(self): self.out = self.x + self.y def test_check_output(self): - self.check_output(check_pir=False) + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', check_pir=False) + self.check_grad(['X', 'Y'], 'Out', check_pir=True) def test_check_grad_ingore_x(self): - self.check_grad(['Y'], 'Out', no_grad_set=set("X"), check_pir=False) + self.check_grad(['Y'], 'Out', no_grad_set=set("X"), check_pir=True) def test_check_grad_ingore_y(self): - self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_pir=False) + self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_pir=True) class TestRealComplexElementwiseAddOp(TestComplexElementwiseAddOp): @@ -772,7 +772,11 @@ def test_static_add(self): b = paddle.full([4, 5, 6], True, dtype='bool') c = a + b self.assertTrue(c.dtype == core.VarDesc.VarType.FP32) - paddle.enable_static() + with paddle.pir_utils.IrGuard(): + a = 1.5 + b = paddle.full([4, 5, 6], True, dtype='bool') + c = a + b + self.assertTrue(c.dtype == core.DataType.FLOAT32) def test_dygraph_add(self): paddle.disable_static() diff --git a/test/legacy_test/test_elementwise_mod_op.py b/test/legacy_test/test_elementwise_mod_op.py index bb9348b358ebd..ba6a75c9e6ac8 100644 --- a/test/legacy_test/test_elementwise_mod_op.py +++ b/test/legacy_test/test_elementwise_mod_op.py @@ -45,9 +45,9 @@ def setUp(self): def test_check_output(self): if self.attrs['axis'] == -1: - self.check_output() + self.check_output(check_pir=True) else: - self.check_output() + self.check_output(check_pir=True) def init_input_output(self): self.x = np.random.uniform(0, 10000, [10, 10]).astype(self.dtype) @@ -102,9 +102,9 @@ def init_input_output(self): def test_check_output(self): if self.attrs['axis'] == -1: - self.check_output() + self.check_output(check_pir=True) else: - self.check_output() + self.check_output(check_pir=True) @unittest.skipIf( @@ -121,9 +121,9 @@ def init_input_output(self): def test_check_output(self): if self.attrs['axis'] == -1: - self.check_output() + self.check_output(check_pir=True) else: - self.check_output() + self.check_output(check_pir=True) class TestElementwiseModFP16Op_ZeroDim1(TestElementwiseModFP16Op): @@ -181,7 +181,7 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def init_dtype(self): self.dtype = np.uint16 diff --git a/test/legacy_test/test_elementwise_sub_op.py b/test/legacy_test/test_elementwise_sub_op.py index 8a466e77c0ec7..29185c1844bf4 100644 --- a/test/legacy_test/test_elementwise_sub_op.py +++ b/test/legacy_test/test_elementwise_sub_op.py @@ -23,6 +23,7 @@ from paddle import base from paddle.base import core from paddle.base.layer_helper import LayerHelper +from paddle.pir_utils import test_with_pir_api class TestElementwiseOp(OpTest): @@ -903,8 +904,9 @@ def test_name(self): y_1 = self._executed_api(x, y, name='subtract_res') self.assertEqual(('subtract_res' in y_1.name), True) + @test_with_pir_api def test_declarative(self): - with base.program_guard(base.Program()): + with paddle.static.program_guard(paddle.static.Program()): def gen_data(): return { @@ -917,7 +919,10 @@ def gen_data(): z = self._executed_api(x, y) place = base.CPUPlace() exe = base.Executor(place) - z_value = exe.run(feed=gen_data(), fetch_list=[z.name]) + if paddle.framework.in_pir_mode(): + z_value = exe.run(feed=gen_data(), fetch_list=[z]) + else: + z_value = exe.run(feed=gen_data(), fetch_list=[z.name]) z_expected = np.array([1.0, -2.0, 2.0]) self.assertEqual((z_value == z_expected).all(), True) diff --git a/test/legacy_test/test_erfinv_op.py b/test/legacy_test/test_erfinv_op.py index 3108f8520d532..e9eb1d668ada8 100644 --- a/test/legacy_test/test_erfinv_op.py +++ b/test/legacy_test/test_erfinv_op.py @@ -44,7 +44,7 @@ def init_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -52,6 +52,7 @@ def test_check_grad(self): 'Out', user_defined_grads=[self.gradient], user_defined_grad_outputs=self.grad_out, + check_pir=True, ) @@ -143,15 +144,11 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place( - place, - ['X'], - 'Out', - ) + self.check_grad_with_place(place, ['X'], 'Out', check_pir=True) if __name__ == "__main__": diff --git a/test/legacy_test/test_expand_as_v2_op.py b/test/legacy_test/test_expand_as_v2_op.py index 13aa6863b9bd6..6b11c2f8dee99 100755 --- a/test/legacy_test/test_expand_as_v2_op.py +++ b/test/legacy_test/test_expand_as_v2_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestExpandAsBasic(OpTest): @@ -48,10 +49,10 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True) + self.check_output(check_prim=True, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) class TestExpandAs_ZeroDim1(TestExpandAsBasic): @@ -104,11 +105,11 @@ def if_enable_cinn(self): self.enable_cinn = False def test_check_output(self): - self.check_output_with_place(place=paddle.CUDAPlace(0)) + self.check_output_with_place(place=paddle.CUDAPlace(0), check_pir=True) def test_check_grad(self): self.check_grad_with_place( - paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True + paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True, check_pir=True ) @@ -242,7 +243,7 @@ def setUp(self): self.outputs = {'Out': convert_float_to_uint16(output)} def test_check_output(self): - self.check_output_with_place(place=paddle.CUDAPlace(0)) + self.check_output_with_place(place=paddle.CUDAPlace(0), check_pir=True) def test_check_grad(self): pass @@ -261,26 +262,28 @@ def test_errors(self): # Test python API class TestExpandAsV2API(unittest.TestCase): + @test_with_pir_api def test_api(self): - input1 = np.random.random([12, 14]).astype("float32") - input2 = np.random.random([2, 12, 14]).astype("float32") - x = paddle.static.data(name='x', shape=[12, 14], dtype="float32") - - y = paddle.static.data( - name='target_tensor', - shape=[2, 12, 14], - dtype="float32", - ) - - out_1 = paddle.expand_as(x, y=y) - - exe = base.Executor(place=base.CPUPlace()) - res_1 = exe.run( - base.default_main_program(), - feed={"x": input1, "target_tensor": input2}, - fetch_list=[out_1], - ) - np.testing.assert_array_equal(res_1[0], np.tile(input1, (2, 1, 1))) + with paddle.static.program_guard(paddle.static.Program()): + input1 = np.random.random([12, 14]).astype("float32") + input2 = np.random.random([2, 12, 14]).astype("float32") + x = paddle.static.data(name='x', shape=[12, 14], dtype="float32") + + y = paddle.static.data( + name='target_tensor', + shape=[2, 12, 14], + dtype="float32", + ) + + out_1 = paddle.expand_as(x, y=y) + + exe = base.Executor(place=base.CPUPlace()) + res_1 = exe.run( + paddle.static.default_main_program(), + feed={"x": input1, "target_tensor": input2}, + fetch_list=[out_1], + ) + np.testing.assert_array_equal(res_1[0], np.tile(input1, (2, 1, 1))) if __name__ == "__main__": diff --git a/test/legacy_test/test_expand_v2_op.py b/test/legacy_test/test_expand_v2_op.py index 988043d472e25..5e37043bd4f98 100644 --- a/test/legacy_test/test_expand_v2_op.py +++ b/test/legacy_test/test_expand_v2_op.py @@ -22,6 +22,7 @@ import paddle from paddle import base from paddle.base import Program, core, program_guard +from paddle.pir_utils import test_with_pir_api # Situation 1: shape is a list(without tensor) @@ -313,35 +314,35 @@ def test_errors(self): # Test python API class TestExpandV2API(unittest.TestCase): + @test_with_pir_api def test_api(self): - input = np.random.random([12, 14]).astype("float32") - x = paddle.static.data(name='x', shape=[12, 14], dtype="float32") - - positive_2 = paddle.tensor.fill_constant([1], "int32", 12) - expand_shape = paddle.static.data( - name="expand_shape", - shape=[2], - dtype="int32", - ) - - out_1 = paddle.expand(x, shape=[12, 14]) - out_2 = paddle.expand(x, shape=[positive_2, 14]) - out_3 = paddle.expand(x, shape=expand_shape) - - g0 = base.backward.calc_gradient(out_2, x) + with paddle.static.program_guard(paddle.static.Program()): + input = np.random.random([12, 14]).astype("float32") + x = paddle.static.data(name='x', shape=[12, 14], dtype="float32") + + positive_2 = paddle.tensor.fill_constant([1], "int32", 12) + expand_shape = paddle.static.data( + name="expand_shape", + shape=[2], + dtype="int32", + ) - exe = base.Executor(place=base.CPUPlace()) - res_1, res_2, res_3 = exe.run( - base.default_main_program(), - feed={ - "x": input, - "expand_shape": np.array([12, 14]).astype("int32"), - }, - fetch_list=[out_1, out_2, out_3], - ) - np.testing.assert_array_equal(res_1, np.tile(input, (1, 1))) - np.testing.assert_array_equal(res_2, np.tile(input, (1, 1))) - np.testing.assert_array_equal(res_3, np.tile(input, (1, 1))) + out_1 = paddle.expand(x, shape=[12, 14]) + out_2 = paddle.expand(x, shape=[positive_2, 14]) + out_3 = paddle.expand(x, shape=expand_shape) + + exe = base.Executor(place=base.CPUPlace()) + res_1, res_2, res_3 = exe.run( + paddle.static.default_main_program(), + feed={ + "x": input, + "expand_shape": np.array([12, 14]).astype("int32"), + }, + fetch_list=[out_1, out_2, out_3], + ) + np.testing.assert_array_equal(res_1, np.tile(input, (1, 1))) + np.testing.assert_array_equal(res_2, np.tile(input, (1, 1))) + np.testing.assert_array_equal(res_3, np.tile(input, (1, 1))) class TestExpandInferShape(unittest.TestCase): diff --git a/test/legacy_test/test_exponential_op.py b/test/legacy_test/test_exponential_op.py index 2eac134124c0a..1df9276590a0f 100644 --- a/test/legacy_test/test_exponential_op.py +++ b/test/legacy_test/test_exponential_op.py @@ -411,7 +411,7 @@ def config(self): def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place_customized( - checker=self.verify_output, place=place + checker=self.verify_output, place=place, check_pir=True ) def verify_output(self, outs): diff --git a/test/legacy_test/test_fill_constant_op.py b/test/legacy_test/test_fill_constant_op.py index 7ea153d627cbd..d898567291a99 100644 --- a/test/legacy_test/test_fill_constant_op.py +++ b/test/legacy_test/test_fill_constant_op.py @@ -524,8 +524,8 @@ def test_shape_type(): # The shape dtype of fill_constant_op must be int32 or int64. # test_shape_tensor_dtype: with paddle.pir_utils.IrGuard(): - new_ir_program = paddle.static.Program() - with paddle.static.program_guard(new_ir_program): + pir_program = paddle.static.Program() + with paddle.static.program_guard(pir_program): shape = paddle.static.data( name="shape_tensor", shape=[2], dtype="int32" ) diff --git a/test/legacy_test/test_flatten_contiguous_range_op.py b/test/legacy_test/test_flatten_contiguous_range_op.py index 83354d87b705b..82ba03f559efc 100644 --- a/test/legacy_test/test_flatten_contiguous_range_op.py +++ b/test/legacy_test/test_flatten_contiguous_range_op.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestFlattenOp(OpTest): @@ -461,6 +462,7 @@ class TestStaticFlattenPythonAPI(unittest.TestCase): def execute_api(self, x, start_axis=0, stop_axis=-1): return paddle.flatten(x, start_axis, stop_axis) + @test_with_pir_api def test_static_api(self): paddle.enable_static() np_x = np.random.rand(2, 3, 4, 4).astype('float32') @@ -481,6 +483,7 @@ class TestStaticFlattenInferShapePythonAPI(unittest.TestCase): def execute_api(self, x, start_axis=0, stop_axis=-1): return paddle.flatten(x, start_axis, stop_axis) + @test_with_pir_api def test_static_api(self): paddle.enable_static() main_prog = paddle.static.Program() @@ -489,7 +492,7 @@ def test_static_api(self): name="x", shape=[-1, 3, -1, -1], dtype='float32' ) out = self.execute_api(x, start_axis=2, stop_axis=3) - self.assertTrue((-1, 3, -1) == out.shape) + self.assertTrue((-1, 3, -1) == tuple(out.shape)) class TestStaticInplaceFlattenPythonAPI(TestStaticFlattenPythonAPI): diff --git a/test/legacy_test/test_flip.py b/test/legacy_test/test_flip.py index 4e5cc58ad3312..e4f729ded8234 100644 --- a/test/legacy_test/test_flip.py +++ b/test/legacy_test/test_flip.py @@ -100,10 +100,10 @@ def init_attrs(self): self.attrs = {"axis": self.axis} def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): - self.check_grad(["X"], "Out", check_cinn=True) + self.check_grad(["X"], "Out", check_cinn=True, check_pir=True) def init_test_case(self): self.in_shape = (6, 4, 2, 3) @@ -167,12 +167,16 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, check_cinn=True) + self.check_output_with_place( + place, check_cinn=True, check_pir=True + ) def test_check_grad(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_grad_with_place(place, ["X"], "Out", check_cinn=True) + self.check_grad_with_place( + place, ["X"], "Out", check_cinn=True, check_pir=True + ) cls_name = "{}_{}".format(parent.__name__, "FP16OP") TestFlipFP16.__name__ = cls_name @@ -202,12 +206,12 @@ def init_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) if core.is_bfloat16_supported(place): - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) if core.is_bfloat16_supported(place): - self.check_grad_with_place(place, ["X"], "Out") + self.check_grad_with_place(place, ["X"], "Out", check_pir=True) cls_name = "{}_{}".format(parent.__name__, "BF16OP") TestFlipBF16.__name__ = cls_name diff --git a/test/legacy_test/test_fmax_op.py b/test/legacy_test/test_fmax_op.py index bc5272134f238..1ca313857c03a 100644 --- a/test/legacy_test/test_fmax_op.py +++ b/test/legacy_test/test_fmax_op.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api class ApiFMaxTest(unittest.TestCase): @@ -43,6 +44,7 @@ def setUp(self): self.np_expected3 = np.fmax(self.input_a, self.input_c) self.np_expected4 = np.fmax(self.input_b, self.input_c) + @test_with_pir_api def test_static_api(self): """test_static_api""" paddle.enable_static() @@ -145,11 +147,11 @@ def setUp(self): def test_check_output(self): """test_check_output""" - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): """test_check_grad_normal""" - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_pir=True) def test_check_grad_ingore_x(self): """test_check_grad_ingore_x""" @@ -158,6 +160,7 @@ def test_check_grad_ingore_x(self): 'Out', max_relative_error=0.005, no_grad_set=set("X"), + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -167,6 +170,7 @@ def test_check_grad_ingore_y(self): 'Out', max_relative_error=0.005, no_grad_set=set('Y'), + check_pir=True, ) @@ -190,11 +194,11 @@ def setUp(self): def test_check_output(self): """test_check_output""" - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): """test_check_grad_normal""" - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_pir=True) def test_check_grad_ingore_x(self): """test_check_grad_ingore_x""" @@ -203,6 +207,7 @@ def test_check_grad_ingore_x(self): 'Out', max_relative_error=0.005, no_grad_set=set("X"), + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -212,6 +217,7 @@ def test_check_grad_ingore_y(self): 'Out', max_relative_error=0.005, no_grad_set=set('Y'), + check_pir=True, ) @@ -234,11 +240,11 @@ def setUp(self): def test_check_output(self): """test_check_output""" - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): """test_check_grad_normal""" - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_pir=True) @unittest.skipIf( @@ -263,11 +269,11 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Y'], 'Out') + self.check_grad_with_place(place, ['X', 'Y'], 'Out', check_pir=True) if __name__ == "__main__": diff --git a/test/legacy_test/test_fmin_op.py b/test/legacy_test/test_fmin_op.py index 88d4b8252f3d1..0a828c9c97395 100644 --- a/test/legacy_test/test_fmin_op.py +++ b/test/legacy_test/test_fmin_op.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -45,6 +46,7 @@ def setUp(self): self.np_expected3 = np.fmin(self.input_a, self.input_c) self.np_expected4 = np.fmin(self.input_b, self.input_c) + @test_with_pir_api def test_static_api(self): """test_static_api""" paddle.enable_static() @@ -147,11 +149,11 @@ def setUp(self): def test_check_output(self): """test_check_output""" - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): """test_check_grad_normal""" - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_pir=True) def test_check_grad_ingore_x(self): """test_check_grad_ingore_x""" @@ -160,6 +162,7 @@ def test_check_grad_ingore_x(self): 'Out', max_relative_error=0.005, no_grad_set=set("X"), + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -169,6 +172,7 @@ def test_check_grad_ingore_y(self): 'Out', max_relative_error=0.005, no_grad_set=set('Y'), + check_pir=True, ) @@ -192,11 +196,11 @@ def setUp(self): def test_check_output(self): """test_check_output""" - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): """test_check_grad_normal""" - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_pir=True) def test_check_grad_ingore_x(self): """test_check_grad_ingore_x""" @@ -205,6 +209,7 @@ def test_check_grad_ingore_x(self): 'Out', max_relative_error=0.005, no_grad_set=set("X"), + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -214,6 +219,7 @@ def test_check_grad_ingore_y(self): 'Out', max_relative_error=0.005, no_grad_set=set('Y'), + check_pir=True, ) @@ -236,11 +242,11 @@ def setUp(self): def test_check_output(self): """test_check_output""" - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): """test_check_grad_normal""" - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_pir=True) @unittest.skipIf( @@ -265,11 +271,11 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Y'], 'Out') + self.check_grad_with_place(place, ['X', 'Y'], 'Out', check_pir=True) if __name__ == "__main__": diff --git a/test/legacy_test/test_frac_api.py b/test/legacy_test/test_frac_api.py index 26bc74225e54b..1d401066cee2f 100644 --- a/test/legacy_test/test_frac_api.py +++ b/test/legacy_test/test_frac_api.py @@ -18,7 +18,8 @@ import paddle from paddle import base -from paddle.base import Program, core, program_guard +from paddle.base import core +from paddle.pir_utils import test_with_pir_api def ref_frac(x): @@ -40,15 +41,13 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_api_static(self): paddle.enable_static() - with program_guard(Program()): + with paddle.static.program_guard(paddle.static.Program()): input = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) out = paddle.frac(input) - place = base.CPUPlace() - if base.core.is_compiled_with_cuda(): - place = base.CUDAPlace(0) - exe = base.Executor(place) + exe = base.Executor(self.place) (res,) = exe.run(feed={'X': self.x_np}, fetch_list=[out]) out_ref = ref_frac(self.x_np) np.testing.assert_allclose(out_ref, res, rtol=1e-05) @@ -101,6 +100,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_error(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): diff --git a/test/legacy_test/test_full_like_op.py b/test/legacy_test/test_full_like_op.py index 5cbcc3f5c78aa..2126f532d9149 100644 --- a/test/legacy_test/test_full_like_op.py +++ b/test/legacy_test/test_full_like_op.py @@ -22,6 +22,7 @@ from paddle.base import core from paddle.base.framework import convert_np_dtype_to_dtype_ from paddle.framework import in_pir_mode +from paddle.pir_utils import test_with_pir_api from paddle.static import Program, program_guard @@ -45,11 +46,12 @@ def fill_any_like_wrapper(x, value, out_dtype=None, name=None): class TestFullOp(unittest.TestCase): """Test fill_any_like op(whose API is full_like) for attr out.""" + @test_with_pir_api def test_attr_tensor_API(self): paddle.enable_static() - startup_program = Program() - train_program = Program() - with program_guard(train_program, startup_program): + startup_program = paddle.static.Program() + train_program = paddle.static.Program() + with paddle.static.program_guard(train_program, startup_program): fill_value = 2.0 input = paddle.static.data( name='input', dtype='float32', shape=[2, 3] diff --git a/test/legacy_test/test_full_op.py b/test/legacy_test/test_full_op.py index 74e928e58a52a..0281d41252a27 100644 --- a/test/legacy_test/test_full_op.py +++ b/test/legacy_test/test_full_op.py @@ -19,60 +19,63 @@ import paddle from paddle import base from paddle.base import Program, program_guard +from paddle.pir_utils import test_with_pir_api # Test python API class TestFullAPI(unittest.TestCase): + @test_with_pir_api def test_api(self): - positive_2_int32 = paddle.tensor.fill_constant([1], "int32", 2) + with paddle.static.program_guard(paddle.static.Program()): + positive_2_int32 = paddle.tensor.fill_constant([1], "int32", 2) - positive_2_int64 = paddle.tensor.fill_constant([1], "int64", 2) - shape_tensor_int32 = paddle.static.data( - name="shape_tensor_int32", shape=[2], dtype="int32" - ) + positive_2_int64 = paddle.tensor.fill_constant([1], "int64", 2) + shape_tensor_int32 = paddle.static.data( + name="shape_tensor_int32", shape=[2], dtype="int32" + ) - shape_tensor_int64 = paddle.static.data( - name="shape_tensor_int64", shape=[2], dtype="int64" - ) + shape_tensor_int64 = paddle.static.data( + name="shape_tensor_int64", shape=[2], dtype="int64" + ) - out_1 = paddle.full(shape=[1, 2], dtype="float32", fill_value=1.1) + out_1 = paddle.full(shape=[1, 2], dtype="float32", fill_value=1.1) - out_2 = paddle.full( - shape=[1, positive_2_int32], dtype="float32", fill_value=1.1 - ) + out_2 = paddle.full( + shape=[1, positive_2_int32], dtype="float32", fill_value=1.1 + ) - out_3 = paddle.full( - shape=[1, positive_2_int64], dtype="float32", fill_value=1.1 - ) + out_3 = paddle.full( + shape=[1, positive_2_int64], dtype="float32", fill_value=1.1 + ) - out_4 = paddle.full( - shape=shape_tensor_int32, dtype="float32", fill_value=1.2 - ) + out_4 = paddle.full( + shape=shape_tensor_int32, dtype="float32", fill_value=1.2 + ) - out_5 = paddle.full( - shape=shape_tensor_int64, dtype="float32", fill_value=1.1 - ) + out_5 = paddle.full( + shape=shape_tensor_int64, dtype="float32", fill_value=1.1 + ) - out_6 = paddle.full( - shape=shape_tensor_int64, dtype=np.float32, fill_value=1.1 - ) + out_6 = paddle.full( + shape=shape_tensor_int64, dtype=np.float32, fill_value=1.1 + ) - val = paddle.tensor.fill_constant( - shape=[1], dtype=np.float32, value=1.1 - ) - out_7 = paddle.full( - shape=shape_tensor_int64, dtype=np.float32, fill_value=val - ) + val = paddle.tensor.fill_constant( + shape=[1], dtype=np.float32, value=1.1 + ) + out_7 = paddle.full( + shape=shape_tensor_int64, dtype=np.float32, fill_value=val + ) - exe = base.Executor(place=base.CPUPlace()) - res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run( - base.default_main_program(), - feed={ - "shape_tensor_int32": np.array([1, 2]).astype("int32"), - "shape_tensor_int64": np.array([1, 2]).astype("int64"), - }, - fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7], - ) + exe = base.Executor(place=base.CPUPlace()) + res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run( + paddle.static.default_main_program(), + feed={ + "shape_tensor_int32": np.array([1, 2]).astype("int32"), + "shape_tensor_int64": np.array([1, 2]).astype("int64"), + }, + fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7], + ) np.testing.assert_array_equal( res_1, np.full([1, 2], 1.1, dtype="float32") diff --git a/test/legacy_test/test_fused_attention_op_api.py b/test/legacy_test/test_fused_attention_op_api.py index decba63d496cd..1570c0b0dd733 100644 --- a/test/legacy_test/test_fused_attention_op_api.py +++ b/test/legacy_test/test_fused_attention_op_api.py @@ -18,7 +18,7 @@ import paddle from paddle.incubate.nn.layer.fused_transformer import FusedMultiHeadAttention -from paddle.static import Program +from paddle.pir_utils import test_with_pir_api def fc(x, weight): @@ -553,9 +553,12 @@ def run_static(self): ln_2_bias, ) + @test_with_pir_api def test_static_api(self): paddle.enable_static() - with paddle.static.program_guard(Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): ( out, qkv_weight, diff --git a/test/legacy_test/test_fused_feedforward_op.py b/test/legacy_test/test_fused_feedforward_op.py index 6ac4c8b09f63f..7c2d21b0f8027 100644 --- a/test/legacy_test/test_fused_feedforward_op.py +++ b/test/legacy_test/test_fused_feedforward_op.py @@ -23,6 +23,7 @@ from paddle.nn.layer import transformer from paddle.nn.layer.common import Dropout, Linear from paddle.nn.layer.norm import LayerNorm +from paddle.pir_utils import test_with_pir_api class TestFusedFFNOp(OpTest): @@ -211,85 +212,155 @@ def getShape(self): class APITestStaticFusedFFN(unittest.TestCase): - def test_static(self): - paddle.enable_static() - default_main_program().random_seed = 42 - dtype = "float32" - layer_norm_dtype = "float32" - batch_size = 1 - d_model = 8 - dim_feedforward = 8 - - x = paddle.static.data( - name='x', shape=[batch_size, d_model, dim_feedforward], dtype=dtype - ) - linear1_weight = paddle.static.data( - name='linear1_weight', shape=[d_model, dim_feedforward], dtype=dtype - ) - linear1_bias = paddle.static.data( - name='linear1_bias', shape=[dim_feedforward] - ) - linear2_weight = paddle.static.data( - name='linear2_weight', shape=[dim_feedforward, d_model], dtype=dtype - ) - linear2_bias = paddle.static.data(name='linear2_bias', shape=[d_model]) - ln1_scale = paddle.static.data(name='ln1_scale', shape=[d_model]) - ln1_bias = paddle.static.data(name='ln1_scale', shape=[d_model]) - ln2_scale = paddle.static.data(name='ln2_scale', shape=[d_model]) - ln2_bias = paddle.static.data(name='ln2_scale', shape=[d_model]) - - fused_out = incubate_f.fused_feedforward( - x, - linear1_weight, - linear2_weight, - linear1_bias, - linear2_bias, - ln1_scale, - ln1_bias, - ln2_scale, - ln2_bias, - 0.0, - 0.0, - activation="relu", - pre_layer_norm=False, - ) + def setUp(self): + self.dtype = "float32" + self.layer_norm_dtype = "float32" + self.batch_size = 1 + self.d_model = 8 + self.dim_feedforward = 8 - # base ffn - linear1_out = F.linear(x, linear1_weight, linear1_bias) - act_out = F.relu(linear1_out) - dropout1_out = F.dropout(x=act_out, p=0.0, training=False) - linear2_out = F.linear(dropout1_out, linear2_weight, linear2_bias) - dropout2_out = x + F.dropout(x=linear2_out, p=0.0, training=False) - ln_out = F.layer_norm( - dropout2_out, - normalized_shape=[d_model], - weight=ln2_scale, - bias=ln2_bias, - ) + def run_fused_feedforward( + self, + x_data, + linear1_weight_data, + linear1_bias_data, + linear2_weight_data, + linear2_bias_data, + ln1_scale_data, + ln1_bias_data, + ln2_scale_data, + ln2_bias_data, + ): + main = paddle.static.Program() + startup = paddle.static.Program() + main.random_seed = 42 + with paddle.static.program_guard(main, startup): + x = paddle.static.data( + name='x', + shape=[self.batch_size, self.d_model, self.dim_feedforward], + dtype=self.dtype, + ) + linear1_weight = paddle.static.data( + name='linear1_weight', + shape=[self.d_model, self.dim_feedforward], + dtype=self.dtype, + ) + linear1_bias = paddle.static.data( + name='linear1_bias', shape=[self.dim_feedforward] + ) + linear2_weight = paddle.static.data( + name='linear2_weight', + shape=[self.dim_feedforward, self.d_model], + dtype=self.dtype, + ) + linear2_bias = paddle.static.data( + name='linear2_bias', shape=[self.d_model] + ) + ln1_scale = paddle.static.data( + name='ln1_scale', shape=[self.d_model] + ) + ln1_bias = paddle.static.data(name='ln1_bias', shape=[self.d_model]) + ln2_scale = paddle.static.data( + name='ln2_scale', shape=[self.d_model] + ) + ln2_bias = paddle.static.data(name='ln2_bias', shape=[self.d_model]) + + fused_out = incubate_f.fused_feedforward( + x, + linear1_weight, + linear2_weight, + linear1_bias, + linear2_bias, + ln1_scale, + ln1_bias, + ln2_scale, + ln2_bias, + 0.0, + 0.0, + activation="relu", + pre_layer_norm=False, + ) - exe = paddle.static.Executor(paddle.CUDAPlace(0)) + exe = paddle.static.Executor(paddle.CUDAPlace(0)) - x_data = np.random.random( - (batch_size, d_model, dim_feedforward) - ).astype(dtype) - linear1_weight_data = np.random.random( - (d_model, dim_feedforward) - ).astype(dtype) - linear1_bias_data = np.zeros(dim_feedforward).astype(dtype) - linear2_weight_data = np.random.random( - (dim_feedforward, d_model) - ).astype(dtype) - linear2_bias_data = np.zeros(d_model).astype(dtype) + fetch = exe.run( + feed={ + 'x': x_data, + 'linear1_weight': linear1_weight_data, + 'linear1_bias': linear1_bias_data, + 'linear2_weight': linear2_weight_data, + 'linear2_bias': linear2_bias_data, + 'ln1_scale': ln1_scale_data, + 'ln1_bias': ln1_bias_data, + 'ln2_scale': ln2_scale_data, + 'ln2_bias': ln2_bias_data, + }, + fetch_list=[fused_out], + ) - ln1_scale_data = np.ones(d_model).astype(layer_norm_dtype) - ln1_bias_data = np.zeros(d_model).astype(layer_norm_dtype) - ln2_scale_data = np.ones(d_model).astype(layer_norm_dtype) - ln2_bias_data = np.zeros(d_model).astype(layer_norm_dtype) + return fetch + + def run_base_ffn( + self, + x_data, + linear1_weight_data, + linear1_bias_data, + linear2_weight_data, + linear2_bias_data, + ln1_scale_data, + ln1_bias_data, + ln2_scale_data, + ln2_bias_data, + ): + main = paddle.static.Program() + startup = paddle.static.Program() + main.random_seed = 42 + with paddle.static.program_guard(main, startup): + x = paddle.static.data( + name='x', + shape=[self.batch_size, self.d_model, self.dim_feedforward], + dtype=self.dtype, + ) + linear1_weight = paddle.static.data( + name='linear1_weight', + shape=[self.d_model, self.dim_feedforward], + dtype=self.dtype, + ) + linear1_bias = paddle.static.data( + name='linear1_bias', shape=[self.dim_feedforward] + ) + linear2_weight = paddle.static.data( + name='linear2_weight', + shape=[self.dim_feedforward, self.d_model], + dtype=self.dtype, + ) + linear2_bias = paddle.static.data( + name='linear2_bias', shape=[self.d_model] + ) + ln1_scale = paddle.static.data( + name='ln1_scale', shape=[self.d_model] + ) + ln1_bias = paddle.static.data(name='ln1_bias', shape=[self.d_model]) + ln2_scale = paddle.static.data( + name='ln2_scale', shape=[self.d_model] + ) + ln2_bias = paddle.static.data(name='ln2_bias', shape=[self.d_model]) + + # base ffn + linear1_out = F.linear(x, linear1_weight, linear1_bias) + act_out = F.relu(linear1_out) + dropout1_out = F.dropout(x=act_out, p=0.0, training=False) + linear2_out = F.linear(dropout1_out, linear2_weight, linear2_bias) + dropout2_out = x + F.dropout(x=linear2_out, p=0.0, training=False) + ln_out = F.layer_norm( + dropout2_out, + normalized_shape=[self.d_model], + weight=ln2_scale, + bias=ln2_bias, + ) - res_list = [fused_out, ln_out] - real_res = [] + exe = paddle.static.Executor(paddle.CUDAPlace(0)) - for res in res_list: fetch = exe.run( feed={ 'x': x_data, @@ -302,15 +373,63 @@ def test_static(self): 'ln2_scale': ln2_scale_data, 'ln2_bias': ln2_bias_data, }, - fetch_list=[res], + fetch_list=[ln_out], ) - real_res.append(fetch) + + return fetch + + @test_with_pir_api + def test_static(self): + paddle.enable_static() + + x_data = np.random.random( + (self.batch_size, self.d_model, self.dim_feedforward) + ).astype(self.dtype) + linear1_weight_data = np.random.random( + (self.d_model, self.dim_feedforward) + ).astype(self.dtype) + linear1_bias_data = np.zeros(self.dim_feedforward).astype(self.dtype) + linear2_weight_data = np.random.random( + (self.dim_feedforward, self.d_model) + ).astype(self.dtype) + linear2_bias_data = np.zeros(self.d_model).astype(self.dtype) + + ln1_scale_data = np.ones(self.d_model).astype(self.layer_norm_dtype) + ln1_bias_data = np.zeros(self.d_model).astype(self.layer_norm_dtype) + ln2_scale_data = np.ones(self.d_model).astype(self.layer_norm_dtype) + ln2_bias_data = np.zeros(self.d_model).astype(self.layer_norm_dtype) + + fused_feedforward_res = self.run_fused_feedforward( + x_data, + linear1_weight_data, + linear1_bias_data, + linear2_weight_data, + linear2_bias_data, + ln1_scale_data, + ln1_bias_data, + ln2_scale_data, + ln2_bias_data, + ) + + base_ffn_res = self.run_base_ffn( + x_data, + linear1_weight_data, + linear1_bias_data, + linear2_weight_data, + linear2_bias_data, + ln1_scale_data, + ln1_bias_data, + ln2_scale_data, + ln2_bias_data, + ) + np.testing.assert_allclose( - real_res[0], real_res[1], rtol=1e-05, atol=0.001 + fused_feedforward_res, base_ffn_res, rtol=1e-05, atol=0.001 ) class TestFusedFFNOpError(unittest.TestCase): + @test_with_pir_api def test_errors(self): paddle.enable_static() with paddle.static.program_guard( diff --git a/test/legacy_test/test_fusion_seqconv_eltadd_relu_op.py b/test/legacy_test/test_fusion_seqconv_eltadd_relu_op.py index 7cfc0da1ebe47..b4b2471d95da9 100644 --- a/test/legacy_test/test_fusion_seqconv_eltadd_relu_op.py +++ b/test/legacy_test/test_fusion_seqconv_eltadd_relu_op.py @@ -56,7 +56,7 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_dygraph=False) class TestSeqConvEltAddReluBS1(TestSeqConvEltAddRelu): diff --git a/test/legacy_test/test_gather_nd_op.py b/test/legacy_test/test_gather_nd_op.py index 3a27faf99cb6b..7d1dea17e20eb 100644 --- a/test/legacy_test/test_gather_nd_op.py +++ b/test/legacy_test/test_gather_nd_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestGatherNdOpWithEmptyIndex(OpTest): @@ -561,6 +562,7 @@ def test_check_grad(self): # Test Python API class TestGatherNdOpAPI(unittest.TestCase): + @test_with_pir_api def test_case1(self): x1 = paddle.static.data( name='x1', shape=[-1, 30, 40, 50, 60], dtype='float32' @@ -570,6 +572,7 @@ def test_case1(self): ) output1 = paddle.gather_nd(x1, index1) + @test_with_pir_api def test_case2(self): x2 = paddle.static.data( name='x2', shape=[-1, 30, 40, 50], dtype='float32' @@ -579,6 +582,7 @@ def test_case2(self): ) output2 = paddle.gather_nd(x2, index2) + @test_with_pir_api def test_case3(self): x3 = paddle.static.data(name='x3', shape=[-1, 3, 4, 5], dtype='float32') index3 = paddle.static.data( @@ -589,6 +593,7 @@ def test_case3(self): # Test Raise Index Error class TestGatherNdOpRaise(unittest.TestCase): + @test_with_pir_api def test_check_raise(self): def check_raise_is_test(): try: @@ -638,16 +643,15 @@ def test_index_dtype(): class TestGatherNdAPI2(unittest.TestCase): + @test_with_pir_api def test_static(self): with base.program_guard(base.Program(), base.Program()): data1 = paddle.static.data('data1', shape=[-1, 2], dtype='float64') - data1.desc.set_need_check_feed(False) index = paddle.static.data('index', shape=[-1, 1], dtype='int32') - index.desc.set_need_check_feed(False) out = paddle.gather_nd(data1, index) place = base.CPUPlace() exe = base.Executor(place) - input = np.array([[1, 2], [3, 4], [5, 6]]) + input = np.array([[1, 2], [3, 4], [5, 6]]).astype('float64') index_1 = np.array([[1]]).astype('int32') (result,) = exe.run( feed={"data1": input, "index": index_1}, fetch_list=[out] @@ -655,6 +659,7 @@ def test_static(self): expected_output = np.array([[3, 4]]) np.testing.assert_allclose(result, expected_output, rtol=1e-05) + @test_with_pir_api def test_static_fp16_with_gpu(self): if paddle.base.core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) @@ -671,11 +676,9 @@ def test_static_fp16_with_gpu(self): x = paddle.static.data( name="x", shape=[2, 3, 2], dtype="float16" ) - x.desc.set_need_check_feed(False) idx = paddle.static.data( name="index", shape=[1, 2], dtype="int32" ) - idx.desc.set_need_check_feed(False) y = paddle.gather_nd(x, idx) diff --git a/test/legacy_test/test_gather_op.py b/test/legacy_test/test_gather_op.py index e845875394be6..3ebb2de7b8560 100644 --- a/test/legacy_test/test_gather_op.py +++ b/test/legacy_test/test_gather_op.py @@ -21,6 +21,7 @@ from paddle import base from paddle.base.dygraph.base import switch_to_static_graph from paddle.framework import core +from paddle.pir_utils import test_with_pir_api def gather_numpy(x, index, axis): @@ -41,10 +42,10 @@ def setUp(self): self.if_enable_cinn() def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def config(self): """ @@ -114,11 +115,11 @@ def if_enable_cinn(self): self.enable_cinn = False def test_check_output(self): - self.check_output_with_place(place=paddle.CUDAPlace(0)) + self.check_output_with_place(place=paddle.CUDAPlace(0), check_pir=True) def test_check_grad(self): self.check_grad_with_place( - paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True + paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True, check_pir=True ) @@ -299,10 +300,10 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', numeric_grad_delta=0.5) + self.check_grad(['X'], 'Out', numeric_grad_delta=0.5, check_pir=True) def config(self): """ @@ -328,10 +329,10 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) def config(self): """ @@ -418,23 +419,23 @@ def config_dtype(self): class API_TestGather(unittest.TestCase): + @test_with_pir_api def test_out1(self): with base.program_guard(base.Program(), base.Program()): data1 = paddle.static.data('data1', shape=[-1, 2], dtype='float64') - data1.desc.set_need_check_feed(False) - index = paddle.static.data('index', shape=[-1, 1], dtype='int32') - index.desc.set_need_check_feed(False) + index = paddle.static.data('index', shape=[-1, 1], dtype='int64') out = paddle.gather(data1, index) place = base.CPUPlace() exe = base.Executor(place) - input = np.array([[1, 2], [3, 4], [5, 6]]) - index_1 = np.array([1, 2]) + input = np.array([[1, 2], [3, 4], [5, 6]]).astype('float64') + index_1 = np.array([1, 2]).astype('int64') (result,) = exe.run( feed={"data1": input, "index": index_1}, fetch_list=[out] ) expected_output = np.array([[3, 4], [5, 6]]) np.testing.assert_allclose(result, expected_output, rtol=1e-05) + @test_with_pir_api def test_out2(self): with paddle.static.program_guard( paddle.static.Program(), paddle.static.Program() @@ -602,11 +603,22 @@ def test_axis_maxsize(): class TestCheckOutType(unittest.TestCase): + @test_with_pir_api def test_out_type(self): data = paddle.static.data(shape=[16, 10], dtype='int64', name='x') index = paddle.static.data(shape=[4], dtype='int64', name='index') out = paddle.gather(data, index) - self.assertTrue(out.dtype == core.VarDesc.VarType.INT64) + self.assertTrue( + out.dtype == core.VarDesc.VarType.INT64 + or out.dtype == core.DataType.INT64 + ) + + def test_pir_out_type(self): + with paddle.pir_utils.IrGuard(): + data = paddle.static.data(shape=[16, 10], dtype='int64', name='x') + index = paddle.static.data(shape=[4], dtype='int64', name='index') + out = paddle.gather(data, index) + self.assertTrue(out.dtype == core.DataType.INT64) if __name__ == "__main__": diff --git a/test/legacy_test/test_gradient_clip.py b/test/legacy_test/test_gradient_clip.py index 96c5de1bfe3a3..556302578d994 100644 --- a/test/legacy_test/test_gradient_clip.py +++ b/test/legacy_test/test_gradient_clip.py @@ -179,9 +179,7 @@ def check_clip_result(self, out, out_clip): v, rtol=1e-05, atol=1e-08, - err_msg='gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}'.format( - u, v, u - v - ), + err_msg=f'gradient clip by global norm has wrong results!, \nu={u}\nv={v}\ndiff={u - v}', ) # test whether the output is right when use 'set_gradient_clip' diff --git a/test/legacy_test/test_graph_send_recv_op.py b/test/legacy_test/test_graph_send_recv_op.py index 31269a82b3faa..403f8b6ff3824 100644 --- a/test/legacy_test/test_graph_send_recv_op.py +++ b/test/legacy_test/test_graph_send_recv_op.py @@ -18,6 +18,7 @@ from op_test import OpTest import paddle +from paddle.pir_utils import test_with_pir_api def graph_send_recv_wrapper( @@ -49,10 +50,12 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', user_defined_grads=[self.gradient]) + self.check_grad( + ['X'], 'Out', user_defined_grads=[self.gradient], check_pir=True + ) class TestGraphSendRecvMinOp(OpTest): @@ -77,10 +80,12 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', user_defined_grads=[self.gradient]) + self.check_grad( + ['X'], 'Out', user_defined_grads=[self.gradient], check_pir=True + ) class TestGraphSendRecvSumOp(OpTest): @@ -103,10 +108,10 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestGraphSendRecvMeanOp(OpTest): @@ -131,10 +136,10 @@ def setUp(self): self.outputs = {'Out': out, 'Dst_count': dst_count} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) def compute_graph_send_recv_for_sum_mean(inputs, attributes): @@ -216,6 +221,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): class API_GraphSendRecvOpTest(unittest.TestCase): + @test_with_pir_api def test_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): @@ -343,6 +349,7 @@ def test_set_outsize_gpu(self): np_res_set_outsize, res_set_outsize, rtol=1e-05, atol=1e-06 ) + @test_with_pir_api def test_out_size_tensor_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): @@ -378,6 +385,7 @@ def test_out_size_tensor_static(self): class API_GeometricSendURecvTest(unittest.TestCase): + @test_with_pir_api def test_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): @@ -489,6 +497,7 @@ def test_set_outsize_gpu(self): np_res_set_outsize, res_set_outsize, rtol=1e-05, atol=1e-06 ) + @test_with_pir_api def test_out_size_tensor_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): diff --git a/test/legacy_test/test_graph_send_ue_recv_op.py b/test/legacy_test/test_graph_send_ue_recv_op.py index cc8f6e188a863..d5d3c18436308 100644 --- a/test/legacy_test/test_graph_send_ue_recv_op.py +++ b/test/legacy_test/test_graph_send_ue_recv_op.py @@ -20,6 +20,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api def get_broadcast_shape(shp1, shp2): @@ -314,10 +315,10 @@ def set_config(self): self.message_op = 'ADD' def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_pir=True) class TestSumCase1(TestGraphSendUERecvSumOp): @@ -420,10 +421,10 @@ def set_config(self): self.message_op = 'ADD' def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_pir=True) class TestMeanCase1(TestGraphSendUERecvMeanOp): @@ -526,13 +527,14 @@ def set_config(self): self.message_op = 'ADD' def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( ['X', 'Y'], 'Out', user_defined_grads=self.gradients, + check_pir=True, ) @@ -636,13 +638,14 @@ def set_config(self): self.message_op = 'ADD' def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( ['X', 'Y'], 'Out', user_defined_grads=self.gradients, + check_pir=True, ) @@ -1013,6 +1016,7 @@ def test_reshape_lhs_rhs(self): ), ) + @test_with_pir_api def test_out_size_tensor_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): diff --git a/test/legacy_test/test_graph_send_uv_op.py b/test/legacy_test/test_graph_send_uv_op.py index ad32cbeea3952..45162ce0b346f 100644 --- a/test/legacy_test/test_graph_send_uv_op.py +++ b/test/legacy_test/test_graph_send_uv_op.py @@ -18,6 +18,7 @@ from op_test import OpTest import paddle +from paddle.pir_utils import test_with_pir_api def compute_graph_send_uv(inputs, attributes): @@ -63,10 +64,10 @@ def setUp(self): self.outputs = {'out': out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['x', 'y'], 'out') + self.check_grad(['x', 'y'], 'out', check_pir=True) def set_config(self): self.x = np.random.random((10, 20)).astype("float64") @@ -194,6 +195,7 @@ def test_compute_all_dygraph(self): ), ) + @test_with_pir_api def test_compute_all_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): diff --git a/test/legacy_test/test_hypot.py b/test/legacy_test/test_hypot.py new file mode 100644 index 0000000000000..66a049038eb5a --- /dev/null +++ b/test/legacy_test/test_hypot.py @@ -0,0 +1,104 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import base +from paddle.base import core + +paddle.enable_static() + + +class TestHypotAPI(unittest.TestCase): + def setUp(self): + self.x_shape = [10, 10] + self.y_shape = [10, 1] + self.x_np = np.random.uniform(-10, 10, self.x_shape).astype(np.float32) + self.y_np = np.random.uniform(-10, 10, self.y_shape).astype(np.float32) + + def test_static_graph(self): + paddle.enable_static() + startup_program = base.Program() + train_program = base.Program() + with base.program_guard(startup_program, train_program): + x = paddle.static.data( + name='input1', dtype='float32', shape=self.x_shape + ) + y = paddle.static.data( + name='input2', dtype='float32', shape=self.y_shape + ) + out = paddle.hypot(x, y) + + place = ( + base.CUDAPlace(0) + if core.is_compiled_with_cuda() + else base.CPUPlace() + ) + exe = base.Executor(place) + res = exe.run( + base.default_main_program(), + feed={'input1': self.x_np, 'input2': self.y_np}, + fetch_list=[out], + ) + np_out = np.hypot(self.x_np, self.y_np) + np.testing.assert_allclose(res[0], np_out, atol=1e-5, rtol=1e-5) + paddle.disable_static() + + def test_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x_np) + y = paddle.to_tensor(self.y_np) + result = paddle.hypot(x, y) + np.testing.assert_allclose( + np.hypot(self.x_np, self.y_np), result.numpy(), rtol=1e-05 + ) + + paddle.enable_static() + + def test_error(self): + x = paddle.to_tensor(self.x_np) + y = 3.8 + self.assertRaises(TypeError, paddle.hypot, x, y) + self.assertRaises(TypeError, paddle.hypot, y, x) + + +class TestHypotAPIBroadCast(TestHypotAPI): + def setUp(self): + self.x_np = np.arange(6).astype(np.float32) + self.y_np = np.array([20]).astype(np.float32) + self.x_shape = [6] + self.y_shape = [1] + + +class TestHypotAPI3(TestHypotAPI): + def setUp(self): + self.x_shape = [] + self.y_shape = [] + self.x_np = np.random.uniform(-10, 10, self.x_shape).astype(np.float32) + self.y_np = np.random.uniform(-10, 10, self.y_shape).astype(np.float32) + + +class TestHypotAPI4(TestHypotAPI): + def setUp(self): + self.x_shape = [1] + self.y_shape = [1] + self.x_np = np.random.uniform(-10, 10, self.x_shape).astype(np.float32) + self.y_np = np.random.uniform(-10, 10, self.y_shape).astype(np.float32) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_i0_op.py b/test/legacy_test/test_i0_op.py index 4ff7514752e0b..84694fa71ff08 100644 --- a/test/legacy_test/test_i0_op.py +++ b/test/legacy_test/test_i0_op.py @@ -20,6 +20,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api np.random.seed(100) paddle.seed(100) @@ -40,10 +41,12 @@ class TestI0API(unittest.TestCase): def setUp(self): self.x = np.array(self.DATA).astype(self.DTYPE) + self.out_ref = output_i0(self.x) self.place = [paddle.CPUPlace()] if core.is_compiled_with_cuda(): self.place.append(paddle.CUDAPlace(0)) + @test_with_pir_api def test_api_static(self): def run(place): paddle.enable_static() @@ -58,8 +61,7 @@ def run(place): feed={"x": self.x}, fetch_list=[out], ) - out_ref = output_i0(self.x) - np.testing.assert_allclose(res[0], out_ref, rtol=1e-5) + np.testing.assert_allclose(res[0], self.out_ref, rtol=1e-5) paddle.disable_static() for place in self.place: @@ -130,13 +132,14 @@ def init_config(self): self.target = output_i0(self.inputs['x']) def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( ['x'], 'out', user_defined_grads=[ref_i0_grad(self.case, 1 / self.case.size)], + check_pir=True, ) diff --git a/test/legacy_test/test_i0e_op.py b/test/legacy_test/test_i0e_op.py index 692587bf86dbf..76a87504102d2 100644 --- a/test/legacy_test/test_i0e_op.py +++ b/test/legacy_test/test_i0e_op.py @@ -20,6 +20,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api np.random.seed(100) paddle.seed(100) @@ -44,10 +45,12 @@ class TestI0eAPI(unittest.TestCase): def setUp(self): self.x = np.array(self.DATA).astype(self.DTYPE) + self.out_ref = output_i0e(self.x) self.place = [paddle.CPUPlace()] if core.is_compiled_with_cuda(): self.place.append(paddle.CUDAPlace(0)) + @test_with_pir_api def test_api_static(self): def run(place): paddle.enable_static() @@ -62,8 +65,7 @@ def run(place): feed={"x": self.x}, fetch_list=[y], ) - out_ref = output_i0e(self.x) - np.testing.assert_allclose(out_ref, res[0], rtol=1e-5) + np.testing.assert_allclose(self.out_ref, res[0], rtol=1e-5) paddle.disable_static() for place in self.place: @@ -134,13 +136,14 @@ def init_config(self): self.target = output_i0e(self.inputs['x']) def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( ['x'], 'out', user_defined_grads=[ref_i0e_grad(self.case, 1 / self.case.size)], + check_pir=True, ) diff --git a/test/legacy_test/test_i1_op.py b/test/legacy_test/test_i1_op.py index 0bb76a9bbf6ef..c82c65d2c01d1 100644 --- a/test/legacy_test/test_i1_op.py +++ b/test/legacy_test/test_i1_op.py @@ -20,6 +20,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api np.random.seed(42) paddle.seed(42) @@ -49,6 +50,7 @@ def setUp(self): self.place.append(paddle.CUDAPlace(0)) def test_api_static(self): + @test_with_pir_api def run(place): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): @@ -120,7 +122,7 @@ def setUp(self): # 测试前向输出结果 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) # 测试反向梯度输出 def test_check_grad(self): @@ -133,6 +135,7 @@ def test_check_grad(self): 1 / self.case.size, ) ], + check_pir=True, ) def init_config(self): diff --git a/test/legacy_test/test_i1e_op.py b/test/legacy_test/test_i1e_op.py index 94f9b625dc95b..f2692f9b94aa2 100644 --- a/test/legacy_test/test_i1e_op.py +++ b/test/legacy_test/test_i1e_op.py @@ -20,6 +20,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api np.random.seed(42) paddle.seed(42) @@ -49,6 +50,7 @@ def setUp(self): self.place.append(paddle.CUDAPlace(0)) def test_api_static(self): + @test_with_pir_api def run(place): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): @@ -120,7 +122,7 @@ def setUp(self): # 测试前向输出结果 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) # 测试反向梯度输出 def test_check_grad(self): @@ -133,6 +135,7 @@ def test_check_grad(self): 1 / self.case.size, ) ], + check_pir=True, ) # 生成随机的输入数据并计算对应输出 diff --git a/test/legacy_test/test_increment.py b/test/legacy_test/test_increment.py index 4887564e9b9bb..3055ffe1bdcf3 100755 --- a/test/legacy_test/test_increment.py +++ b/test/legacy_test/test_increment.py @@ -18,9 +18,11 @@ import paddle from paddle import base +from paddle.pir_utils import test_with_pir_api class TestIncrement(unittest.TestCase): + @test_with_pir_api def test_api(self): with base.program_guard(base.Program(), base.Program()): input = paddle.tensor.fill_constant( @@ -41,6 +43,7 @@ def test_api(self): class TestInplaceApiWithDataTransform(unittest.TestCase): + @test_with_pir_api def test_increment(self): if base.core.is_compiled_with_cuda(): paddle.enable_static() diff --git a/test/legacy_test/test_index_add_op.py b/test/legacy_test/test_index_add_op.py index cf6a3d03245d2..969131346e53e 100644 --- a/test/legacy_test/test_index_add_op.py +++ b/test/legacy_test/test_index_add_op.py @@ -18,7 +18,8 @@ from op_test import OpTest, convert_float_to_uint16 import paddle -from paddle.base import Program, core +from paddle.base import core +from paddle.pir_utils import test_with_pir_api def compute_index_add_ref( @@ -93,10 +94,10 @@ def init_dtype_type(self): self.add_value_shape = (3, 3) def test_check_output(self): - self.check_output(atol=1e-2) + self.check_output(atol=1e-2, check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'AddValue'], 'Out') + self.check_grad(['X', 'AddValue'], 'Out', check_pir=True) class TestIndexAddFP16Op(TestIndexAddOp): @@ -156,10 +157,12 @@ def init_dtype_type(self): self.dtype = np.uint16 def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, check_pir=True) def test_check_grad_normal(self): - self.check_grad_with_place(self.place, ['X', 'AddValue'], 'Out') + self.check_grad_with_place( + self.place, ['X', 'AddValue'], 'Out', check_pir=True + ) class TestIndexAddAPI(unittest.TestCase): @@ -290,15 +293,16 @@ def run_static(self, device): "Index": self.index_np, "AddValue": self.add_value_np, }, - fetch_list=[out.name], + fetch_list=[out], return_numpy=False, ) return res + @test_with_pir_api def test_static(self): paddle.enable_static() for device in self.place: - with paddle.static.program_guard(Program()): + with paddle.static.program_guard(paddle.static.Program()): out = self.run_static(device) ref_out = compute_index_add_ref( self.axis, diff --git a/test/legacy_test/test_index_fill.py b/test/legacy_test/test_index_fill.py new file mode 100644 index 0000000000000..ffb80f02b016e --- /dev/null +++ b/test/legacy_test/test_index_fill.py @@ -0,0 +1,143 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from itertools import combinations + +import numpy as np + +import paddle +from paddle.base import Program + +paddle.enable_static() + + +def compute_index_fill_ref(x, axis, index, value): + perm = list(range(len(x.shape))) + perm[0] = axis + perm[axis] = 0 + + out = np.transpose(x, perm) + out[index] = value + out = np.transpose(out, perm) + return out + + +class TestIndexFillAPIBase(unittest.TestCase): + def setUp(self): + self.init_setting() + self.modify_setting() + self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) + self.index_np = np.array(self.combs[np.random.randint(0, 252)]).astype( + self.index_type + ) + + self.place = ['cpu'] + if self.dtype_np == 'float16': + self.place = [] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + + def init_setting(self): + self.dtype_np = 'float64' + self.index_type = 'int64' + self.x_shape = (20, 40) + self.index_size = (5,) + self.axis = 0 + self.value = -1 + self.combs = list(combinations(list(range(10)), self.index_size[0])) + + def modify_setting(self): + pass + + def test_static_graph(self): + paddle.enable_static() + for place in self.place: + with paddle.static.program_guard(Program()): + x = paddle.static.data( + name="x", shape=self.x_shape, dtype=self.dtype_np + ) + index = paddle.static.data( + name="index", shape=self.index_size, dtype=self.index_type + ) + out = paddle.index_fill(x, index, self.axis, self.value) + exe = paddle.static.Executor(place=place) + feed_list = {"x": self.x_np, "index": self.index_np} + pd_res = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out], + )[0] + ref_res = compute_index_fill_ref( + self.x_np, self.axis, self.index_np, self.value + ) + np.testing.assert_allclose(ref_res, pd_res) + + def test_dygraph(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + x_pd = paddle.to_tensor(self.x_np) + index_pd = paddle.to_tensor(self.index_np) + pd_res = paddle.index_fill(x_pd, index_pd, self.axis, self.value) + ref_res = compute_index_fill_ref( + self.x_np, self.axis, self.index_np, self.value + ) + np.testing.assert_allclose(ref_res, pd_res) + + def test_errors(self): + data_np = np.random.random((10, 10)).astype(np.float32) + index = paddle.to_tensor([0, 2]) + + def test_index_not_tensor(): + res = paddle.index_fill(data_np, [0, 2], axis=-1, value=-1) + + self.assertRaises(ValueError, test_index_not_tensor) + + def test_value_shape(): + res = paddle.index_fill( + data_np, index, axis=-1, value=paddle.to_tensor([-1, -4]) + ) + + self.assertRaises(ValueError, test_value_shape) + + def test_axis_range(): + res = paddle.index_fill(data_np, index, axis=4, value=-1) + + self.assertRaises(ValueError, test_axis_range) + + +class TestIndexFillAPI1(TestIndexFillAPIBase): + def modify_setting(self): + self.dtype_np = 'int64' + self.index_type = 'int32' + self.x_shape = (10, 15, 10) + self.axis = 1 + + +class TestIndexFillAPI2(TestIndexFillAPIBase): + def modify_setting(self): + self.dtype_np = 'bool' + self.index_type = 'int32' + self.x_shape = (10, 15, 10) + self.axis = 1 + self.value = True + + +class TestIndexFillAPI3(TestIndexFillAPIBase): + def modify_setting(self): + self.dtype_np = 'float16' + self.x_shape = (10, 15, 10) + self.axis = 1 + self.value = 0.5 diff --git a/test/legacy_test/test_index_put_op.py b/test/legacy_test/test_index_put_op.py index 9ab02298f94dd..3d988462194cc 100644 --- a/test/legacy_test/test_index_put_op.py +++ b/test/legacy_test/test_index_put_op.py @@ -18,7 +18,6 @@ import numpy as np import paddle -from paddle.base import Program def compute_index_put_ref(x_np, indices_np, value_np, accumulate=False): @@ -115,8 +114,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 + self.dtype_pd = "float64" + self.index_type_pd = "int64" self.accumulate = False def setPlace(self): @@ -144,10 +143,11 @@ def test_dygraph_forward(self): ) np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + # @test_with_pir_api def test_static_forward(self): paddle.enable_static() for place in self.place: - with paddle.static.program_guard(Program()): + with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data( name="x", shape=self.x_shape, dtype=self.dtype_pd ) @@ -194,14 +194,13 @@ def test_static_forward(self): feed_list.update({"indice" + str(i): self.indices_np[i]}) feed_list.update({"value": self.value_np}) pd_res = exe.run( - paddle.static.default_main_program(), feed=feed_list, fetch_list=[out], - )[0] + ) ref_res = compute_index_put_ref( self.x_np, self.indices_np, self.value_np, self.accumulate ) - np.testing.assert_allclose(ref_res, pd_res, atol=1e-7) + np.testing.assert_allclose(ref_res, pd_res[0], atol=1e-7) class TestIndexPutAPI0(TestIndexPutAPIBase): @@ -211,8 +210,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 + self.dtype_pd = "float64" + self.index_type_pd = "int64" self.accumulate = True @@ -223,8 +222,8 @@ def init_dtype_type(self): self.x_shape = (110, 42, 56, 56) self.indices_shapes = ((16, 16), (16, 16), (1, 16), (1, 16)) self.value_shape = (16, 16) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 + self.dtype_pd = "float64" + self.index_type_pd = "int64" self.accumulate = False @@ -235,8 +234,8 @@ def init_dtype_type(self): self.x_shape = (110, 42, 56, 56) self.indices_shapes = ((16, 16), (16, 16), (1, 16), (1, 16)) self.value_shape = (16, 16) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 + self.dtype_pd = "float64" + self.index_type_pd = "int64" self.accumulate = True @@ -247,8 +246,8 @@ def init_dtype_type(self): self.x_shape = (110, 94) self.indices_shapes = [(110, 94)] self.value_shape = (5170,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.bool + self.dtype_pd = "float64" + self.index_type_pd = "bool" self.accumulate = False @@ -259,8 +258,8 @@ def init_dtype_type(self): self.x_shape = (110, 94) self.indices_shapes = [(110, 94)] self.value_shape = (5170,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.bool + self.dtype_pd = "float64" + self.index_type_pd = "bool" self.accumulate = True @@ -271,8 +270,8 @@ def init_dtype_type(self): self.x_shape = (110, 42, 56, 56) self.indices_shapes = ((16, 16), (16, 16), (1, 16)) self.value_shape = (16, 16, 56) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 + self.dtype_pd = "float64" + self.index_type_pd = "int64" self.accumulate = False @@ -283,8 +282,8 @@ def init_dtype_type(self): self.x_shape = (110, 42, 56, 56) self.indices_shapes = ((16, 16), (16, 16), (1, 16)) self.value_shape = (16, 16, 56) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 + self.dtype_pd = "float64" + self.index_type_pd = "int64" self.accumulate = True @@ -295,8 +294,8 @@ def init_dtype_type(self): self.x_shape = (110, 94) self.indices_shapes = [(110,)] self.value_shape = (55, 94) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.bool + self.dtype_pd = "float64" + self.index_type_pd = "bool" self.accumulate = False @@ -307,8 +306,8 @@ def init_dtype_type(self): self.x_shape = (110, 94) self.indices_shapes = [(110,)] self.value_shape = (55, 94) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.bool + self.dtype_pd = "float64" + self.index_type_pd = "bool" self.accumulate = True @@ -319,8 +318,8 @@ def init_dtype_type(self): self.x_shape = (110, 42, 56, 56) self.indices_shapes = ((16, 16), (16, 16), (1, 16)) self.value_shape = (56,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 + self.dtype_pd = "float64" + self.index_type_pd = "int64" self.accumulate = False @@ -331,8 +330,8 @@ def init_dtype_type(self): self.x_shape = (110, 42, 56, 56) self.indices_shapes = ((16, 16), (16, 16), (1, 16)) self.value_shape = (56,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 + self.dtype_pd = "float64" + self.index_type_pd = "int64" self.accumulate = True @@ -343,8 +342,8 @@ def init_dtype_type(self): self.x_shape = (110, 42, 56, 56) self.indices_shapes = ((16, 16), (16, 16), (1, 16)) self.value_shape = (1,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 + self.dtype_pd = "float64" + self.index_type_pd = "int64" self.accumulate = False @@ -355,8 +354,8 @@ def init_dtype_type(self): self.x_shape = (110, 42, 56, 56) self.indices_shapes = ((16, 16), (16, 16), (1, 16)) self.value_shape = (1,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 + self.dtype_pd = "float64" + self.index_type_pd = "int64" self.accumulate = True @@ -367,8 +366,8 @@ def init_dtype_type(self): self.x_shape = (44, 94) self.indices_shapes = [(44,)] self.value_shape = (94,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.bool + self.dtype_pd = "float64" + self.index_type_pd = "bool" self.accumulate = False @@ -379,8 +378,8 @@ def init_dtype_type(self): self.x_shape = (44, 94) self.indices_shapes = [(44,)] self.value_shape = (94,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.bool + self.dtype_pd = "float64" + self.index_type_pd = "bool" self.accumulate = True @@ -391,8 +390,8 @@ def init_dtype_type(self): self.x_shape = (44, 94) self.indices_shapes = [(44,)] self.value_shape = (1,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.bool + self.dtype_pd = "float64" + self.index_type_pd = "bool" self.accumulate = False @@ -403,8 +402,8 @@ def init_dtype_type(self): self.x_shape = (44, 94) self.indices_shapes = [(44,)] self.value_shape = (1,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.bool + self.dtype_pd = "float64" + self.index_type_pd = "bool" self.accumulate = True @@ -415,8 +414,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int32 + self.dtype_pd = "float64" + self.index_type_pd = "int32" self.accumulate = False @@ -427,8 +426,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int32 + self.dtype_pd = "float64" + self.index_type_pd = "int32" self.accumulate = True @@ -439,8 +438,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.float32 - self.index_type_pd = paddle.int32 + self.dtype_pd = "float32" + self.index_type_pd = "int32" self.accumulate = False @@ -451,8 +450,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.float32 - self.index_type_pd = paddle.int32 + self.dtype_pd = "float32" + self.index_type_pd = "int32" self.accumulate = True @@ -463,8 +462,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.float16 - self.index_type_pd = paddle.int32 + self.dtype_pd = "float16" + self.index_type_pd = "int32" self.accumulate = False @@ -475,8 +474,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.float16 - self.index_type_pd = paddle.int32 + self.dtype_pd = "float16" + self.index_type_pd = "int32" self.accumulate = True @@ -487,8 +486,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.int32 - self.index_type_pd = paddle.int32 + self.dtype_pd = "int32" + self.index_type_pd = "int32" self.accumulate = False @@ -499,8 +498,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.int32 - self.index_type_pd = paddle.int32 + self.dtype_pd = "int32" + self.index_type_pd = "int32" self.accumulate = True @@ -511,8 +510,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.int64 - self.index_type_pd = paddle.int32 + self.dtype_pd = "int64" + self.index_type_pd = "int32" self.accumulate = False @@ -523,8 +522,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.int64 - self.index_type_pd = paddle.int32 + self.dtype_pd = "int64" + self.index_type_pd = "int32" self.accumulate = True @@ -535,8 +534,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.bool - self.index_type_pd = paddle.int32 + self.dtype_pd = "bool" + self.index_type_pd = "int32" self.accumulate = False @@ -547,8 +546,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.bool - self.index_type_pd = paddle.int32 + self.dtype_pd = "bool" + self.index_type_pd = "int32" self.accumulate = True @@ -559,8 +558,8 @@ def init_dtype_type(self): self.x_shape = (110, 42, 56, 56) self.indices_shapes = ((16, 16), (16, 16), (1, 16)) self.value_shape = (16, 16, 56) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int32 + self.dtype_pd = "float64" + self.index_type_pd = "int32" self.accumulate = False @@ -571,8 +570,8 @@ def init_dtype_type(self): self.x_shape = (110, 42, 56, 56) self.indices_shapes = ((16, 16), (16, 16), (1, 16)) self.value_shape = (16, 16, 56) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int32 + self.dtype_pd = "float64" + self.index_type_pd = "int32" self.accumulate = True @@ -583,8 +582,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.bool - self.index_type_pd = paddle.int32 + self.dtype_pd = "bool" + self.index_type_pd = "int32" self.accumulate = False self.is_all_false = True @@ -596,8 +595,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.bool - self.index_type_pd = paddle.int32 + self.dtype_pd = "bool" + self.index_type_pd = "int32" self.accumulate = True self.is_all_false = True @@ -618,8 +617,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 + self.dtype_pd = "float64" + self.index_type_pd = "int64" self.accumulate = False def setPlace(self): @@ -656,8 +655,8 @@ def init_dtype_type(self): self.x_shape = (100, 110) self.indices_shapes = [(21,), (21,)] self.value_shape = (21,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 + self.dtype_pd = "float64" + self.index_type_pd = "int64" self.accumulate = True @@ -674,10 +673,10 @@ def test_backward(self): paddle.disable_static() for place in self.place: paddle.device.set_device(place) - value = paddle.ones(shape=[4], dtype=paddle.float64) - x = paddle.ones(shape=[16, 21], dtype=paddle.float64) - ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) - ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) + value = paddle.ones(shape=[4], dtype="float64") + x = paddle.ones(shape=[16, 21], dtype="float64") + ix1 = paddle.to_tensor([0, 1, 2, 3], dtype="int64") + ix2 = paddle.to_tensor([0, 1, 2, 3], dtype="int64") value.stop_gradient = False x.stop_gradient = False out = paddle.index_put(x, (ix1, ix2), value, False) @@ -719,10 +718,10 @@ def test_backward_scalarval(self): paddle.disable_static() for place in self.place: paddle.device.set_device(place) - value = paddle.ones(shape=[1], dtype=paddle.float64) - x = paddle.ones(shape=[16, 21], dtype=paddle.float64) - ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) - ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) + value = paddle.ones(shape=[1], dtype="float64") + x = paddle.ones(shape=[16, 21], dtype="float64") + ix1 = paddle.to_tensor([0, 1, 2, 3], dtype="int64") + ix2 = paddle.to_tensor([0, 1, 2, 3], dtype="int64") value.stop_gradient = False x.stop_gradient = False out = paddle.index_put(x, (ix1, ix2), value, False) @@ -760,10 +759,10 @@ def test_backward_broadcastvalue(self): paddle.disable_static() for place in self.place: paddle.device.set_device(place) - value = paddle.ones(shape=[2], dtype=paddle.float64) - x = paddle.ones(shape=[16, 21], dtype=paddle.float64) - ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) - ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + value = paddle.ones(shape=[2], dtype="float64") + x = paddle.ones(shape=[16, 21], dtype="float64") + ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype="int64") + ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype="int64") value.stop_gradient = False x.stop_gradient = False out = paddle.index_put(x, (ix1, ix2), value, False) @@ -805,10 +804,10 @@ def test_backward_broadcastvalue1(self): paddle.disable_static() for place in self.place: paddle.device.set_device(place) - value = paddle.ones(shape=[1, 2], dtype=paddle.float64) - x = paddle.ones(shape=[16, 21], dtype=paddle.float64) - ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) - ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + value = paddle.ones(shape=[1, 2], dtype="float64") + x = paddle.ones(shape=[16, 21], dtype="float64") + ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype="int64") + ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype="int64") value.stop_gradient = False x.stop_gradient = False out = paddle.index_put(x, (ix1, ix2), value, False) @@ -850,10 +849,10 @@ def test_backward_broadcastvalue2(self): paddle.disable_static() for place in self.place: paddle.device.set_device(place) - value = paddle.ones(shape=[2, 1], dtype=paddle.float64) - x = paddle.ones(shape=[16, 21], dtype=paddle.float64) - ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) - ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + value = paddle.ones(shape=[2, 1], dtype="float64") + x = paddle.ones(shape=[16, 21], dtype="float64") + ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype="int64") + ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype="int64") value.stop_gradient = False x.stop_gradient = False out = paddle.index_put(x, (ix1, ix2), value, False) @@ -895,9 +894,9 @@ def test_backward_all_false_bool_indice(self): paddle.disable_static() for place in self.place: paddle.device.set_device(place) - value = paddle.ones(shape=[2, 1], dtype=paddle.float64) - x = paddle.ones(shape=[16, 21], dtype=paddle.float64) - ix = paddle.zeros(shape=[16, 21], dtype=paddle.bool) + value = paddle.ones(shape=[2, 1], dtype="float64") + x = paddle.ones(shape=[16, 21], dtype="float64") + ix = paddle.zeros(shape=[16, 21], dtype="bool") value.stop_gradient = False x.stop_gradient = False @@ -935,6 +934,7 @@ def test_backward_all_false_bool_indice(self): atol=1e-7, ) + # @test_with_pir_api def test_backward_in_static(self): paddle.enable_static() exe = paddle.static.Executor() @@ -952,8 +952,16 @@ def test_backward_in_static(self): z = paddle.index_put(y, (index,), value) l = z.sum() - paddle.static.append_backward(l) - res = exe.run(fetch_list=[z, x.grad_name, value.grad_name]) + if paddle.framework.in_pir_mode(): + grads = paddle.autograd.ir_backward.grad(l, [x, value]) + x_grad = grads[0] + value_grad = grads[1] + else: + paddle.static.append_backward(l) + x_grad = x.grad_name + value_grad = value.grad_name + + res = exe.run(fetch_list=[z, x_grad, value_grad]) expected_z = np.ones((4, 2, 5)) expected_z[[0, 1, 3]] = np.ones((5,)) @@ -976,14 +984,14 @@ def init_dtype_type(self): self.x_shape = (110, 42, 32, 56) self.indices_shapes = ((16, 16), (16, 16)) self.value_shape = (16, 16, 56) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int32 + self.dtype_pd = "float64" + self.index_type_pd = "int32" self.accumulate = False self.mixed_indices = True self.index_type_np1 = np.bool_ self.indices_shapes1 = [(32,)] - self.index_type_pd1 = paddle.bool + self.index_type_pd1 = "bool" class TestIndexPutAPIMixedIndices1(TestIndexPutAPIBase): @@ -993,14 +1001,14 @@ def init_dtype_type(self): self.x_shape = (110, 42, 32, 56) self.indices_shapes = ((16, 16), (16, 16)) self.value_shape = (16, 16, 56) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int32 + self.dtype_pd = "float64" + self.index_type_pd = "int32" self.accumulate = True self.mixed_indices = True self.index_type_np1 = np.bool_ self.indices_shapes1 = [(32,)] - self.index_type_pd1 = paddle.bool + self.index_type_pd1 = "bool" if __name__ == '__main__': diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index e3f1de1048e11..cac243f5e8682 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -56,7 +56,7 @@ def test_backward_error(self): loss = paddle.nn.functional.relu(var_c + var_d) with self.assertRaisesRegex( RuntimeError, - f"received tensor_version:{1} != wrapper_version_snapshot:{0}", + "received tensor_version:1 != wrapper_version_snapshot:0", ): loss.backward() @@ -171,7 +171,7 @@ def test_backward_error(self): loss = paddle.nn.functional.relu(var_c) with self.assertRaisesRegex( RuntimeError, - f"received tensor_version:{1} != wrapper_version_snapshot:{0}", + "received tensor_version:1 != wrapper_version_snapshot:0", ): loss.backward() @@ -250,6 +250,72 @@ def test_backward_success_2(self): np.testing.assert_array_equal(grad_var_a_inplace, grad_var_a) +class TestDygraphInplaceMaskedFill(TestDygraphInplace): + def non_inplace_api_processing(self, var): + return paddle.masked_fill(var, self.mask, self.value) + + def inplace_api_processing(self, var): + return paddle.masked_fill_(var, self.mask, self.value) + + def init_data(self): + self.dtype = "float32" + self.input_var_numpy = np.random.uniform(-5, 5, [30, 3]) + self.value = np.random.uniform(-10, 10) + self.value = paddle.to_tensor(self.value, dtype=self.dtype) + self.mask = np.random.randint(0, 2, [30, 3]).astype('bool') + self.mask = paddle.to_tensor(self.mask, dtype='bool') + + def test_forward_version(self): + with paddle.base.dygraph.guard(): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + self.assertEqual(var.inplace_version, 0) + + inplace_var = self.inplace_api_processing(var) + self.assertEqual(var.inplace_version, 2) + + inplace_var[0] = 2 + self.assertEqual(var.inplace_version, 3) + + inplace_var = self.inplace_api_processing(inplace_var) + self.assertEqual(var.inplace_version, 5) + + def test_backward_error(self): + # It raises an error because the inplace operator will result + # in incorrect gradient computation. + with paddle.base.dygraph.guard(): + var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + var_a.stop_gradient = False + + var_b = var_a**2 + + # Here, the gradient computation will use the value of var_b + var_c = var_b**2 + self.inplace_api_processing(var_b) + + loss = paddle.nn.functional.relu(var_c) + with self.assertRaisesRegex( + RuntimeError, + f"received tensor_version:{2} != wrapper_version_snapshot:{0}", + ): + loss.backward() + + +class TestDygraphInplaceMaskedFill2(TestDygraphInplaceMaskedFill): + def non_inplace_api_processing(self, var): + return paddle.masked_fill(var, self.mask, self.value) + + def inplace_api_processing(self, var): + return paddle.masked_fill_(var, self.mask, self.value) + + def init_data(self): + self.dtype = "float32" + self.input_var_numpy = np.random.uniform(-5, 5, [30, 3]) + self.value = np.random.uniform(-10, 10) + self.value = paddle.to_tensor(self.value, dtype=self.dtype) + self.mask = np.random.randint(0, 2, [30, 1]).astype('bool') + self.mask = paddle.to_tensor(self.mask, dtype='bool') + + class TestDygraphInplaceWithContinuous(TestDygraphInplace): def init_data(self): self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1]) @@ -834,6 +900,59 @@ def test_error(self): self.assertRaises(ValueError, paddle.gcd_, x, y) +class TestDygraphInplaceHypot(TestDygraphInplace): + def init_data(self): + self.input_var_numpy = np.random.randint(2, size=200) + self.input_var_numpy = self.input_var_numpy.reshape([10, 20]) + self.dtype = "float32" + self.y = paddle.randn(shape=[10, 20], dtype="float32") + + def inplace_api_processing(self, var): + return paddle.hypot_(var, self.y) + + def non_inplace_api_processing(self, var): + return paddle.hypot(var, self.y) + + def test_errors(self): + x = 3.0 + self.assertRaises(TypeError, paddle.hypot_, x, self.y) + self.assertRaises(TypeError, paddle.hypot_, self.y, x) + + def test_forward_version(self): + with paddle.base.dygraph.guard(): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + self.assertEqual(var.inplace_version, 0) + + inplace_var = self.inplace_api_processing(var) + self.assertEqual(var.inplace_version, 3) + + inplace_var[0] = 2.0 + self.assertEqual(var.inplace_version, 4) + + inplace_var = self.inplace_api_processing(inplace_var) + self.assertEqual(var.inplace_version, 7) + + def test_backward_error(self): + # It raises an error because the inplace operator will result + # in incorrect gradient computation. + with paddle.base.dygraph.guard(): + var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + var_a.stop_gradient = False + + var_b = var_a**2 + # Here, the gradient computation will use the value of var_b + var_c = var_b**2 + self.inplace_api_processing(var_b) + var_c = paddle.cast(var_c, "float32") + + loss = paddle.nn.functional.relu(var_c) + with self.assertRaisesRegex( + RuntimeError, + f"received tensor_version:{3} != wrapper_version_snapshot:{0}", + ): + loss.backward() + + class TestDygraphInplaceNanToNum(TestDygraphInplace): def init_data(self): self.input_var_numpy = np.array( @@ -886,7 +1005,7 @@ def test_backward_error(self): loss = paddle.nn.functional.relu(var_c) with self.assertRaisesRegex( RuntimeError, - f"received tensor_version:{3} != wrapper_version_snapshot:{0}", + "received tensor_version:3 != wrapper_version_snapshot:0", ): loss.backward() @@ -975,7 +1094,7 @@ def test_backward_error(self): loss = paddle.nn.functional.relu(var_c) with self.assertRaisesRegex( RuntimeError, - f"received tensor_version:{2} != wrapper_version_snapshot:{0}", + "received tensor_version:2 != wrapper_version_snapshot:0", ): loss.backward() @@ -1051,7 +1170,7 @@ def test_backward_error(self): loss = paddle.nn.functional.relu(var_c) with self.assertRaisesRegex( RuntimeError, - f"received tensor_version:{2} != wrapper_version_snapshot:{0}", + "received tensor_version:2 != wrapper_version_snapshot:0", ): loss.backward() @@ -1347,7 +1466,7 @@ def test_backward_error(self): loss = paddle.nn.functional.relu(var_c) with self.assertRaisesRegex( RuntimeError, - f"received tensor_version:{2} != wrapper_version_snapshot:{0}", + "received tensor_version:2 != wrapper_version_snapshot:0", ): loss.backward() @@ -1389,7 +1508,7 @@ def test_backward_error(self): loss = paddle.nn.functional.relu(var_c) with self.assertRaisesRegex( RuntimeError, - f"received tensor_version:{2} != wrapper_version_snapshot:{0}", + "received tensor_version:2 != wrapper_version_snapshot:0", ): loss.backward() @@ -1482,5 +1601,20 @@ def test_forward_version(self): self.assertEqual(var.inplace_version, 2) +class TestDygraphInplaceIndexFill(TestDygraphInplace): + def init_data(self): + self.input_var_numpy = np.random.random((20, 40)) + self.dtype = "float32" + self.axis = 0 + self.index = paddle.to_tensor([0, 2]) + self.value = -1 + + def inplace_api_processing(self, var): + return paddle.index_fill_(var, self.index, self.axis, self.value) + + def non_inplace_api_processing(self, var): + return paddle.index_fill(var, self.index, self.axis, self.value) + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_input_spec.py b/test/legacy_test/test_input_spec.py index 47c461a2a1eab..a1e8c5e852295 100644 --- a/test/legacy_test/test_input_spec.py +++ b/test/legacy_test/test_input_spec.py @@ -200,7 +200,7 @@ def check_result(self, specs, path): np.testing.assert_allclose(dy_out, pred_out, rtol=1e-05) # @to_static by InputSpec - net = paddle.jit.to_static(net, input_spec=specs) + net = paddle.jit.to_static(net, input_spec=specs, full_graph=True) st_out = net(self.x, *specs[1:]) np.testing.assert_allclose(dy_out, st_out, rtol=1e-05) @@ -217,7 +217,7 @@ def test_spec_compatible(self): net = NetWithNonTensorSpec(self.in_num, self.out_num) specs = [self.x_spec, False, "bn", -10] - net = paddle.jit.to_static(net, input_spec=specs) + net = paddle.jit.to_static(net, input_spec=specs, full_graph=True) net.eval() path = os.path.join(self.temp_dir.name, './net_twice') @@ -288,7 +288,7 @@ def test_non_tensor_with_prune(self): np.testing.assert_allclose(dy_out, pred_out, rtol=1e-05) # @to_static by InputSpec - net = paddle.jit.to_static(net, input_spec=specs) + net = paddle.jit.to_static(net, input_spec=specs, full_graph=True) st_out, _ = net(self.x, self.y, *specs[2:]) np.testing.assert_allclose(dy_out, st_out, rtol=1e-05) @@ -351,7 +351,9 @@ def tearDown(self): def test_run(self): net = NegSpecNet() net = paddle.jit.to_static( - net, input_spec=[paddle.static.InputSpec(shape=[-1, 10])] + net, + input_spec=[paddle.static.InputSpec(shape=[-1, 10])], + full_graph=True, ) x = paddle.randn([2, 10]) out = net(x) diff --git a/test/legacy_test/test_instance_norm_op_v2.py b/test/legacy_test/test_instance_norm_op_v2.py index 3b58eccb456e4..bb6d1d2d9111c 100644 --- a/test/legacy_test/test_instance_norm_op_v2.py +++ b/test/legacy_test/test_instance_norm_op_v2.py @@ -209,7 +209,7 @@ def setUp(self): self.python_api = instance_norm_wrapper self.public_python_api = instance_norm_wrapper self.check_prim = ( - False if os.getenv("FLAGS_enable_new_ir_in_executor") else True + False if os.getenv("FLAGS_enable_pir_in_executor") else True ) def test_check_output(self): @@ -315,7 +315,7 @@ def setUp(self): 'data_format': self.data_format, } self.check_prim = ( - False if os.getenv("FLAGS_enable_new_ir_in_executor") else True + False if os.getenv("FLAGS_enable_pir_in_executor") else True ) def init_value(self): diff --git a/test/legacy_test/test_inverse_op.py b/test/legacy_test/test_inverse_op.py index 8e578746226ac..bc111bac1e781 100644 --- a/test/legacy_test/test_inverse_op.py +++ b/test/legacy_test/test_inverse_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestInverseOp(OpTest): @@ -40,10 +41,10 @@ def setUp(self): self.outputs = {'Output': inverse} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_grad(self): - self.check_grad(['Input'], 'Output') + self.check_grad(['Input'], 'Output', check_pir=True) class TestInverseOpBatched(TestInverseOp): @@ -60,7 +61,9 @@ def config(self): self.python_api = paddle.tensor.math.inverse def test_grad(self): - self.check_grad(['Input'], 'Output', max_relative_error=1e-6) + self.check_grad( + ['Input'], 'Output', max_relative_error=1e-6, check_pir=True + ) class TestInverseOpFP32(TestInverseOp): @@ -70,7 +73,9 @@ def config(self): self.python_api = paddle.tensor.math.inverse def test_grad(self): - self.check_grad(['Input'], 'Output', max_relative_error=1e-2) + self.check_grad( + ['Input'], 'Output', max_relative_error=1e-2, check_pir=True + ) class TestInverseOpBatchedFP32(TestInverseOpFP32): @@ -95,7 +100,9 @@ def setUp(self): self.places.append(base.CUDAPlace(0)) def check_static_result(self, place): - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): input = paddle.static.data( name="input", shape=[4, 4], dtype="float64" ) @@ -105,7 +112,7 @@ def check_static_result(self, place): exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={"input": input_np}, fetch_list=[result], ) @@ -113,6 +120,7 @@ def check_static_result(self, place): fetches[0], np.linalg.inv(input_np), rtol=1e-05 ) + @test_with_pir_api def test_static(self): for place in self.places: self.check_static_result(place=place) @@ -161,7 +169,9 @@ def setUp(self): self.places.append(base.CUDAPlace(0)) def check_static_result(self, place): - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): input = paddle.static.data( name="input", shape=[4, 4], dtype="float64" ) @@ -172,7 +182,7 @@ def check_static_result(self, place): exe = base.Executor(place) try: fetches = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={"input": input_np}, fetch_list=[result], ) @@ -181,6 +191,7 @@ def check_static_result(self, place): except ValueError as ex: print("The mat is singular") + @test_with_pir_api def test_static(self): for place in self.places: self.check_static_result(place=place) diff --git a/test/legacy_test/test_isclose_op.py b/test/legacy_test/test_isclose_op.py index db7d8e6c49b54..c9803bf441149 100644 --- a/test/legacy_test/test_isclose_op.py +++ b/test/legacy_test/test_isclose_op.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestIscloseOp(OpTest): @@ -52,7 +53,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestIscloseOpException(TestIscloseOp): @@ -114,6 +115,7 @@ def set_args(self): class TestIscloseStatic(unittest.TestCase): + @test_with_pir_api def test_api_case(self): paddle.enable_static() x_data = np.random.rand(10, 10) @@ -122,9 +124,9 @@ def test_api_case(self): if paddle.base.core.is_compiled_with_cuda(): places.append(paddle.base.CUDAPlace(0)) for place in places: - with paddle.static.program_guard( - paddle.static.Program(), paddle.static.Program() - ): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): x = paddle.static.data( name='x', shape=[10, 10], dtype='float64' ) @@ -134,7 +136,7 @@ def test_api_case(self): result = paddle.isclose(x, y) exe = paddle.base.Executor(place) fetches = exe.run( - paddle.base.default_main_program(), + main, feed={"x": x_data, "y": y_data}, fetch_list=[result], ) @@ -209,17 +211,20 @@ def test_equal_nan(): class TestIscloseOpFp16(unittest.TestCase): + @test_with_pir_api def test_fp16(self): x_data = np.random.rand(10, 10).astype('float16') y_data = np.random.rand(10, 10).astype('float16') - with paddle.static.program_guard(paddle.static.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): x = paddle.static.data(shape=[10, 10], name='x', dtype='float16') y = paddle.static.data(shape=[10, 10], name='y', dtype='float16') out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08) if core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) exe = paddle.static.Executor(place) - exe.run(paddle.static.default_startup_program()) + exe.run(startup) out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out]) @@ -235,7 +240,7 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) class TestIscloseOpFloat32(TestIscloseOp): @@ -256,10 +261,11 @@ def set_args(self): self.equal_nan = False def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestIscloseOpCp64(unittest.TestCase): + @test_with_pir_api def test_cp64(self): x_data = ( np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10) @@ -267,18 +273,21 @@ def test_cp64(self): y_data = ( np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10) ).astype(np.complex64) - with paddle.static.program_guard(paddle.static.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): x = paddle.static.data(shape=[10, 10], name='x', dtype=np.complex64) y = paddle.static.data(shape=[10, 10], name='y', dtype=np.complex64) out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08) if core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) exe = paddle.static.Executor(place) - exe.run(paddle.static.default_startup_program()) + exe.run(startup) out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out]) class TestIscloseOpCp128(unittest.TestCase): + @test_with_pir_api def test_cp128(self): x_data = ( np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10) @@ -286,7 +295,9 @@ def test_cp128(self): y_data = ( np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10) ).astype(np.complex128) - with paddle.static.program_guard(paddle.static.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): x = paddle.static.data( shape=[10, 10], name='x', dtype=np.complex128 ) @@ -297,7 +308,7 @@ def test_cp128(self): if core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) exe = paddle.static.Executor(place) - exe.run(paddle.static.default_startup_program()) + exe.run(startup) out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out]) @@ -319,7 +330,7 @@ def set_args(self): self.equal_nan = False def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestIscloseOpLargeDimInput(TestIscloseOp): @@ -332,6 +343,15 @@ def set_args(self): self.equal_nan = False +class TestIscloseOpDoubleTol(TestIscloseOp): + def set_args(self): + self.input = np.array([1.0, 1e-9]).astype("float64") + self.other = np.array([1.0, 1e-10]).astype("float64") + self.rtol = np.array([1e-13]).astype("float64") + self.atol = np.array([1e-14]).astype("float64") + self.equal_nan = False + + if __name__ == "__main__": paddle.enable_static() unittest.main() diff --git a/test/legacy_test/test_label_smooth_functional.py b/test/legacy_test/test_label_smooth_functional.py index 81f868c83c895..484f003c45497 100644 --- a/test/legacy_test/test_label_smooth_functional.py +++ b/test/legacy_test/test_label_smooth_functional.py @@ -20,6 +20,7 @@ import paddle.base.dygraph as dg import paddle.nn.functional as F from paddle import base +from paddle.pir_utils import test_with_pir_api class LabelSmoothTestCase(unittest.TestCase): @@ -88,6 +89,7 @@ def paddle_dygraph_layer(self): y_np = y_var.numpy() return y_np + @test_with_pir_api def _test_equivalence(self, place): place = base.CPUPlace() result1 = self.base_layer(place) diff --git a/test/legacy_test/test_label_smooth_op.py b/test/legacy_test/test_label_smooth_op.py index 763cbd676c169..6fdffc1cafa6b 100644 --- a/test/legacy_test/test_label_smooth_op.py +++ b/test/legacy_test/test_label_smooth_op.py @@ -45,10 +45,10 @@ def init_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(["X"], "Out") + self.check_grad(["X"], "Out", check_pir=True) @unittest.skipIf( @@ -77,11 +77,11 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ["X"], "Out") + self.check_grad_with_place(place, ["X"], "Out", check_pir=True) class TestLabelSmoothFP16OP(TestLabelSmoothOp): diff --git a/test/legacy_test/test_linear_interp_v2_op.py b/test/legacy_test/test_linear_interp_v2_op.py index f748da3e6f849..fff39a95302c0 100755 --- a/test/legacy_test/test_linear_interp_v2_op.py +++ b/test/legacy_test/test_linear_interp_v2_op.py @@ -192,12 +192,12 @@ def setUp(self): def test_check_output(self): if platform.system() == "Linux": - self.check_output(atol=1e-7) + self.check_output(atol=1e-7, check_pir=True) else: - self.check_output(atol=1e-5) + self.check_output(atol=1e-5, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) def init_test_case(self): create_test_case0(self) @@ -314,6 +314,15 @@ def setUp(self): self.attrs['scale'] = self.scale self.outputs = {'Out': output_np} + def test_check_output(self): + if platform.system() == "Linux": + self.check_output(atol=1e-7, check_pir=False) + else: + self.check_output(atol=1e-5, check_pir=False) + + def test_check_grad(self): + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) + class TestLinearInterpOpAPI2_0(unittest.TestCase): def test_case(self): @@ -339,14 +348,11 @@ def test_case(self): class TestLinearInterpOpFP16(TestLinearInterpOp): def test_check_output(self): - self.check_output(atol=1e-3) + self.check_output(atol=1e-3, check_pir=True) def test_check_grad(self): self.check_grad( - ['X'], - 'Out', - in_place=True, - max_relative_error=1e-2, + ['X'], 'Out', in_place=True, max_relative_error=1e-2, check_pir=True ) def init_test_case(self): @@ -416,12 +422,17 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-2) + self.check_output_with_place(place, atol=1e-2, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', in_place=True, max_relative_error=1e-2 + place, + ['X'], + 'Out', + in_place=True, + max_relative_error=1e-2, + check_pir=True, ) def init_test_case(self): @@ -475,9 +486,13 @@ def setUp(self): def test_check_output(self): if platform.system() == "Linux": - self.check_output_with_place(place=core.CPUPlace(), atol=1e-7) + self.check_output_with_place( + place=core.CPUPlace(), atol=1e-7, check_pir=True + ) else: - self.check_output_with_place(place=core.CPUPlace(), atol=1e-5) + self.check_output_with_place( + place=core.CPUPlace(), atol=1e-5, check_pir=True + ) def init_test_case(self): self.interp_method = 'linear' diff --git a/test/legacy_test/test_linspace.py b/test/legacy_test/test_linspace.py index d45463cd3d826..e113bd4b76fb0 100644 --- a/test/legacy_test/test_linspace.py +++ b/test/legacy_test/test_linspace.py @@ -43,7 +43,7 @@ def _set_data(self): } def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestLinspaceOpReverseCase(TestLinspaceOpCommonCase): @@ -56,7 +56,7 @@ def _set_data(self): self.outputs = {'Out': np.arange(10, -1, -1).astype(self.dtype)} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestLinspaceOpNumOneCase(TestLinspaceOpCommonCase): @@ -69,7 +69,7 @@ def _set_data(self): self.outputs = {'Out': np.array([10], dtype=self.dtype)} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestLinspaceOpCommonCaseFP16(TestLinspaceOpCommonCase): @@ -111,7 +111,7 @@ def _set_data(self): } def test_check_output(self): - return self.check_output_with_place(core.CUDAPlace(0)) + return self.check_output_with_place(core.CUDAPlace(0), check_pir=True) class TestLinspaceOpReverseCaseBF16(TestLinspaceOpCommonCaseBF16): diff --git a/test/legacy_test/test_listen_and_serv_op.py b/test/legacy_test/test_listen_and_serv_op.py index 121490c8ae4fc..0c9b55f2c3e8d 100644 --- a/test/legacy_test/test_listen_and_serv_op.py +++ b/test/legacy_test/test_listen_and_serv_op.py @@ -149,7 +149,7 @@ def _wait_ps_ready(self, pid): # on the /tmp directory until it was ready to process all the RPC call. os.stat("/tmp/paddle.%d.port" % pid) return - except os.error: + except OSError: start_left_time -= sleep_time def test_rpc_interfaces(self): diff --git a/test/legacy_test/test_log_softmax.py b/test/legacy_test/test_log_softmax.py index 1c11d1096f0b4..316627be88892 100644 --- a/test/legacy_test/test_log_softmax.py +++ b/test/legacy_test/test_log_softmax.py @@ -20,6 +20,7 @@ import paddle import paddle.nn.functional as F from paddle.base import core +from paddle.pir_utils import test_with_pir_api np.random.seed(10) @@ -63,10 +64,12 @@ def set_attrs(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad]) + self.check_grad( + ['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True + ) class TestLogSoftmaxOp_ZeroDim(TestLogSoftmaxOp): @@ -83,10 +86,10 @@ def setUp(self): self.attrs = {'axis': -1} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], ['Out']) + self.check_grad(['X'], ['Out'], check_pir=True) class TestLogSoftmaxShape(TestLogSoftmaxOp): @@ -104,10 +107,10 @@ def set_attrs(self): self.dtype = np.float16 def test_check_output(self): - self.check_output(atol=1e-3) + self.check_output(atol=1e-3, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], ['Out'], max_relative_error=1e-2) + self.check_grad(['X'], ['Out'], max_relative_error=1e-2, check_pir=True) class TestLogSoftmaxShapeFP16OP(TestLogSoftmaxFP16OP): @@ -143,7 +146,7 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) @@ -152,6 +155,7 @@ def test_check_grad(self): ['X'], ['Out'], user_defined_grads=[self.x_grad], + check_pir=True, ) @@ -171,6 +175,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def check_api(self, axis=-1): ref_out = np.apply_along_axis(ref_log_softmax, axis, self.x) @@ -190,6 +195,7 @@ def check_api(self, axis=-1): np.testing.assert_allclose(y.numpy(), ref_out, rtol=1e-05) paddle.enable_static() + @test_with_pir_api def test_check_api(self): for axis in [-1, 1]: self.check_api(axis) @@ -205,6 +211,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def check_api(self, axis=-1, dtype=None): x = self.x.copy() if dtype is not None: @@ -223,6 +230,7 @@ def check_api(self, axis=-1, dtype=None): np.testing.assert_allclose(y.numpy(), ref_out, rtol=1e-05) paddle.enable_static() + @test_with_pir_api def test_check_api(self): for axis in [-1, 1]: self.check_api(axis) diff --git a/test/legacy_test/test_logcumsumexp_op.py b/test/legacy_test/test_logcumsumexp_op.py index 373548f679b88..c3363ae721b65 100644 --- a/test/legacy_test/test_logcumsumexp_op.py +++ b/test/legacy_test/test_logcumsumexp_op.py @@ -22,6 +22,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def np_naive_logcumsumexp(x: np.ndarray, axis: Optional[int] = None): @@ -145,7 +146,9 @@ def run_imperative(self): np.testing.assert_allclose(z, y.numpy(), rtol=1e-05) def run_static(self, use_gpu=False): - with base.program_guard(base.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): data_np = np.random.random((5, 4)).astype(np.float32) x = paddle.static.data('X', [5, 4]) y = paddle.logcumsumexp(x) @@ -156,15 +159,15 @@ def run_static(self, use_gpu=False): place = base.CUDAPlace(0) if use_gpu else base.CPUPlace() exe = base.Executor(place) - exe.run(base.default_startup_program()) out = exe.run( + main, feed={'X': data_np}, fetch_list=[ - y.name, - y2.name, - y3.name, - y4.name, - y5.name, + y, + y2, + y3, + y4, + y5, ], ) @@ -178,6 +181,7 @@ def run_static(self, use_gpu=False): z = np_logcumsumexp(data_np, axis=-2) np.testing.assert_allclose(z, out[4], rtol=1e-05) + @test_with_pir_api def test_cpu(self): paddle.disable_static(paddle.base.CPUPlace()) self.run_imperative() @@ -185,6 +189,7 @@ def test_cpu(self): self.run_static() + @test_with_pir_api def test_gpu(self): if not base.core.is_compiled_with_cuda(): return @@ -194,14 +199,18 @@ def test_gpu(self): self.run_static(use_gpu=True) + # @test_with_pir_api def test_name(self): with base.program_guard(base.Program()): x = paddle.static.data('x', [3, 4]) y = paddle.logcumsumexp(x, name='out') self.assertTrue('out' in y.name) + @test_with_pir_api def test_type_error(self): - with base.program_guard(base.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): with self.assertRaises(TypeError): data_np = np.random.random((100, 100), dtype=np.int32) x = paddle.static.data('X', [100, 100], dtype='int32') @@ -209,8 +218,7 @@ def test_type_error(self): place = base.CUDAPlace(0) exe = base.Executor(place) - exe.run(base.default_startup_program()) - out = exe.run(feed={'X': data_np}, fetch_list=[y.name]) + out = exe.run(main, feed={'X': data_np}, fetch_list=[y]) def logcumsumexp_wrapper( @@ -232,7 +240,7 @@ def setUp(self): self.outputs = {'Out': np_logcumsumexp(input, **attrs)} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -245,6 +253,7 @@ def test_check_grad(self): **self.attrs ) ], + check_pir=True, ) def input_and_attrs(self): @@ -295,6 +304,7 @@ def check_main(self, x_np, dtype, axis=None): paddle.enable_static() return y_np, x_g_np + @test_with_pir_api def test_main(self): if not paddle.is_compiled_with_cuda(): return @@ -332,7 +342,7 @@ def test_check_output(self): place = core.CUDAPlace(0) place = core.CUDAPlace(0) self.check_output_with_place_customized( - checker=self.verify_output, place=place + checker=self.verify_output, place=place, check_pir=True ) def verify_output(self, outs): @@ -352,7 +362,12 @@ def verify_output(self, outs): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', numeric_grad_delta=0.5, max_relative_error=0.5 + place, + ['X'], + 'Out', + numeric_grad_delta=0.5, + max_relative_error=0.5, + check_pir=True, ) diff --git a/test/legacy_test/test_logical_op.py b/test/legacy_test/test_logical_op.py index 98e15878cdfb6..81dec36e2f698 100755 --- a/test/legacy_test/test_logical_op.py +++ b/test/legacy_test/test_logical_op.py @@ -67,6 +67,7 @@ } +# @test_with_pir_api def run_static(x_np, y_np, op_str, use_gpu=False, binary_op=True): paddle.enable_static() startup_program = Program() diff --git a/test/legacy_test/test_logit_op.py b/test/legacy_test/test_logit_op.py index b2f2e21af25ee..641fc68e1832d 100644 --- a/test/legacy_test/test_logit_op.py +++ b/test/legacy_test/test_logit_op.py @@ -58,10 +58,12 @@ def set_attrs(self): self.eps = 1e-8 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad]) + self.check_grad( + ['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True + ) class TestLogitOpFp32(TestLogitOp): @@ -71,10 +73,12 @@ def set_attrs(self): self.eps = 1e-8 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad]) + self.check_grad( + ['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True + ) class TestLogitOpFp16(TestLogitOp): @@ -84,10 +88,12 @@ def set_attrs(self): self.eps = 1e-8 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad]) + self.check_grad( + ['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True + ) @unittest.skipIf( @@ -115,7 +121,7 @@ def set_attrs(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -125,6 +131,7 @@ def test_check_grad(self): ['X'], ['Out'], user_defined_grads=[self.x_grad], + check_pir=True, ) diff --git a/test/legacy_test/test_logspace.py b/test/legacy_test/test_logspace.py index 9edd4aef71788..857a6411b869f 100644 --- a/test/legacy_test/test_logspace.py +++ b/test/legacy_test/test_logspace.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestLogspaceOpCommonCase(OpTest): @@ -39,7 +40,7 @@ def init_data(self): self.outputs = {'Out': np.power(2, np.arange(0, 11)).astype(dtype)} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestLogspaceFP16Op(TestLogspaceOpCommonCase): @@ -87,7 +88,7 @@ def init_data(self): self.place = core.CUDAPlace(0) def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, check_pir=True) class TestLogspaceOpReverseCase(TestLogspaceOpCommonCase): @@ -143,6 +144,7 @@ def init_data(self): class TestLogspaceAPI(unittest.TestCase): + @test_with_pir_api def test_variable_input1(self): paddle.enable_static() prog = paddle.static.Program() @@ -170,6 +172,7 @@ def test_variable_input2(self): self.assertEqual((out.numpy() == np_res).all(), True) paddle.enable_static() + @test_with_pir_api def test_dtype(self): paddle.enable_static() prog = paddle.static.Program() diff --git a/test/legacy_test/test_lookup_table_v2_bf16_op.py b/test/legacy_test/test_lookup_table_v2_bf16_op.py index 04362fc7cffa0..44a2f1881b086 100644 --- a/test/legacy_test/test_lookup_table_v2_bf16_op.py +++ b/test/legacy_test/test_lookup_table_v2_bf16_op.py @@ -27,6 +27,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestLookupTableV2BF16Op(TestLookupTableBF16Op): @@ -125,10 +126,12 @@ def setUp(self): self.prog, feed={'x': self.ids}, fetch_list=['emb_weight', self.emb] ) + @test_with_pir_api def test_embedding_weights(self): result = convert_uint16_to_float(self.result[0]) np.testing.assert_array_equal(self.w_fp32, result) + @test_with_pir_api def test_lookup_results(self): lookup_result = convert_uint16_to_float(self.result[1]) lookup_ref = _lookup(self.w_fp32, self.ids, self.flat_ids, self.op_type) diff --git a/test/legacy_test/test_lookup_table_v2_op.py b/test/legacy_test/test_lookup_table_v2_op.py index ad708eb137bb1..20a9c05e91bce 100644 --- a/test/legacy_test/test_lookup_table_v2_op.py +++ b/test/legacy_test/test_lookup_table_v2_op.py @@ -21,9 +21,11 @@ import paddle from paddle import base from paddle.base import Program, core, program_guard +from paddle.pir_utils import test_with_pir_api class TestStaticGraphSupportMultipleInt(unittest.TestCase): + @test_with_pir_api def test_main(self): dtypes = ['uint8', 'int8', 'int16', 'int32', 'int64'] if paddle.in_dynamic_mode(): diff --git a/test/legacy_test/test_lr_scheduler.py b/test/legacy_test/test_lr_scheduler.py index ba1f712dce2fd..3db40ea291342 100644 --- a/test/legacy_test/test_lr_scheduler.py +++ b/test/legacy_test/test_lr_scheduler.py @@ -214,6 +214,204 @@ def _test_dygraph(self, place, kwargs): self.assertEqual(scheduler.last_lr, scheduler1.last_lr) +def cosine_annealing_warm_restarts_lr(epoch_num, v_l): + if epoch_num is None and v_l['last_epoch'] < 0: + epoch_num = 0 + + cur_lr = ( + v_l['eta_min'] + + (v_l['base_lr'] - v_l['eta_min']) + * (1 + math.cos(math.pi * v_l['T_cur'] / v_l['T_i'])) + / 2 + ) + + if v_l['last_epoch'] == -1: + cur_lr = v_l['base_lr'] + + if epoch_num is None: + epoch_num = v_l['last_epoch'] + 1 + v_l['T_cur'] = v_l['T_cur'] + 1 + if v_l['T_cur'] >= v_l['T_i']: + v_l['T_cur'] = v_l['T_cur'] - v_l['T_i'] + v_l['T_i'] = v_l['T_i'] * v_l['T_mult'] + else: + if epoch_num < 0: + raise ValueError( + f"Expected non-negative epoch, but got {epoch_num}" + ) + if epoch_num >= v_l['T_0']: + if v_l['T_mult'] == 1: + v_l['T_cur'] = epoch_num % v_l['T_0'] + else: + n = int( + math.log( + (epoch_num / v_l['T_0'] * (v_l['T_mult'] - 1) + 1), + v_l['T_mult'], + ) + ) + v_l['T_cur'] = epoch_num - v_l['T_0'] * ( + v_l['T_mult'] ** n - 1 + ) / (v_l['T_mult'] - 1) + v_l['T_i'] = v_l['T_0'] * v_l['T_mult'] ** (n) + else: + v_l['T_i'] = v_l['T_0'] + v_l['T_cur'] = epoch_num + v_l['last_epoch'] = math.floor(epoch_num) + + return cur_lr + + +class TestCosineAnnealingWarmRestarts(unittest.TestCase): + def test_CosineRestartsLR(self): + # check value of T_0 + with self.assertRaises(ValueError): + paddle.optimizer.lr.CosineAnnealingWarmRestarts( + learning_rate=0.5, + T_0=-1, + T_mult=1, + ) + # check type of T_0 + with self.assertRaises(ValueError): + paddle.optimizer.lr.CosineAnnealingWarmRestarts( + learning_rate=0.5, + T_0=1.0, + T_mult=1, + ) + # check value of T_mult + with self.assertRaises(ValueError): + paddle.optimizer.lr.CosineAnnealingWarmRestarts( + learning_rate=0.5, + T_0=1, + T_mult=-1, + ) + # check type of T_mult + with self.assertRaises(ValueError): + paddle.optimizer.lr.CosineAnnealingWarmRestarts( + learning_rate=0.5, + T_0=1, + T_mult=1.0, + ) + + places = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + + for place in places: + for T_0 in [1, 2, 3]: + kwargs = { + 'learning_rate': 0.5, + 'T_0': T_0, + 'T_mult': 2, + 'eta_min': 0, + 'last_epoch': -1, + 'verbose': False, + } + paddle.enable_static() + self._test_static(place, kwargs) + paddle.disable_static(place) + self._test_dygraph(place, kwargs) + paddle.enable_static() + + def _test_static(self, place, kwargs): + paddle.enable_static() + v_l = { + 'base_lr': kwargs['learning_rate'], + 'T_0': kwargs['T_0'], + 'T_i': kwargs['T_0'], + 'T_mult': kwargs['T_mult'], + 'eta_min': kwargs['eta_min'], + 'T_cur': -1, + 'last_epoch': -1, + } + scheduler = paddle.optimizer.lr.CosineAnnealingWarmRestarts(**kwargs) + adam = paddle.optimizer.Adam(learning_rate=scheduler) + + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[3, 4, 5]) + loss = paddle.mean(x) + adam.minimize(loss) + lr_var = adam._global_learning_rate() + test_prog = main_prog.clone() + + exe = paddle.static.Executor(place) + exe.run(start_prog) + + for epoch in range(5): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={'x': np.random.randn(3, 4, 5).astype('float32')}, + fetch_list=lr_var.name, + ) + expected_lr = np.array( + cosine_annealing_warm_restarts_lr(epoch, v_l) + ).astype(out[0].dtype) + self.assertEqual(out[0], expected_lr) + scheduler.step(epoch) + + for epoch in range(5): + for batch_id in range(2): + out = exe.run( + test_prog, + feed={'x': np.random.randn(3, 4, 5).astype('float32')}, + fetch_list=lr_var.name, + ) + expected_lr = np.array( + cosine_annealing_warm_restarts_lr(epoch_num=None, v_l=v_l) + ).astype(out[0].dtype) + self.assertEqual(out[0], expected_lr) + scheduler.step() + + def _test_dygraph(self, place, kwargs): + paddle.disable_static(place) + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + v_l = { + 'base_lr': kwargs['learning_rate'], + 'T_0': kwargs['T_0'], + 'T_i': kwargs['T_0'], + 'T_mult': kwargs['T_mult'], + 'eta_min': kwargs['eta_min'], + 'T_cur': -1, + 'last_epoch': -1, + } + + scheduler = paddle.optimizer.lr.CosineAnnealingWarmRestarts(**kwargs) + adam = paddle.optimizer.Adam( + learning_rate=scheduler, parameters=linear.parameters() + ) + + for epoch in range(10): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.mean(out) + loss.backward() + adam.step() + adam.clear_grad() + current_lr = adam.get_lr() + expected_lr = cosine_annealing_warm_restarts_lr(epoch, v_l) + self.assertEqual(current_lr, expected_lr) + scheduler.step(epoch) + + for epoch in range(10): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.mean(out) + loss.backward() + adam.step() + adam.clear_grad() + current_lr = scheduler.get_lr() + expected_lr = cosine_annealing_warm_restarts_lr( + epoch_num=None, v_l=v_l + ) + self.assertEqual(current_lr, expected_lr) + scheduler.step() + + def noam_lr(epoch_num, d_model, warmup_steps, learning_rate=1.0, verbose=False): if epoch_num == 0: a = 1 diff --git a/test/legacy_test/test_masked_fill.py b/test/legacy_test/test_masked_fill.py new file mode 100644 index 0000000000000..ec511f9b680e4 --- /dev/null +++ b/test/legacy_test/test_masked_fill.py @@ -0,0 +1,328 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import convert_float_to_uint16 + +import paddle +from paddle import base +from paddle.base import core + + +def np_masked_fill(x, mask, value): + if not np.isscalar(value): + value = value[0] + + x, mask = np.broadcast_arrays(x, mask) + result = np.copy(x) + for idx, m in np.ndenumerate(mask): + if m: + result[idx] = value + return result + + +paddle.enable_static() + + +class TestMaskedFillAPI(unittest.TestCase): + def setUp(self): + self.init() + + self.x_np = np.random.random(self.x_shape).astype(self.dtype) + self.mask_np = np.array( + np.random.randint(2, size=self.mask_shape), dtype="bool" + ) + + self.value_np = np.random.randn(1).astype(self.dtype) + self.out_np = np_masked_fill(self.x_np, self.mask_np, self.value_np) + + def init(self): + self.x_shape = (50, 3) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.scalar_value = False + + def test_static_graph(self): + paddle.enable_static() + startup_program = base.Program() + train_program = base.Program() + with base.program_guard(startup_program, train_program): + x = paddle.static.data( + name='x', dtype=self.dtype, shape=self.x_shape + ) + mask = paddle.static.data( + name='mask', dtype='bool', shape=self.mask_shape + ) + value = paddle.static.data( + name='value', dtype=self.dtype, shape=self.value_np.shape + ) + out = paddle.masked_fill(x, mask, value) + + place = ( + base.CUDAPlace(0) + if core.is_compiled_with_cuda() + else base.CPUPlace() + ) + exe = base.Executor(place) + res = exe.run( + base.default_main_program(), + feed={ + 'x': self.x_np, + 'mask': self.mask_np, + 'value': self.value_np, + }, + fetch_list=[out], + ) + np.testing.assert_allclose( + res[0], self.out_np, atol=1e-5, rtol=1e-5 + ) + paddle.disable_static() + + def test_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x_np, dtype=self.dtype) + mask = paddle.to_tensor(self.mask_np).astype('bool') + if self.scalar_value: + value = self.value_np[0] + else: + value = paddle.to_tensor(self.value_np, dtype=self.dtype) + result = paddle.masked_fill(x, mask, value) + np.testing.assert_allclose(self.out_np, result.numpy(), rtol=1e-05) + + paddle.enable_static() + + +class TestMaskedFillAPI1(TestMaskedFillAPI): + def init(self): + self.x_shape = (6, 8, 9, 18) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.scalar_value = False + + +class TestMaskedFillAPI2(TestMaskedFillAPI): + def init(self): + self.x_shape = (168,) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.scalar_value = False + + +class TestMaskedFillAPI3(TestMaskedFillAPI): + def init(self): + self.x_shape = (6, 8, 9, 18) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.scalar_value = True + + +class TestMaskedFillGrad(unittest.TestCase): + def setUp(self): + self.typelist = ['float32', 'float64', 'int32', 'int64'] + self.places = [base.CPUPlace()] + if base.core.is_compiled_with_cuda(): + self.places.append(base.CUDAPlace(0)) + self.dtype = "float32" + + def test_backward(self): + expected_np = np.array( + [[2, 1, 1], [2, 1, 1], [2, 1, 1], [2, 1, 1]] + ).astype('float32') + expected_y_grad = np.array( + [[1, 0, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0]] + ).astype('float32') + expected_v_grad = np.array(8).astype('float32') + + for idx, p in enumerate(self.places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in self.typelist: + v = paddle.to_tensor(np.array(1).astype(self.dtype)) + x = paddle.ones((4, 3), dtype=self.dtype) + mask = paddle.to_tensor(np.array([0, 1, 1]).astype("bool")) + x.stop_gradient = False + v.stop_gradient = False + y = x * 2 + y.retain_grads() + ny = y.masked_fill(mask=mask, value=v) + loss = ny.sum() + loss.backward() + + self.assertEqual( + (ny.numpy().astype('float32') == expected_np).all(), True + ) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_y_grad).all(), + True, + ) + self.assertEqual( + (v.grad.numpy().astype('float32') == expected_v_grad).all(), + True, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedFillFP16API1(TestMaskedFillAPI): + def init(self): + self.x_shape = (6, 8, 9, 18) + self.mask_shape = self.x_shape + self.dtype = "float16" + self.scalar_value = False + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedFillFP16API2(TestMaskedFillAPI): + def init(self): + self.x_shape = (168,) + self.mask_shape = self.x_shape + self.dtype = "float16" + self.scalar_value = False + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedFillFP16API3(TestMaskedFillAPI): + def init(self): + self.x_shape = (168,) + self.mask_shape = self.x_shape + self.dtype = "float16" + self.scalar_value = True + + +class TestMaskedFillAPIBroadcast(TestMaskedFillAPI): + def init(self): + self.x_shape = (3, 40) + self.mask_shape = (3, 1) + self.dtype = "float32" + self.scalar_value = False + + +class TestMaskedFillAPIBroadcast2(TestMaskedFillAPI): + def init(self): + self.x_shape = (3, 3) + self.mask_shape = (1, 3) + self.dtype = "float32" + self.scalar_value = False + + +class TestMaskedFillAPIBroadcast3(TestMaskedFillAPI): + def init(self): + self.x_shape = (120,) + self.mask_shape = (300, 120) + self.dtype = "float32" + self.scalar_value = False + + +class TestMaskedFillAPIBroadcast4(TestMaskedFillAPI): + def init(self): + self.x_shape = (300, 40) + self.mask_shape = (40,) + self.dtype = "float32" + self.scalar_value = False + + +class TestMaskedFillAPIBroadcast5(TestMaskedFillAPI): + def init(self): + self.x_shape = (300, 40) + self.mask_shape = (40,) + self.dtype = "float32" + self.scalar_value = True + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedFillFP16APIBroadcast(TestMaskedFillAPI): + def init(self): + self.x_shape = (3, 40) + self.mask_shape = (3, 1) + self.dtype = "float16" + self.scalar_value = False + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedFillFP16APIBroadcast2(TestMaskedFillAPI): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 40) + self.dtype = "float16" + self.scalar_value = False + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedFillFP16APIBroadcast3(TestMaskedFillAPI): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 40) + self.dtype = "float16" + self.scalar_value = True + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestMaskedFillBF16(TestMaskedFillAPI): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 1) + self.dtype = "uint16" + self.scalar_value = False + + def setUp(self): + self.init() + + self.x_np = convert_float_to_uint16( + np.random.random(self.x_shape).astype("float32") + ) + self.mask_np = np.array( + np.random.randint(2, size=self.mask_shape), dtype="bool" + ) + + self.value_np = convert_float_to_uint16( + np.random.randn(1).astype("float32") + ) + self.out_np = np_masked_fill(self.x_np, self.mask_np, self.value_np) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestMaskedFillBF16APIBroadcast2(TestMaskedFillBF16): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 3) + self.dtype = "uint16" + self.scalar_value = False + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/test/legacy_test/test_math_op_patch_pir.py b/test/legacy_test/test_math_op_patch_pir.py index e9d2ee096d7dd..95a6be11bf501 100644 --- a/test/legacy_test/test_math_op_patch_pir.py +++ b/test/legacy_test/test_math_op_patch_pir.py @@ -14,13 +14,411 @@ import inspect import unittest +import warnings + +import numpy as np import paddle +from paddle import base paddle.enable_static() +paddle.device.set_device("cpu") + + +def new_program(): + # TODO(gouzil): Optimize program code + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + place = base.CPUPlace() + exe = base.Executor(place) + return ( + main_program, + exe, + paddle.static.program_guard( + main_program=main_program, startup_program=startup_program + ), + ) class TestMathOpPatchesPir(unittest.TestCase): + def test_pow(self): + # Calculate results in dynamic graphs + paddle.disable_static() + x_np = np.random.random([10, 1024]).astype('float32') + y_np = np.random.random([10, 1024]).astype('float32') + res_np_b = x_np**y_np + res_np_c = paddle.pow(paddle.to_tensor(x_np), 2) + res_np_d = x_np.__pow__(2) + res_np_e = x_np.__rpow__(2) + paddle.enable_static() + # Calculate results under pir + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data( + name='x', shape=[10, 1024], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[10, 1024], dtype='float32' + ) + b = x**y + c = x.pow(2) + d = x.__pow__(2) + e = x.__rpow__(2) + # TODO(gouzil): Why not use `paddle.static.default_main_program()`? + # Because different case do not isolate parameters (This is a known problem) + (b_np, c_np, d_np, e_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c, d, e], + ) + np.testing.assert_allclose(res_np_b, b_np, rtol=1e-05) + np.testing.assert_allclose(res_np_c, c_np, rtol=1e-05) + np.testing.assert_allclose(res_np_d, d_np, rtol=1e-05) + np.testing.assert_allclose(res_np_e, e_np, rtol=1e-05) + + def test_mod(self): + paddle.disable_static() + x_np = np.random.randint(1, 100, size=[10, 1024], dtype=np.int64) + y_np = np.random.randint(1, 100, size=[10, 1024], dtype=np.int64) + res_np_b = x_np % y_np + res_np_c = paddle.mod(paddle.to_tensor(x_np), paddle.to_tensor(y_np)) + res_np_d = x_np.__mod__(y_np) + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data( + name='x', shape=[10, 1024], dtype='int64' + ) + y = paddle.static.data( + name='y', shape=[10, 1024], dtype='int64' + ) + b = x % y + c = x.mod(y) + d = x.__mod__(y) + (b_np, c_np, d_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c, d], + ) + np.testing.assert_allclose(res_np_b, b_np, atol=1e-05) + np.testing.assert_allclose(res_np_c, c_np, atol=1e-05) + np.testing.assert_allclose(res_np_d, d_np, atol=1e-05) + + def test_matmul(self): + paddle.disable_static() + x_np = np.random.uniform(-1, 1, [2, 3]).astype('float32') + y_np = np.random.uniform(-1, 1, [3, 5]).astype('float32') + res_np_b = x_np @ y_np # __matmul__ + res_np_c = paddle.matmul(paddle.to_tensor(x_np), paddle.to_tensor(y_np)) + res_np_d = x_np.__matmul__(y_np) + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data(name='x', shape=[2, 3], dtype='float32') + y = paddle.static.data(name='y', shape=[3, 5], dtype='float32') + b = x @ y + c = x.matmul(y) + d = x.__matmul__(y) + (b_np, c_np, d_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c, d], + ) + np.testing.assert_allclose(res_np_b, b_np, atol=1e-05) + np.testing.assert_allclose(res_np_c, c_np, atol=1e-05) + np.testing.assert_allclose(res_np_d, d_np, atol=1e-05) + + def test_floordiv(self): + paddle.disable_static() + x_np = np.full([10, 1024], 10, np.int64) + y_np = np.full([10, 1024], 2, np.int64) + res_np_b = x_np // y_np + res_np_c = paddle.floor_divide( + paddle.to_tensor(x_np), paddle.to_tensor(y_np) + ) + res_np_d = x_np.__floordiv__(y_np) + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data( + name='x', shape=[10, 1024], dtype='int64' + ) + y = paddle.static.data( + name='y', shape=[10, 1024], dtype='int64' + ) + b = x // y + c = x.floor_divide(y) + d = x.__floordiv__(y) + (b_np, c_np, d_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c, d], + ) + np.testing.assert_allclose(res_np_b, b_np, atol=1e-05) + np.testing.assert_allclose(res_np_c, c_np, atol=1e-05) + np.testing.assert_allclose(res_np_d, d_np, atol=1e-05) + + def test_bitwise_not(self): + paddle.disable_static() + x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + res_np_b = ~x_np + res_np_c = paddle.bitwise_not(paddle.to_tensor(x_np)) + res_np_d = x_np.__invert__() + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data(name='x', shape=[2, 3, 5], dtype='int32') + b = ~x + c = x.bitwise_not() + d = x.__invert__() + (b_np, c_np, d_np) = exe.run( + main_program, + feed={"x": x_np}, + fetch_list=[b, c, d], + ) + np.testing.assert_array_equal(res_np_b, b_np) + np.testing.assert_array_equal(res_np_c, c_np) + np.testing.assert_array_equal(res_np_d, d_np) + + def test_bitwise_xor(self): + paddle.disable_static() + x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + res_np_b = x_np ^ y_np + res_np_c = paddle.bitwise_xor( + paddle.to_tensor(x_np), paddle.to_tensor(y_np) + ) + res_np_d = x_np.__xor__(y_np) + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data(name="x", shape=[2, 3, 5], dtype="int32") + y = paddle.static.data(name="y", shape=[2, 3, 5], dtype="int32") + b = x ^ y + c = x.bitwise_xor(y) + d = x.__xor__(y) + (b_np, c_np, d_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c, d], + ) + np.testing.assert_array_equal(res_np_b, b_np) + np.testing.assert_array_equal(res_np_c, c_np) + np.testing.assert_array_equal(res_np_d, d_np) + + def test_bitwise_or(self): + paddle.disable_static() + x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + res_np_b = x_np | y_np + res_np_c = paddle.bitwise_or( + paddle.to_tensor(x_np), paddle.to_tensor(y_np) + ) + res_np_d = x_np.__or__(y_np) + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data(name="x", shape=[2, 3, 5], dtype="int32") + y = paddle.static.data(name="y", shape=[2, 3, 5], dtype="int32") + b = x | y + c = x.bitwise_or(y) + d = x.__or__(y) + (b_np, c_np, d_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c, d], + ) + np.testing.assert_array_equal(res_np_b, b_np) + np.testing.assert_array_equal(res_np_c, c_np) + np.testing.assert_array_equal(res_np_d, d_np) + + def test_bitwise_and(self): + paddle.disable_static() + x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + res_np_b = x_np & y_np + res_np_c = paddle.bitwise_and( + paddle.to_tensor(x_np), paddle.to_tensor(y_np) + ) + res_np_d = x_np.__and__(y_np) + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data(name="x", shape=[2, 3, 5], dtype="int32") + y = paddle.static.data(name="y", shape=[2, 3, 5], dtype="int32") + b = x & y + c = x.bitwise_and(y) + d = x.__and__(y) + (b_np, c_np, d_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c, d], + ) + np.testing.assert_array_equal(res_np_b, b_np) + np.testing.assert_array_equal(res_np_c, c_np) + np.testing.assert_array_equal(res_np_d, d_np) + + # for logical compare + def test_equal_and_nequal(self): + paddle.disable_static() + x_np = np.array([3, 4, 10, 14, 9, 18]).astype('float32') + y_np = np.array([3, 4, 11, 15, 8, 18]).astype('float32') + # TODO(gouzil): Open after deleting c++ logic + # res_np_b = x_np == y_np + # res_np_c = paddle.equal(paddle.to_tensor(x_np), paddle.to_tensor(y_np)) + # res_np_d = x_np.__eq__(y_np) + res_np_e = x_np != y_np + res_np_f = paddle.not_equal( + paddle.to_tensor(x_np), paddle.to_tensor(y_np) + ) + res_np_g = x_np.__ne__(y_np) + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data(name="x", shape=[-1, 1], dtype='float32') + y = paddle.static.data(name="y", shape=[-1, 1], dtype='float32') + # b = x == y + # c = x.equal(y) + # d = x.__eq__(y) + e = x != y + f = x.not_equal(y) + g = x.__ne__(y) + (e_np, f_np, g_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[e, f, g], + ) + # np.testing.assert_array_equal(res_np_b, b_np) + # np.testing.assert_array_equal(res_np_c, c_np) + # np.testing.assert_array_equal(res_np_d, d_np) + np.testing.assert_array_equal(res_np_e, e_np) + np.testing.assert_array_equal(res_np_f, f_np) + np.testing.assert_array_equal(res_np_g, g_np) + + def test_less(self): + paddle.disable_static() + x_np = np.array([3, 4, 10, 14, 9, 18]).astype('float32') + y_np = np.array([3, 4, 11, 15, 8, 18]).astype('float32') + z_np = np.array([3, 4, 10, 14, 9, 18]).astype('float32') + res_np_b = x_np < y_np + res_np_c = paddle.less_than( + paddle.to_tensor(x_np), paddle.to_tensor(y_np) + ) + res_np_d = x_np.__lt__(y_np) + res_np_e = x_np <= y_np + res_np_f = paddle.less_equal( + paddle.to_tensor(x_np), paddle.to_tensor(y_np) + ) + res_np_g = x_np.__le__(y_np) + res_np_h = x_np <= z_np + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data(name="x", shape=[-1, 1], dtype='float32') + y = paddle.static.data(name="y", shape=[-1, 1], dtype='float32') + z = paddle.static.data(name="z", shape=[-1, 1], dtype='float32') + b = x < y + c = x.less_than(y) + d = x.__lt__(y) + e = x <= y + f = x.less_equal(y) + g = x.__le__(y) + h = x <= z + (b_np, c_np, d_np, e_np, f_np, g_np, h_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np, "z": z_np}, + fetch_list=[b, c, d, e, f, g, h], + ) + np.testing.assert_array_equal(res_np_b, b_np) + np.testing.assert_array_equal(res_np_c, c_np) + np.testing.assert_array_equal(res_np_d, d_np) + np.testing.assert_array_equal(res_np_e, e_np) + np.testing.assert_array_equal(res_np_f, f_np) + np.testing.assert_array_equal(res_np_g, g_np) + np.testing.assert_array_equal(res_np_h, h_np) + + def test_greater(self): + paddle.disable_static() + x_np = np.array([3, 4, 10, 14, 9, 18]).astype('float32') + y_np = np.array([3, 4, 11, 15, 8, 18]).astype('float32') + z_np = np.array([3, 4, 10, 14, 9, 18]).astype('float32') + res_np_b = x_np > y_np + res_np_c = paddle.greater_than( + paddle.to_tensor(x_np), paddle.to_tensor(y_np) + ) + res_np_d = x_np.__gt__(y_np) + res_np_e = x_np >= y_np + res_np_f = paddle.greater_equal( + paddle.to_tensor(x_np), paddle.to_tensor(y_np) + ) + res_np_g = x_np.__ge__(y_np) + res_np_h = x_np >= z_np + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data(name="x", shape=[-1, 1], dtype='float32') + y = paddle.static.data(name="y", shape=[-1, 1], dtype='float32') + z = paddle.static.data(name="z", shape=[-1, 1], dtype='float32') + b = x > y + c = x.greater_than(y) + d = x.__gt__(y) + e = x >= y + f = x.greater_equal(y) + g = x.__ge__(y) + h = x >= z + (b_np, c_np, d_np, e_np, f_np, g_np, h_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np, "z": z_np}, + fetch_list=[b, c, d, e, f, g, h], + ) + np.testing.assert_array_equal(res_np_b, b_np) + np.testing.assert_array_equal(res_np_c, c_np) + np.testing.assert_array_equal(res_np_d, d_np) + np.testing.assert_array_equal(res_np_e, e_np) + np.testing.assert_array_equal(res_np_f, f_np) + np.testing.assert_array_equal(res_np_g, g_np) + np.testing.assert_array_equal(res_np_h, h_np) + + def test_item(self): + with paddle.pir_utils.IrGuard(): + x = paddle.static.data(name='x', shape=[3, 2, 1]) + y = paddle.static.data( + name='y', + shape=[ + 3, + ], + ) + self.assertTrue(y.item() == y) + with self.assertRaises(TypeError): + x.item() + + def test_place(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + with paddle.pir_utils.IrGuard(): + x = paddle.static.data(name='x', shape=[3, 2, 1]) + x.place() + self.assertTrue(len(w) == 1) + self.assertTrue("place" in str(w[-1].message)) + + def test_some_dim(self): + with paddle.pir_utils.IrGuard(): + x = paddle.static.data(name='x', shape=[3, 2, 1]) + self.assertEqual(x.dim(), 3) + self.assertEqual(x.ndimension(), 3) + self.assertEqual(x.ndim, 3) + def test_math_exists(self): with paddle.pir_utils.IrGuard(): a = paddle.static.data(name='a', shape=[1], dtype='float32') diff --git a/test/legacy_test/test_matmul_v2_op.py b/test/legacy_test/test_matmul_v2_op.py index eb893971e026b..acdb453dc14ff 100644 --- a/test/legacy_test/test_matmul_v2_op.py +++ b/test/legacy_test/test_matmul_v2_op.py @@ -21,6 +21,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): @@ -508,7 +509,9 @@ def setUp(self): def check_static_result(self, place): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): input_x = paddle.static.data( name="input_x", shape=[4, 3], dtype="float32" ) @@ -523,12 +526,13 @@ def check_static_result(self, place): exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={"input_x": x_np, "input_y": y_np}, fetch_list=[result], ) paddle.disable_static() + @test_with_pir_api def test_static(self): for place in self.places: self.check_static_result(place=place) diff --git a/test/legacy_test/test_max_min_amax_amin_op.py b/test/legacy_test/test_max_min_amax_amin_op.py index b5184bd3acd20..4c07869f6f988 100644 --- a/test/legacy_test/test_max_min_amax_amin_op.py +++ b/test/legacy_test/test_max_min_amax_amin_op.py @@ -19,6 +19,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -90,6 +91,7 @@ def _choose_paddle_func(self, func, x): return out # We check the output between paddle API and numpy in static graph. + @test_with_pir_api def test_static_graph(self): def _test_static_graph(func): startup_program = base.Program() @@ -103,7 +105,6 @@ def _test_static_graph(func): exe = base.Executor(self.place) res = exe.run( - base.default_main_program(), feed={'input': self.x_np}, fetch_list=[out], ) diff --git a/test/legacy_test/test_meshgrid_op.py b/test/legacy_test/test_meshgrid_op.py index d8324612e78e4..215424b9c9236 100644 --- a/test/legacy_test/test_meshgrid_op.py +++ b/test/legacy_test/test_meshgrid_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def meshgrid_wrapper(x): @@ -41,10 +42,12 @@ def init_data_type(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_prim=True) + self.check_output(check_prim=True, check_pir=True) def test_check_grad(self): - self.check_grad(['x0'], ['out0', 'out1'], check_prim=True) + self.check_grad( + ['x0'], ['out0', 'out1'], check_prim=True, check_pir=True + ) def init_inputs_and_outputs(self): self.shape = self.get_x_shape() @@ -122,19 +125,21 @@ def if_enable_cinn(self): self.enable_cinn = False def test_check_output(self): - self.check_output_with_place(place=paddle.CUDAPlace(0)) + self.check_output_with_place(place=paddle.CUDAPlace(0), check_pir=True) def test_check_grad(self): self.check_grad_with_place( - paddle.CUDAPlace(0), ['x0'], ['out0', 'out1'], check_prim=True + paddle.CUDAPlace(0), + ['x0'], + ['out0', 'out1'], + check_prim=True, + check_pir=True, ) class TestMeshgridOp3(unittest.TestCase): + @test_with_pir_api def test_api(self): - x = paddle.static.data(shape=[100], dtype='int32', name='x') - y = paddle.static.data(shape=[200], dtype='int32', name='y') - input_1 = np.random.randint( 0, 100, @@ -155,22 +160,24 @@ def test_api(self): out_2 = np.reshape(input_2, [1, 200]) out_2 = np.broadcast_to(out_2, [100, 200]) - exe = base.Executor(place=base.CPUPlace()) - grid_x, grid_y = paddle.tensor.meshgrid(x, y) - res_1, res_2 = exe.run( - base.default_main_program(), - feed={'x': input_1, 'y': input_2}, - fetch_list=[grid_x, grid_y], - ) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[100], dtype='int32', name='x') + y = paddle.static.data(shape=[200], dtype='int32', name='y') + + exe = base.Executor(place=base.CPUPlace()) + grid_x, grid_y = paddle.tensor.meshgrid(x, y) + res_1, res_2 = exe.run( + paddle.static.default_main_program(), + feed={'x': input_1, 'y': input_2}, + fetch_list=[grid_x, grid_y], + ) np.testing.assert_array_equal(res_1, out_1) np.testing.assert_array_equal(res_2, out_2) class TestMeshgridOp4(unittest.TestCase): + @test_with_pir_api def test_list_input(self): - x = paddle.static.data(shape=[100], dtype='int32', name='x') - y = paddle.static.data(shape=[200], dtype='int32', name='y') - input_1 = np.random.randint( 0, 100, @@ -191,23 +198,24 @@ def test_list_input(self): out_2 = np.reshape(input_2, [1, 200]) out_2 = np.broadcast_to(out_2, [100, 200]) - exe = base.Executor(place=base.CPUPlace()) - grid_x, grid_y = paddle.tensor.meshgrid([x, y]) - res_1, res_2 = exe.run( - base.default_main_program(), - feed={'x': input_1, 'y': input_2}, - fetch_list=[grid_x, grid_y], - ) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[100], dtype='int32', name='x') + y = paddle.static.data(shape=[200], dtype='int32', name='y') + exe = base.Executor(place=base.CPUPlace()) + grid_x, grid_y = paddle.tensor.meshgrid([x, y]) + res_1, res_2 = exe.run( + paddle.static.default_main_program(), + feed={'x': input_1, 'y': input_2}, + fetch_list=[grid_x, grid_y], + ) np.testing.assert_array_equal(res_1, out_1) np.testing.assert_array_equal(res_2, out_2) class TestMeshgridOp5(unittest.TestCase): + @test_with_pir_api def test_tuple_input(self): - x = paddle.static.data(shape=[100], dtype='int32', name='x') - y = paddle.static.data(shape=[200], dtype='int32', name='y') - input_1 = np.random.randint( 0, 100, @@ -228,14 +236,17 @@ def test_tuple_input(self): out_2 = np.reshape(input_2, [1, 200]) out_2 = np.broadcast_to(out_2, [100, 200]) - exe = base.Executor(place=base.CPUPlace()) - grid_x, grid_y = paddle.tensor.meshgrid((x, y)) - res_1, res_2 = exe.run( - base.default_main_program(), - feed={'x': input_1, 'y': input_2}, - fetch_list=[grid_x, grid_y], - ) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[100], dtype='int32', name='x') + y = paddle.static.data(shape=[200], dtype='int32', name='y') + exe = base.Executor(place=base.CPUPlace()) + grid_x, grid_y = paddle.tensor.meshgrid((x, y)) + res_1, res_2 = exe.run( + paddle.static.default_main_program(), + feed={'x': input_1, 'y': input_2}, + fetch_list=[grid_x, grid_y], + ) np.testing.assert_array_equal(res_1, out_1) np.testing.assert_array_equal(res_2, out_2) diff --git a/test/legacy_test/test_min_op.py b/test/legacy_test/test_min_op.py index e24471b20dca8..78601c77ecf06 100644 --- a/test/legacy_test/test_min_op.py +++ b/test/legacy_test/test_min_op.py @@ -21,6 +21,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api class ApiMinTest(unittest.TestCase): @@ -30,6 +31,7 @@ def setUp(self): else: self.place = core.CPUPlace() + @test_with_pir_api def test_api(self): paddle.enable_static() with paddle.static.program_guard( diff --git a/test/legacy_test/test_mse_loss.py b/test/legacy_test/test_mse_loss.py index 688895240a374..ab2e9deaef488 100644 --- a/test/legacy_test/test_mse_loss.py +++ b/test/legacy_test/test_mse_loss.py @@ -20,9 +20,11 @@ from paddle import base from paddle.base import core from paddle.base.executor import Executor +from paddle.pir_utils import test_with_pir_api class TestMseLoss(unittest.TestCase): + @test_with_pir_api def test_mse_loss(self): input_val = np.random.uniform(0.1, 0.5, (2, 3)).astype("float32") label_val = np.random.uniform(0.1, 0.5, (2, 3)).astype("float32") @@ -30,29 +32,35 @@ def test_mse_loss(self): sub = input_val - label_val np_result = np.mean(sub * sub) - input_var = paddle.static.data( - name="input", shape=[-1, 3], dtype="float32" - ) - label_var = paddle.static.data( - name="label", shape=[-1, 3], dtype="float32" - ) - - output = paddle.nn.functional.mse_loss(input=input_var, label=label_var) - for use_cuda in ( - [False, True] if core.is_compiled_with_cuda() else [False] - ): - place = base.CUDAPlace(0) if use_cuda else base.CPUPlace() - exe = Executor(place) - (result,) = exe.run( - base.default_main_program(), - feed={"input": input_val, "label": label_val}, - fetch_list=[output], + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + input_var = paddle.static.data( + name="input", shape=[-1, 3], dtype="float32" ) + label_var = paddle.static.data( + name="label", shape=[-1, 3], dtype="float32" + ) + + output = paddle.nn.functional.mse_loss( + input=input_var, label=label_var + ) + for use_cuda in ( + [False, True] if core.is_compiled_with_cuda() else [False] + ): + place = base.CUDAPlace(0) if use_cuda else base.CPUPlace() + exe = Executor(place) + (result,) = exe.run( + main, + feed={"input": input_val, "label": label_val}, + fetch_list=[output], + ) - np.testing.assert_allclose(np_result, result, rtol=1e-05) + np.testing.assert_allclose(np_result, result, rtol=1e-05) class TestMseInvalidInput(unittest.TestCase): + @test_with_pir_api def test_error(self): def test_invalid_input(): input = [256, 3] @@ -74,6 +82,7 @@ def test_invalid_label(): class TestNNMseLoss(unittest.TestCase): + @test_with_pir_api def test_NNMseLoss_mean(self): for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: input_np = np.random.uniform(0.1, 0.5, dim).astype("float32") @@ -88,13 +97,11 @@ def test_NNMseLoss_mean(self): ) with base.program_guard(prog, startup_prog): input = paddle.static.data( - name='input', shape=[-1] + dim, dtype='float32' + name='input', shape=dim, dtype='float32' ) - input.desc.set_need_check_feed(False) label = paddle.static.data( - name='label', shape=[-1] + dim, dtype='float32' + name='label', shape=dim, dtype='float32' ) - label.desc.set_need_check_feed(False) mse_loss = paddle.nn.loss.MSELoss() ret = mse_loss(input, label) @@ -120,6 +127,7 @@ def test_NNMseLoss_mean(self): np.testing.assert_allclose(dy_result, expected, rtol=1e-05) self.assertEqual(dy_result.shape, ()) + @test_with_pir_api def test_NNMseLoss_sum(self): for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: input_np = np.random.uniform(0.1, 0.5, dim).astype("float32") @@ -134,13 +142,11 @@ def test_NNMseLoss_sum(self): ) with base.program_guard(prog, startup_prog): input = paddle.static.data( - name='input', shape=[-1] + dim, dtype='float32' + name='input', shape=dim, dtype='float32' ) - input.desc.set_need_check_feed(False) label = paddle.static.data( - name='label', shape=[-1] + dim, dtype='float32' + name='label', shape=dim, dtype='float32' ) - label.desc.set_need_check_feed(False) mse_loss = paddle.nn.loss.MSELoss(reduction='sum') ret = mse_loss(input, label) @@ -166,6 +172,7 @@ def test_NNMseLoss_sum(self): np.testing.assert_allclose(dy_result, expected, rtol=1e-05) self.assertEqual(dy_result.shape, ()) + @test_with_pir_api def test_NNMseLoss_none(self): for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: input_np = np.random.uniform(0.1, 0.5, dim).astype("float32") @@ -180,13 +187,11 @@ def test_NNMseLoss_none(self): ) with base.program_guard(prog, startup_prog): input = paddle.static.data( - name='input', shape=[-1] + dim, dtype='float32' + name='input', shape=dim, dtype='float32' ) - input.desc.set_need_check_feed(False) label = paddle.static.data( - name='label', shape=[-1] + dim, dtype='float32' + name='label', shape=dim, dtype='float32' ) - label.desc.set_need_check_feed(False) mse_loss = paddle.nn.loss.MSELoss(reduction='none') ret = mse_loss(input, label) @@ -214,6 +219,7 @@ def test_NNMseLoss_none(self): class TestNNFunctionalMseLoss(unittest.TestCase): + @test_with_pir_api def test_NNFunctionalMseLoss_mean(self): for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: input_np = np.random.uniform(0.1, 0.5, dim).astype("float32") @@ -256,6 +262,7 @@ def test_NNFunctionalMseLoss_mean(self): np.testing.assert_allclose(dy_result, expected, rtol=1e-05) self.assertEqual(dy_result.shape, ()) + @test_with_pir_api def test_NNFunctionalMseLoss_sum(self): for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: input_np = np.random.uniform(0.1, 0.5, dim).astype("float32") @@ -298,6 +305,7 @@ def test_NNFunctionalMseLoss_sum(self): np.testing.assert_allclose(dy_result, expected, rtol=1e-05) self.assertEqual(dy_result.shape, ()) + @test_with_pir_api def test_NNFunctionalMseLoss_none(self): for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: input_np = np.random.uniform(0.1, 0.5, dim).astype("float32") diff --git a/test/legacy_test/test_multinomial_op.py b/test/legacy_test/test_multinomial_op.py index dbb0961d15936..e886876b27583 100644 --- a/test/legacy_test/test_multinomial_op.py +++ b/test/legacy_test/test_multinomial_op.py @@ -178,6 +178,7 @@ class TestMultinomialBF16OP(OpTest): def setUp(self): paddle.enable_static() self.op_type = "multinomial" + self.python_api = paddle.multinomial self.dtype = np.uint16 self.init_data() self.inputs = {"X": convert_float_to_uint16(self.input_np)} @@ -190,7 +191,9 @@ def init_data(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place_customized(self.verify_output, place) + self.check_output_with_place_customized( + self.verify_output, place, check_pir=True + ) def sample_output(self, out): return sample_output_one_dimension(out, 4) diff --git a/test/legacy_test/test_nearest_interp_v2_op.py b/test/legacy_test/test_nearest_interp_v2_op.py index 3fe807de35de6..0bba6aa668504 100755 --- a/test/legacy_test/test_nearest_interp_v2_op.py +++ b/test/legacy_test/test_nearest_interp_v2_op.py @@ -383,10 +383,10 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) def init_test_case(self): create_test_case0(self) @@ -445,10 +445,10 @@ def init_test_case(self): class TestNearestInterpOpFP16(TestNearestInterpOp): def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) def init_test_case(self): create_test_case0(self) @@ -609,10 +609,10 @@ def setUp(self): self.outputs = {'Out': convert_float_to_uint16(output_np)} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) def init_test_case(self): create_test_case0(self) @@ -740,7 +740,9 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output_with_place(place=core.CPUPlace(), atol=1) + self.check_output_with_place( + place=core.CPUPlace(), atol=1, check_pir=True + ) def init_test_case(self): self.interp_method = 'nearest' @@ -895,10 +897,10 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) def init_test_case(self): self.interp_method = 'nearest' diff --git a/test/legacy_test/test_nonzero_api.py b/test/legacy_test/test_nonzero_api.py index a57e1d9803c22..a14c72a22a149 100644 --- a/test/legacy_test/test_nonzero_api.py +++ b/test/legacy_test/test_nonzero_api.py @@ -29,6 +29,7 @@ def call_nonzero(x): class TestNonZeroAPI(unittest.TestCase): def test_nonzero_api_as_tuple(self): + paddle.enable_static() data = np.array([[True, False], [False, True]]) with program_guard(Program(), Program()): x = paddle.static.data(name='x', shape=[-1, 2], dtype='float32') @@ -61,6 +62,7 @@ def test_nonzero_api_as_tuple(self): np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) def test_nonzero_api(self): + paddle.enable_static() data = np.array([[True, False], [False, True]]) with program_guard(Program(), Program()): x = paddle.static.data(name='x', shape=[-1, 2], dtype='float32') @@ -108,7 +110,7 @@ def setUp(self): self.outputs = self.return_outputs() def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def init_shape(self): self.shape = [8, 8] @@ -156,7 +158,7 @@ def setUp(self): self.outputs = self.return_outputs() def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def init_shape(self): self.shape = [12, 9] diff --git a/test/legacy_test/test_norm_all.py b/test/legacy_test/test_norm_all.py index 58be677975742..86eea3a4c8eb0 100644 --- a/test/legacy_test/test_norm_all.py +++ b/test/legacy_test/test_norm_all.py @@ -102,10 +102,10 @@ def setUp(self): self.outputs = {'Out': norm} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) def init_test_case(self): self.shape = [2, 3, 4, 5] @@ -126,7 +126,7 @@ def init_dtype(self): self.dtype = "float32" def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestPnormOp(OpTest): diff --git a/test/legacy_test/test_numel_op.py b/test/legacy_test/test_numel_op.py index 32f043dab1b9b..7e0f75c865077 100644 --- a/test/legacy_test/test_numel_op.py +++ b/test/legacy_test/test_numel_op.py @@ -18,8 +18,8 @@ from op_test import OpTest, convert_float_to_uint16 import paddle -from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestNumelOp(OpTest): @@ -148,10 +148,11 @@ def init(self): class TestNumelAPI(unittest.TestCase): + @test_with_pir_api def test_numel_static(self): - main_program = base.Program() - startup_program = base.Program() - with base.program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): shape1 = [2, 1, 4, 5] shape2 = [1, 4, 5] x_1 = paddle.static.data(shape=shape1, dtype='int32', name='x_1') @@ -188,9 +189,9 @@ def test_numel_imperative(self): paddle.enable_static() def test_error(self): - main_program = base.Program() - startup_program = base.Program() - with base.program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): def test_x_type(): shape = [1, 4, 5] @@ -199,6 +200,16 @@ def test_x_type(): self.assertRaises(TypeError, test_x_type) + def test_pir_error(self): + with paddle.pir_utils.IrGuard(): + + def test_x_type(): + shape = [1, 4, 5] + input_1 = np.random.random(shape).astype("int32") + out_1 = paddle.numel(input_1) + + self.assertRaises(ValueError, test_x_type) + if __name__ == '__main__': paddle.enable_static() diff --git a/test/legacy_test/test_op_name_conflict.py b/test/legacy_test/test_op_name_conflict.py index fbb717eda9609..2f6444d7c3ec2 100644 --- a/test/legacy_test/test_op_name_conflict.py +++ b/test/legacy_test/test_op_name_conflict.py @@ -18,9 +18,11 @@ import paddle from paddle import base +from paddle.pir_utils import test_with_pir_api class TestOpNameConflict(unittest.TestCase): + @test_with_pir_api def test_conflict(self): paddle.enable_static() main = base.Program() diff --git a/test/legacy_test/test_pairwise_distance.py b/test/legacy_test/test_pairwise_distance.py index 2f9199f48c04a..099b5ed085065 100644 --- a/test/legacy_test/test_pairwise_distance.py +++ b/test/legacy_test/test_pairwise_distance.py @@ -18,6 +18,7 @@ import paddle from paddle import base +from paddle.pir_utils import test_with_pir_api def np_pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False): @@ -77,6 +78,7 @@ def test_static( def test_dygraph( place, x_np, y_np, p=2.0, epsilon=1e-6, keepdim=False, functional=False ): + paddle.disable_static() x = paddle.to_tensor(x_np) y = paddle.to_tensor(y_np) if functional: @@ -88,6 +90,7 @@ def test_dygraph( x=x, y=y, p=p, epsilon=epsilon, keepdim=keepdim ) dygraph_ret = dy_distance.numpy() + paddle.enable_static() return dygraph_ret @@ -109,14 +112,6 @@ def test_pairwise_distance(self): x_np = np.random.random(shape).astype(dtype) y_np = np.random.random(shape).astype(dtype) - static_ret = test_static( - place, - x_np, - y_np, - p, - epsilon=epsilon, - keepdim=keepdim, - ) dygraph_ret = test_dygraph( place, x_np, @@ -129,27 +124,14 @@ def test_pairwise_distance(self): x_np, y_np, p, epsilon=epsilon, keepdim=keepdim ) - self.assertEqual( - static_ret.shape, excepted_value.shape - ) self.assertEqual( dygraph_ret.shape, excepted_value.shape ) - np.testing.assert_allclose( - static_ret, excepted_value, rtol=1e-05 - ) np.testing.assert_allclose( dygraph_ret, excepted_value, rtol=1e-05 ) - static_functional_ret = test_static( - place, - x_np, - y_np, - p, - epsilon=epsilon, - keepdim=keepdim, - ) + dygraph_functional_ret = test_dygraph( place, x_np, @@ -159,26 +141,58 @@ def test_pairwise_distance(self): keepdim=keepdim, ) - self.assertEqual( - static_functional_ret.shape, - excepted_value.shape, - ) self.assertEqual( dygraph_functional_ret.shape, excepted_value.shape, ) - np.testing.assert_allclose( - static_functional_ret, - excepted_value, - rtol=1e-05, - ) np.testing.assert_allclose( dygraph_functional_ret, excepted_value, rtol=1e-05, ) + @test_with_pir_api + def dynamic_and_pir_mode_test(): + static_ret = test_static( + place, + x_np, + y_np, + p, + epsilon=epsilon, + keepdim=keepdim, + ) + + self.assertEqual( + static_ret.shape, excepted_value.shape + ) + + np.testing.assert_allclose( + static_ret, excepted_value, rtol=1e-05 + ) + + static_functional_ret = test_static( + place, + x_np, + y_np, + p, + epsilon=epsilon, + keepdim=keepdim, + ) + + self.assertEqual( + static_functional_ret.shape, + excepted_value.shape, + ) + + np.testing.assert_allclose( + static_functional_ret, + excepted_value, + rtol=1e-05, + ) + + dynamic_and_pir_mode_test() + def test_pairwise_distance_broadcast_1(self): shape_x = [100, 100] shape_y = [100, 1] @@ -187,9 +201,7 @@ def test_pairwise_distance_broadcast_1(self): place = paddle.CPUPlace() x_np = np.random.random(shape_x).astype('float32') y_np = np.random.random(shape_y).astype('float32') - static_ret = test_static( - place=place, x_np=x_np, y_np=y_np, epsilon=epsilon, keepdim=keepdim - ) + dygraph_ret = test_dygraph( place=place, x_np=x_np, y_np=y_np, epsilon=epsilon, keepdim=keepdim ) @@ -197,20 +209,10 @@ def test_pairwise_distance_broadcast_1(self): x_np, y_np, epsilon=epsilon, keepdim=keepdim ) - self.assertEqual(static_ret.shape, excepted_value.shape) self.assertEqual(dygraph_ret.shape, excepted_value.shape) - np.testing.assert_allclose(static_ret, excepted_value, rtol=1e-05) np.testing.assert_allclose(dygraph_ret, excepted_value, rtol=1e-05) - static_functional_ret = test_static( - place=place, - x_np=x_np, - y_np=y_np, - epsilon=epsilon, - keepdim=keepdim, - functional=True, - ) dygraph_functional_ret = test_dygraph( place=place, x_np=x_np, @@ -220,16 +222,41 @@ def test_pairwise_distance_broadcast_1(self): functional=True, ) - self.assertEqual(static_functional_ret.shape, excepted_value.shape) self.assertEqual(dygraph_functional_ret.shape, excepted_value.shape) - np.testing.assert_allclose( - static_functional_ret, excepted_value, rtol=1e-05 - ) np.testing.assert_allclose( dygraph_functional_ret, excepted_value, rtol=1e-05 ) + @test_with_pir_api + def dynamic_and_pir_mode_test(): + static_ret = test_static( + place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim, + ) + + self.assertEqual(static_ret.shape, excepted_value.shape) + + np.testing.assert_allclose(static_ret, excepted_value, rtol=1e-05) + static_functional_ret = test_static( + place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim, + functional=True, + ) + + self.assertEqual(static_functional_ret.shape, excepted_value.shape) + np.testing.assert_allclose( + static_functional_ret, excepted_value, rtol=1e-05 + ) + + dynamic_and_pir_mode_test() + def test_pairwise_distance_broadcast_2(self): shape_x = [100, 100] shape_y = [100] @@ -238,9 +265,7 @@ def test_pairwise_distance_broadcast_2(self): place = paddle.CPUPlace() x_np = np.random.random(shape_x).astype('float32') y_np = np.random.random(shape_y).astype('float32') - static_ret = test_static( - place=place, x_np=x_np, y_np=y_np, epsilon=epsilon, keepdim=keepdim - ) + dygraph_ret = test_dygraph( place=place, x_np=x_np, y_np=y_np, epsilon=epsilon, keepdim=keepdim ) @@ -249,20 +274,10 @@ def test_pairwise_distance_broadcast_2(self): x_np, y_np, epsilon=epsilon, keepdim=keepdim ) - self.assertEqual(static_ret.shape, excepted_value.shape) self.assertEqual(dygraph_ret.shape, excepted_value.shape) - np.testing.assert_allclose(static_ret, excepted_value, rtol=1e-05) np.testing.assert_allclose(dygraph_ret, excepted_value, rtol=1e-05) - static_functional_ret = test_static( - place=place, - x_np=x_np, - y_np=y_np, - epsilon=epsilon, - keepdim=keepdim, - functional=True, - ) dygraph_functional_ret = test_dygraph( place=place, x_np=x_np, @@ -272,16 +287,44 @@ def test_pairwise_distance_broadcast_2(self): functional=True, ) - self.assertEqual(static_functional_ret.shape, excepted_value.shape) self.assertEqual(dygraph_functional_ret.shape, excepted_value.shape) - np.testing.assert_allclose( - static_functional_ret, excepted_value, rtol=1e-05 - ) np.testing.assert_allclose( dygraph_functional_ret, excepted_value, rtol=1e-05 ) + @test_with_pir_api + def dynamic_and_pir_mode_test(): + static_ret = test_static( + place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim, + ) + + self.assertEqual(static_ret.shape, excepted_value.shape) + + np.testing.assert_allclose(static_ret, excepted_value, rtol=1e-05) + + static_functional_ret = test_static( + place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim, + functional=True, + ) + + self.assertEqual(static_functional_ret.shape, excepted_value.shape) + + np.testing.assert_allclose( + static_functional_ret, excepted_value, rtol=1e-05 + ) + + dynamic_and_pir_mode_test() + + @test_with_pir_api def test_pairwise_distance_fp16(self): shape = [100, 100] if not paddle.device.is_compiled_with_cuda(): diff --git a/test/legacy_test/test_pass_builder.py b/test/legacy_test/test_pass_builder.py index b976d29ca0db3..34aef7fde8ecf 100644 --- a/test/legacy_test/test_pass_builder.py +++ b/test/legacy_test/test_pass_builder.py @@ -117,7 +117,7 @@ def test_parallel_testing_with_new_strategy(self): ) try: os.stat(graph_viz_path) - except os.error: + except OSError: self.assertFalse(True) diff --git a/test/legacy_test/test_pca_lowrank.py b/test/legacy_test/test_pca_lowrank.py index 107c76b442af9..68f0005b36823 100644 --- a/test/legacy_test/test_pca_lowrank.py +++ b/test/legacy_test/test_pca_lowrank.py @@ -62,9 +62,7 @@ def run_subtest( self.assertEqual(v.shape[-1], guess_rank) self.assertEqual(v.shape[-2], columns) - A1 = u.matmul(paddle.nn.functional.diag_embed(s)).matmul( - self.transpose(v) - ) + A1 = u.matmul(paddle.diag_embed(s)).matmul(self.transpose(v)) ones_m1 = paddle.ones(batches + (rows, 1), dtype=a.dtype) c = a.sum(axis=-2) / rows c = c.reshape(batches + (1, columns)) diff --git a/test/legacy_test/test_poisson_op.py b/test/legacy_test/test_poisson_op.py index fdb7962860a9c..b2b889645ddfc 100644 --- a/test/legacy_test/test_poisson_op.py +++ b/test/legacy_test/test_poisson_op.py @@ -73,6 +73,7 @@ def test_check_grad_normal(self): user_defined_grad_outputs=[ np.random.rand(2048, 1024).astype(self.dtype) ], + check_pir=True, ) diff --git a/test/legacy_test/test_polygamma_op.py b/test/legacy_test/test_polygamma_op.py index 9c9b0416ba4f2..5edd092f58610 100644 --- a/test/legacy_test/test_polygamma_op.py +++ b/test/legacy_test/test_polygamma_op.py @@ -20,6 +20,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api np.random.seed(100) paddle.seed(100) @@ -64,6 +65,7 @@ def setUp(self): if core.is_compiled_with_cuda(): self.place.append(paddle.CUDAPlace(0)) + @test_with_pir_api def test_api_static(self): def run(place): paddle.enable_static() @@ -197,7 +199,7 @@ def init_config(self): self.target = ref_polygamma(self.inputs['x'], self.order) def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -206,6 +208,7 @@ def test_check_grad(self): user_defined_grads=[ ref_polygamma_grad(self.case, 1 / self.case.size, self.order) ], + check_pir=True, ) diff --git a/test/legacy_test/test_pool2d_api.py b/test/legacy_test/test_pool2d_api.py index fcca5381fa4f0..84615340fe051 100644 --- a/test/legacy_test/test_pool2d_api.py +++ b/test/legacy_test/test_pool2d_api.py @@ -25,6 +25,7 @@ from paddle import base from paddle.base import core from paddle.nn.functional import avg_pool2d, max_pool2d +from paddle.pir_utils import test_with_pir_api class TestPool2D_API(unittest.TestCase): @@ -52,7 +53,7 @@ def check_avg_static_results(self, place): exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={"input": input_np}, fetch_list=[result], ) @@ -144,7 +145,7 @@ def check_max_static_results(self, place): exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={"input": input_np}, fetch_list=[result], ) @@ -360,8 +361,6 @@ def test_pool2d(self): for place in self.places: self.check_max_dygraph_results(place) self.check_avg_dygraph_results(place) - self.check_max_static_results(place) - self.check_avg_static_results(place) self.check_max_dygraph_stride_is_none(place) self.check_avg_dygraph_stride_is_none(place) self.check_max_dygraph_padding(place) @@ -370,6 +369,14 @@ def test_pool2d(self): self.check_max_dygraph_ceilmode_results(place) self.check_max_dygraph_nhwc_results(place) + @test_with_pir_api + def test_pool2d_static(self): + paddle.enable_static() + for place in self.places: + self.check_max_static_results(place) + self.check_avg_static_results(place) + paddle.disable_static() + class TestPool2DError_API(unittest.TestCase): def test_error_api(self): diff --git a/test/legacy_test/test_prelu_op.py b/test/legacy_test/test_prelu_op.py index c287daec3b959..bc90c119636fa 100644 --- a/test/legacy_test/test_prelu_op.py +++ b/test/legacy_test/test_prelu_op.py @@ -21,6 +21,7 @@ import paddle.nn.functional as F from paddle import base from paddle.base import Program, core +from paddle.pir_utils import test_with_pir_api def ref_prelu(x, weight): @@ -48,6 +49,7 @@ def setUp(self): self.weight_np_0 = np.random.randn(1).astype('float32') self.weight_np_1 = np.random.randn(self.x_np.shape[1]).astype('float32') + @test_with_pir_api def static_check(self, weight_np): with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data('X', self.x_np.shape, 'float32') @@ -69,6 +71,7 @@ def dygraph_check(self, weight_np): np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05) paddle.enable_static() + @test_with_pir_api def test_static_api(self): self.static_check(self.weight_np_0) self.static_check(self.weight_np_1) @@ -105,6 +108,7 @@ def setUp(self): ) self.x_np = np.ones([1, 2, 3, 4]).astype('float32') + @test_with_pir_api def test_static_api(self): startup_program = paddle.static.Program() train_program = paddle.static.Program() @@ -226,10 +230,10 @@ def init_attr(self): self.attrs = {'mode': "channel", "data_format": "NCHW"} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X', 'Alpha'], 'Out') + self.check_grad(['X', 'Alpha'], 'Out', check_pir=True) @skip_check_grad_ci( @@ -392,13 +396,17 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, atol=atol) + self.check_output_with_place( + place, atol=atol, check_pir=True + ) def test_check_grad(self): place = core.CUDAPlace(0) if core.is_float16_supported(place) and check_grad: # Use the default max_relative_error, not use max_relative_error - self.check_grad_with_place(place, ['X', 'Alpha'], 'Out') + self.check_grad_with_place( + place, ['X', 'Alpha'], 'Out', check_pir=True + ) cls_name = "{}_{}".format(parent.__name__, "Fp16Op") TestPReluFp16Case.__name__ = cls_name @@ -426,13 +434,15 @@ def init_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=atol) + self.check_output_with_place(place, atol=atol, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) if check_grad: # Use the default max_relative_error, not use max_relative_error - self.check_grad_with_place(place, ['X', 'Alpha'], 'Out') + self.check_grad_with_place( + place, ['X', 'Alpha'], 'Out', check_pir=True + ) cls_name = "{}_{}".format(parent.__name__, "BF16Op") TestPReluBF16Op.__name__ = cls_name diff --git a/test/legacy_test/test_prod_op.py b/test/legacy_test/test_prod_op.py index 2a0b06d76f849..7a69a840c393d 100644 --- a/test/legacy_test/test_prod_op.py +++ b/test/legacy_test/test_prod_op.py @@ -18,6 +18,7 @@ from test_sum_op import TestReduceOPTensorAxisBase import paddle +from paddle.pir_utils import test_with_pir_api class TestProdOp(unittest.TestCase): @@ -70,33 +71,35 @@ def run_imperative(self): dy_result.numpy(), expected_result, rtol=1e-05 ) + @test_with_pir_api def run_static(self, use_gpu=False): - input = paddle.static.data( - name='input', shape=[10, 10, 5], dtype='float32' - ) - result0 = paddle.prod(input) - result1 = paddle.prod(input, axis=1) - result2 = paddle.prod(input, axis=-1) - result3 = paddle.prod(input, axis=[0, 1]) - result4 = paddle.prod(input, axis=1, keepdim=True) - result5 = paddle.prod(input, axis=1, dtype='int64') - result6 = paddle.prod(input, axis=1, keepdim=True, dtype='int64') - - place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace() - exe = paddle.static.Executor(place) - exe.run(paddle.static.default_startup_program()) - static_result = exe.run( - feed={"input": self.input}, - fetch_list=[ - result0, - result1, - result2, - result3, - result4, - result5, - result6, - ], - ) + with paddle.static.program_guard(paddle.static.Program()): + input = paddle.static.data( + name='input', shape=[10, 10, 5], dtype='float32' + ) + result0 = paddle.prod(input) + result1 = paddle.prod(input, axis=1) + result2 = paddle.prod(input, axis=-1) + result3 = paddle.prod(input, axis=[0, 1]) + result4 = paddle.prod(input, axis=1, keepdim=True) + result5 = paddle.prod(input, axis=1, dtype='int64') + result6 = paddle.prod(input, axis=1, keepdim=True, dtype='int64') + + place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + static_result = exe.run( + feed={"input": self.input}, + fetch_list=[ + result0, + result1, + result2, + result3, + result4, + result5, + result6, + ], + ) expected_result = np.prod(self.input) np.testing.assert_allclose( @@ -134,8 +137,7 @@ def test_cpu(self): self.run_imperative() paddle.enable_static() - with paddle.static.program_guard(paddle.static.Program()): - self.run_static() + self.run_static() def test_gpu(self): if not paddle.base.core.is_compiled_with_cuda(): @@ -145,8 +147,7 @@ def test_gpu(self): self.run_imperative() paddle.enable_static() - with paddle.static.program_guard(paddle.static.Program()): - self.run_static(use_gpu=True) + self.run_static(use_gpu=True) class TestProdOpError(unittest.TestCase): diff --git a/test/legacy_test/test_quant_linear_op.py b/test/legacy_test/test_quant_linear_op.py new file mode 100644 index 0000000000000..ed5a4b30fe7a5 --- /dev/null +++ b/test/legacy_test/test_quant_linear_op.py @@ -0,0 +1,764 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest, paddle_static_guard + +import paddle +from paddle import base +from paddle.base import Program, core, program_guard +from paddle.base.data_feeder import check_dtype +from paddle.base.framework import Variable, static_only +from paddle.common_ops_import import LayerHelper, check_type + +SEED = 2020 + + +@static_only +def quant_linear( + x, + w, + size, + scale_in, + scale_weight, + num_flatten_dims=1, + bias_attr=None, + activation=None, + quant_round_type=1, + quant_max_bound=127.0, + quant_min_bound=-127.0, + name=None, +): + r""" + + Quant linear layer can take a tensor as its input and a tensor as the weight tensor. + The quant linear layer multiplies the input tensor with the weight to produce + an output tensor with shape :math:`[batch\_size, *, size]` , where :math:`*` + means any number of additional dimensions. If :attr:`bias_attr` is not False, a 1-D bias tensor will + be created and added to the output. If :attr:`activation` is not None, + it will be applied to the output as well. Besides, the input tensor will be quantize to + the tensor with int8 type, the parameter w must be a tensor with int8 type and the computation will also + be with the int8 type. + + For a single input tensor :math:`X` , the equation is: + + .. math:: + + Out = Act({XW + b}) + + where: + + * :math:`X`: The input tensor. + * :math:`W`: The weight matrix. + * :math:`b`: The bias created by this layer (if needed). + * :math:`Act`: The activation function. + * :math:`Out`: The output tensor. + + Args: + x (Tensor): A tensor. The number of dimensions + of the tensor is at least 2. The data type should be float16, bfloat16, float32 or float64. + w (Tensor): A tensor. The data type should be int8. + size (int): The number of the output unit in this layer, which also means the feature + size of output tensor. + scale_in (float): The quantization scale for input. + scale_weight (list[float]): The quantization scale for weights. + num_flatten_dims (int, optional): The quant linear layer can accept an input tensor with more than + two dimensions. If this happens, the multi-dimensional tensor will first be flattened + into a 2-D matrix. The parameter :attr:`num_flatten_dims` determines how the input + tensor is flattened: the first :math:`num\_flatten\_dims` (inclusive, index starts from 1) + dimensions will be flatten to form the first dimension of the final matrix (height of + the matrix), and the rest :math:`rank(x) - num\_flatten\_dims` dimensions are + flattened to form the second dimension of the final matrix (width of the matrix). + For example, assuming that :attr:`x` is a 5-dimensional tensor with a shape + :math:`[2, 3, 4, 5, 6]` , and :attr:`num_flatten_dims` = 3. + Then, the flattened matrix will have a shape :math:`[2 * 3 * 4, 5 * 6] = [24, 30]` . + Default: 1. + bias_attr (ParamAttr|bool, optional): The attribute of the learnable bias. + If it is set to False, no bias will be added to the output. + If it is set to None or one kind of ParamAttr, a bias parameter will + be created according to ParamAttr. For detailed information, please refer + to :attr:`paddle.ParamAttr`. The default value is None and the bias will be + initialized to zero. + activation (str, optional): Activation to be applied to the output of + this layer. Only "relu" is supported. For more information, + please refer to :ref:`api_guide_activations_en` . Default: None. + quant_round_type (int, optional): The round type of float to int. 0 means rounding to nearest ties to even and 1 means rounding to nearest ties away from zero. Default: 1. + quant_max_bound (float, optional): The max bound of float type to int type. Defualt: 127.0. + quant_min_bound (float, optional): The min bound of float type to int type. Defualt: -127.0. + name (str, optional): The default value is None. Normally there is no need for user to set + it. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Tensor, its shape is :math:`[batch\_size, *, size]` , and the data type is same with input. + + """ + + def quant_linear_base( + input, + weight, + size, + scale_in, + scale_weight, + num_flatten_dims=1, + bias_attr=None, + act=None, + quant_round_type=1, + quant_max_bound=127.0, + quant_min_bound=-127.0, + name=None, + ): + helper = LayerHelper("quant_linear", **locals()) + check_type(input, 'input', Variable, 'quant_linear') + dtype = helper.input_dtype() + check_dtype( + dtype, + 'input', + ['float16', 'float32', 'float64'], + 'quant_linear', + ) + + input_shape = input.shape + if num_flatten_dims == -1: + num_flatten_dims = len(input_shape) - 1 + + check_type(weight, "weight", Variable, 'quant_linear') + check_dtype( + weight.dtype, + 'weight', + ['int8'], + 'quant_linear', + ) + check_type(scale_weight, "scale_weight", list, 'quant_linear') + if len(scale_weight) != size: + raise AttributeError( + "The length of scale_weight must be the same with the param size." + ) + + inputs_of_quant_linear = {"x": input, "w": weight} + if bias_attr is not False: + bias_shape = [size] + bias = helper.create_parameter( + attr=bias_attr, shape=bias_shape, dtype=dtype, is_bias=True + ) + inputs_of_quant_linear["bias"] = bias + + out = helper.create_variable_for_type_inference(dtype) + attrs_of_quant_linear = { + "in_num_col_dims": num_flatten_dims, + "activation_type": act, + "scale_in": scale_in, + "scale_weights": scale_weight, + "quant_round_type": quant_round_type, + "quant_max_bound": quant_max_bound, + "quant_min_bound": quant_min_bound, + } + + helper.append_op( + type="quant_linear", + inputs=inputs_of_quant_linear, + outputs={"out": out}, + attrs=attrs_of_quant_linear, + ) + return out + + return quant_linear_base( + input=x, + weight=w, + size=size, + scale_in=scale_in, + scale_weight=scale_weight, + num_flatten_dims=num_flatten_dims, + bias_attr=bias_attr, + act=activation, + quant_round_type=quant_round_type, + quant_max_bound=quant_max_bound, + quant_min_bound=quant_min_bound, + name=name, + ) + + +def round_array(x): + x[x > 0] = np.ceil(x[x > 0]) + x[x <= 0] = np.floor(x[x <= 0]) + + +def round_array_with_ties_to_even(x): + xLower = np.floor(x) + xUpper = np.ceil(x) + dLower = x - xLower + dUpper = xUpper - x + x[(dLower == dUpper) & (xLower % 2 == 0)] = xLower[ + (dLower == dUpper) & (xLower % 2 == 0) + ] + x[(dLower == dUpper) & (xLower % 2 != 0)] = xUpper[ + (dLower == dUpper) & (xLower % 2 != 0) + ] + x[dLower < dUpper] = xLower[dLower < dUpper] + x[dLower > dUpper] = xUpper[dLower > dUpper] + + +def quant_linear_refer( + matrix, + with_bias, + scale_in, + scale_weights, + quant_round_type=1, + quant_max_bound=127, + quant_min_bound=-127, + with_relu=False, +): + in_n, in_c, in_h, in_w = matrix.input.shape + w_i, w_o = matrix.weights.shape + + x_data = np.reshape(matrix.input, [in_n, in_c * in_h * in_w]) + quant_x_data = x_data.astype('float32') + quant_x_data = quant_max_bound * scale_in * quant_x_data + if quant_round_type == 0: + round_array_with_ties_to_even(quant_x_data) + else: + round_array(quant_x_data) + quant_x_data[quant_x_data > quant_max_bound] = quant_max_bound + quant_x_data[quant_x_data < quant_min_bound] = quant_min_bound + quant_x_data = quant_x_data.astype('int8') + + w_data = np.reshape(matrix.weights, [w_i, w_o]) + b_data = np.reshape(matrix.bias, [1, w_o]) + result = None + quant_result = np.dot(quant_x_data.astype('int32'), w_data.astype('int32')) + scale_out = scale_weights * scale_in + result = quant_result / (quant_max_bound * quant_max_bound * scale_out) + result = result.astype(x_data.dtype) + + if with_bias: + result = result + b_data + + if with_relu: + return np.maximum(result, 0) + else: + return result + + +class MatrixGenerate: + def __init__(self, mb, ic, oc, h, w, bias_dims=2): + self.input = np.random.random((mb, ic, h, w)).astype("float32") + self.weights = np.random.random((ic * h * w, oc)).astype("float32") + if bias_dims == 2: + self.bias = np.random.random((1, oc)).astype("float32") + else: + self.bias = np.random.random(oc).astype("float32") + + +def get_scale_in(input): + max_v = np.max(np.abs(input)) + return 1 / max_v + + +def get_scale_weights(weights): + max_v = np.max(np.abs(weights), axis=0) + return 1 / max_v + + +def quant_weights( + weights, scale_weights, quant_round_type, quant_max_bound, quant_min_bound +): + quant_weights = weights.astype('float32') + quant_weights = quant_max_bound * scale_weights * quant_weights + if quant_round_type == 0: + round_array_with_ties_to_even(quant_weights) + else: + round_array(quant_weights) + quant_weights[quant_weights > quant_max_bound] = quant_max_bound + quant_weights[quant_weights < quant_min_bound] = quant_min_bound + quant_weights = quant_weights.astype('int8') + return quant_weights + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "QuantLinear only supports cuda kernel.", +) +class TestQuantLinearOp(OpTest): + def config(self): + self.with_bias = False + self.with_relu = False + self.quant_round_type = 0 + self.quant_max_bound = 127 + self.quant_min_bound = -127 + self.matrix = MatrixGenerate(2, 1, 10, 1, 1, 2) + self.scale_in = get_scale_in(self.matrix.input) + self.scale_weights = get_scale_weights(self.matrix.weights) + self.matrix.weights = quant_weights( + self.matrix.weights, + self.scale_weights, + self.quant_round_type, + self.quant_max_bound, + self.quant_min_bound, + ) + + def setUp(self): + self.op_type = "quant_linear" + self.config() + + if self.with_bias: + self.inputs = { + 'x': self.matrix.input, + 'w': self.matrix.weights, + 'bias': self.matrix.bias, + } + else: + self.inputs = {'x': self.matrix.input, 'w': self.matrix.weights} + + if self.with_relu: + activation_type = "relu" + else: + activation_type = "" + self.attrs = { + 'activation_type': activation_type, + 'quant_round_type': self.quant_round_type, + 'quant_max_bound': self.quant_max_bound, + 'quant_min_bound': self.quant_min_bound, + 'scale_in': self.scale_in, + 'scale_weights': self.scale_weights, + } + + self.outputs = { + 'out': quant_linear_refer( + self.matrix, + self.with_bias, + self.scale_in, + self.scale_weights, + self.quant_round_type, + self.quant_max_bound, + self.quant_min_bound, + self.with_relu, + ) + } + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_dygraph=False) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "QuantLinear only supports cuda kernel.", +) +class TestQuantLinearOpNoBias1(TestQuantLinearOp): + def config(self): + self.with_bias = False + self.with_relu = False + self.quant_round_type = 1 + self.quant_max_bound = 127 + self.quant_min_bound = -127 + self.matrix = MatrixGenerate(16, 10, 16, 4, 4, 2) + self.scale_in = get_scale_in(self.matrix.input) + self.scale_weights = get_scale_weights(self.matrix.weights) + self.matrix.weights = quant_weights( + self.matrix.weights, + self.scale_weights, + self.quant_round_type, + self.quant_max_bound, + self.quant_min_bound, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "QuantLinear only supports cuda kernel.", +) +class TestQuantLinearOpNoBias2(TestQuantLinearOp): + def config(self): + self.with_bias = False + self.with_relu = False + self.quant_round_type = 0 + self.quant_max_bound = 127 + self.quant_min_bound = -127 + self.matrix = MatrixGenerate(2, 8, 10, 1, 1, 2) + self.scale_in = get_scale_in(self.matrix.input) + self.scale_weights = get_scale_weights(self.matrix.weights) + self.matrix.weights = quant_weights( + self.matrix.weights, + self.scale_weights, + self.quant_round_type, + self.quant_max_bound, + self.quant_min_bound, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "QuantLinear only supports cuda kernel.", +) +class TestQuantLinearOpNoBias3(TestQuantLinearOp): + def config(self): + self.with_bias = False + self.with_relu = False + self.quant_round_type = 1 + self.quant_max_bound = 127 + self.quant_min_bound = -127 + self.matrix = MatrixGenerate(2, 6, 10, 1, 1, 2) + self.scale_in = get_scale_in(self.matrix.input) + self.scale_weights = get_scale_weights(self.matrix.weights) + self.matrix.weights = quant_weights( + self.matrix.weights, + self.scale_weights, + self.quant_round_type, + self.quant_max_bound, + self.quant_min_bound, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "QuantLinear only supports cuda kernel.", +) +class TestQuantLinearOpNoBias4(TestQuantLinearOp): + def config(self): + self.with_bias = False + self.with_relu = False + self.quant_round_type = 1 + self.quant_max_bound = 127 + self.quant_min_bound = -127 + self.matrix = MatrixGenerate(2, 14, 10, 1, 1, 2) + self.scale_in = get_scale_in(self.matrix.input) + self.scale_weights = get_scale_weights(self.matrix.weights) + self.matrix.weights = quant_weights( + self.matrix.weights, + self.scale_weights, + self.quant_round_type, + self.quant_max_bound, + self.quant_min_bound, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "QuantLinear only supports cuda kernel.", +) +class TestQuantLinearOpWithBias1(TestQuantLinearOp): + def config(self): + self.with_bias = True + self.with_relu = True + self.quant_round_type = 1 + self.quant_max_bound = 127 + self.quant_min_bound = -127 + self.matrix = MatrixGenerate(1, 64, 32, 3, 3, 1) + self.scale_in = get_scale_in(self.matrix.input) + self.scale_weights = get_scale_weights(self.matrix.weights) + self.matrix.weights = quant_weights( + self.matrix.weights, + self.scale_weights, + self.quant_round_type, + self.quant_max_bound, + self.quant_min_bound, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "QuantLinear only supports cuda kernel.", +) +class TestQuantLinearOpWithBias2(TestQuantLinearOp): + def config(self): + self.with_bias = True + self.with_relu = True + self.quant_round_type = 0 + self.quant_max_bound = 127 + self.quant_min_bound = -127 + self.matrix = MatrixGenerate(3, 8, 10, 2, 1, 2) + self.scale_in = get_scale_in(self.matrix.input) + self.scale_weights = get_scale_weights(self.matrix.weights) + self.matrix.weights = quant_weights( + self.matrix.weights, + self.scale_weights, + self.quant_round_type, + self.quant_max_bound, + self.quant_min_bound, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "QuantLinear only supports cuda kernel.", +) +class TestQuantLinearOpWithPadding1(TestQuantLinearOp): + def config(self): + self.with_bias = True + self.with_relu = True + self.quant_round_type = 1 + self.quant_max_bound = 127 + self.quant_min_bound = -127 + self.matrix = MatrixGenerate(1, 4, 4, 128, 128, 2) + self.scale_in = get_scale_in(self.matrix.input) + self.scale_weights = get_scale_weights(self.matrix.weights) + self.matrix.weights = quant_weights( + self.matrix.weights, + self.scale_weights, + self.quant_round_type, + self.quant_max_bound, + self.quant_min_bound, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "QuantLinear only supports cuda kernel.", +) +class TestQuantLinearOpWithPadding2(TestQuantLinearOp): + def config(self): + self.with_bias = True + self.with_relu = True + self.quant_round_type = 0 + self.quant_max_bound = 127 + self.quant_min_bound = -127 + self.matrix = MatrixGenerate(1, 4, 3, 128, 128, 2) + self.scale_in = get_scale_in(self.matrix.input) + self.scale_weights = get_scale_weights(self.matrix.weights) + self.matrix.weights = quant_weights( + self.matrix.weights, + self.scale_weights, + self.quant_round_type, + self.quant_max_bound, + self.quant_min_bound, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "QuantLinear only supports cuda kernel.", +) +class TestQuantLinearOp_NumFlattenDims_NegOne(unittest.TestCase): + def test_api(self): + def run_program(num_flatten_dims): + paddle.seed(SEED) + np.random.seed(SEED) + startup_program = Program() + main_program = Program() + + with paddle_static_guard(): + with program_guard(main_program, startup_program): + quant_round_type = 0 + quant_max_bound = 127.0 + quant_min_bound = -127.0 + + input = np.random.random([2, 2, 25]).astype("float32") + scale_in = get_scale_in(input) + x = paddle.static.data( + name="x", + shape=[2, 2, 25], + dtype="float32", + ) + + weight = np.random.random([25, 1]).astype("float32") + scale_weight = get_scale_weights(weight) + weight = quant_weights( + weight, + scale_weight, + quant_round_type, + quant_max_bound, + quant_min_bound, + ) + w = paddle.static.data( + name="w", + shape=[25, 1], + dtype="int8", + ) + + out = quant_linear( + x=x, + size=1, + num_flatten_dims=num_flatten_dims, + w=w, + scale_in=scale_in, + scale_weight=scale_weight.tolist(), + quant_round_type=quant_round_type, + quant_max_bound=quant_max_bound, + quant_min_bound=quant_min_bound, + ) + + place = base.CUDAPlace(0) + exe = base.Executor(place=place) + exe.run(startup_program) + out = exe.run( + main_program, + feed={"x": input, "w": weight}, + fetch_list=[out], + ) + return out + + res_1 = run_program(-1) + res_2 = run_program(2) + np.testing.assert_array_equal(res_1, res_2) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "QuantLinear only supports cuda kernel.", +) +class TestQuantLinearOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + quant_round_type = 0 + quant_max_bound = 127.0 + quant_min_bound = -127.0 + + input_data = np.random.random((2, 4)).astype("float32") + scale_in = get_scale_in(input_data) + + weight = np.random.random([25, 1]).astype("float32") + scale_weight = get_scale_weights(weight) + weight = quant_weights( + weight, + scale_weight, + quant_round_type, + quant_max_bound, + quant_min_bound, + ) + + def test_Variable(): + with paddle_static_guard(): + w2 = paddle.static.data( + name='w2', shape=[25, 1], dtype='int8' + ) + quant_linear( + x=input_data, + size=1, + num_flatten_dims=1, + w=w2, + scale_in=scale_in, + scale_weight=scale_weight.tolist(), + quant_round_type=quant_round_type, + quant_max_bound=quant_max_bound, + quant_min_bound=quant_min_bound, + ) + + self.assertRaises(TypeError, test_Variable) + + def test_type(): + with paddle_static_guard(): + x2 = paddle.static.data( + name='x2', shape=[-1, 4], dtype='int32' + ) + w2 = paddle.static.data( + name='w2', shape=[25, 1], dtype='int8' + ) + paddle.static.nn.fc( + x=x2, + size=1, + num_flatten_dims=1, + w=w2, + scale_in=scale_in, + scale_weight=scale_weight.tolist(), + quant_round_type=quant_round_type, + quant_max_bound=quant_max_bound, + quant_min_bound=quant_min_bound, + ) + + self.assertRaises(TypeError, test_type) + + def test_Variable(): + with paddle_static_guard(): + x3 = paddle.static.data( + name='x3', shape=[-1, 4], dtype='float32' + ) + quant_linear( + x=x3, + size=1, + num_flatten_dims=1, + w=weight, + scale_in=scale_in, + scale_weight=scale_weight.tolist(), + quant_round_type=quant_round_type, + quant_max_bound=quant_max_bound, + quant_min_bound=quant_min_bound, + ) + + self.assertRaises(TypeError, test_Variable) + + def test_type(): + with paddle_static_guard(): + x3 = paddle.static.data( + name='x3', shape=[-1, 4], dtype='float32' + ) + w3 = paddle.static.data( + name='w3', shape=[25, 1], dtype='int32' + ) + paddle.static.nn.fc( + x=x3, + size=1, + num_flatten_dims=1, + w=w3, + scale_in=scale_in, + scale_weight=scale_weight.tolist(), + quant_round_type=quant_round_type, + quant_max_bound=quant_max_bound, + quant_min_bound=quant_min_bound, + ) + + self.assertRaises(TypeError, test_type) + + scale_weight = 1.0 + + def test_type(): + with paddle_static_guard(): + x4 = paddle.static.data( + name='x4', shape=[-1, 4], dtype='float32' + ) + w4 = paddle.static.data( + name='w4', shape=[25, 1], dtype='int8' + ) + paddle.static.nn.fc( + x=x4, + size=1, + num_flatten_dims=1, + w=w4, + scale_in=scale_in, + scale_weight=scale_weight, + quant_round_type=quant_round_type, + quant_max_bound=quant_max_bound, + quant_min_bound=quant_min_bound, + ) + + self.assertRaises(TypeError, test_type) + + scale_weight = [] + + def test_param_length(): + with paddle_static_guard(): + x4 = paddle.static.data( + name='x4', shape=[-1, 4], dtype='float32' + ) + w4 = paddle.static.data( + name='w4', shape=[25, 1], dtype='int8' + ) + paddle.static.nn.fc( + x=x4, + size=1, + num_flatten_dims=1, + w=w4, + scale_in=scale_in, + scal=scale_weight, + quant_round_type=quant_round_type, + quant_max_bound=quant_max_bound, + quant_min_bound=quant_min_bound, + ) + + self.assertRaises(TypeError, test_param_length) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_randperm_op.py b/test/legacy_test/test_randperm_op.py index 9e10256cd25fb..9cb270801fece 100644 --- a/test/legacy_test/test_randperm_op.py +++ b/test/legacy_test/test_randperm_op.py @@ -144,7 +144,9 @@ def init_attrs(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place_customized(self.verify_output, self.place) + self.check_output_with_place_customized( + self.verify_output, self.place, check_pir=True + ) def verify_output(self, outs): out_np = convert_uint16_to_float(np.array(outs[0])) diff --git a/test/legacy_test/test_real_imag_op.py b/test/legacy_test/test_real_imag_op.py index f714cef69e6d4..71ee93262f267 100644 --- a/test/legacy_test/test_real_imag_op.py +++ b/test/legacy_test/test_real_imag_op.py @@ -19,6 +19,7 @@ import paddle from paddle import base, static +from paddle.pir_utils import test_with_pir_api numpy_apis = { "real": np.real, @@ -57,7 +58,7 @@ def init_grad_input_output(self): ) def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -65,6 +66,7 @@ def test_check_grad(self): 'Out', user_defined_grads=[self.grad_x], user_defined_grad_outputs=[self.grad_out], + check_pir=True, ) @@ -99,6 +101,7 @@ def setUp(self): self.places.append(paddle.CUDAPlace(0)) self._shape = [2, 20, 2, 3] + @test_with_pir_api def test_in_static_mode(self): def init_input_output(dtype): input = np.random.random(self._shape).astype( @@ -114,7 +117,7 @@ def init_input_output(dtype): out = paddle_apis[self.api](x) exe = static.Executor(place) - out_value = exe.run(feed=input_dict, fetch_list=[out.name]) + out_value = exe.run(feed=input_dict, fetch_list=[out]) np.testing.assert_array_equal(np_res, out_value[0]) def test_in_dynamic_mode(self): diff --git a/test/legacy_test/test_reduce_op.py b/test/legacy_test/test_reduce_op.py index 62d5a63ce2c15..a88f2650a005d 100644 --- a/test/legacy_test/test_reduce_op.py +++ b/test/legacy_test/test_reduce_op.py @@ -965,7 +965,7 @@ def setUp(self): self.attrs = {'reduce_all': True} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAnyFloatOp(OpTest): @@ -977,7 +977,7 @@ def setUp(self): self.attrs = {'reduce_all': True} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAnyIntOp(OpTest): @@ -989,7 +989,7 @@ def setUp(self): self.attrs = {'reduce_all': True} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAnyOp_ZeroDim(OpTest): @@ -1001,7 +1001,7 @@ def setUp(self): self.attrs = {'dim': []} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAny8DOp(OpTest): @@ -1017,7 +1017,7 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAnyOpWithDim(OpTest): @@ -1029,7 +1029,7 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].any(axis=1)} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAny8DOpWithDim(OpTest): @@ -1045,7 +1045,7 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAnyOpWithKeepDim(OpTest): @@ -1061,7 +1061,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAny8DOpWithKeepDim(OpTest): @@ -1081,7 +1081,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAnyOpError(unittest.TestCase): @@ -1713,21 +1713,25 @@ def setUp(self): self.places.append(base.CUDAPlace(0)) def check_static_result(self, place): - with base.program_guard(base.Program(), base.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with base.program_guard(main, startup): input = paddle.static.data(name="input", shape=[4, 4], dtype="bool") result = paddle.all(x=input) input_np = np.random.randint(0, 2, [4, 4]).astype("bool") exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + main, feed={"input": input_np}, fetch_list=[result], ) self.assertTrue((fetches[0] == np.all(input_np)).all()) def check_static_float_result(self, place): - with base.program_guard(base.Program(), base.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with base.program_guard(main, startup): input = paddle.static.data( name="input", shape=[4, 4], dtype="float" ) @@ -1736,26 +1740,29 @@ def check_static_float_result(self, place): exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + main, feed={"input": input_np}, fetch_list=[result], ) self.assertTrue((fetches[0] == np.all(input_np)).all()) def check_static_int_result(self, place): - with base.program_guard(base.Program(), base.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with base.program_guard(main, startup): input = paddle.static.data(name="input", shape=[4, 4], dtype="int") result = paddle.all(x=input) input_np = np.random.randint(0, 2, [4, 4]).astype("int") exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + main, feed={"input": input_np}, fetch_list=[result], ) self.assertTrue((fetches[0] == np.all(input_np)).all()) + @test_with_pir_api def test_static(self): for place in self.places: self.check_static_result(place=place) @@ -1814,21 +1821,25 @@ def setUp(self): self.places.append(base.CUDAPlace(0)) def check_static_result(self, place): - with base.program_guard(base.Program(), base.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with base.program_guard(main, startup): input = paddle.static.data(name="input", shape=[4, 4], dtype="bool") result = paddle.any(x=input) input_np = np.random.randint(0, 2, [4, 4]).astype("bool") exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + main, feed={"input": input_np}, fetch_list=[result], ) self.assertTrue((fetches[0] == np.any(input_np)).all()) def check_static_float_result(self, place): - with base.program_guard(base.Program(), base.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with base.program_guard(main, startup): input = paddle.static.data( name="input", shape=[4, 4], dtype="float" ) @@ -1837,26 +1848,29 @@ def check_static_float_result(self, place): exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + main, feed={"input": input_np}, fetch_list=[result], ) self.assertTrue((fetches[0] == np.any(input_np)).all()) def check_static_int_result(self, place): - with base.program_guard(base.Program(), base.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with base.program_guard(main, startup): input = paddle.static.data(name="input", shape=[4, 4], dtype="int") result = paddle.any(x=input) input_np = np.random.randint(0, 2, [4, 4]).astype("int") exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + main, feed={"input": input_np}, fetch_list=[result], ) self.assertTrue((fetches[0] == np.any(input_np)).all()) + @test_with_pir_api def test_static(self): for place in self.places: self.check_static_result(place=place) diff --git a/test/legacy_test/test_repeat_interleave_op.py b/test/legacy_test/test_repeat_interleave_op.py index ec6649039dd45..8b2e6e20333f2 100644 --- a/test/legacy_test/test_repeat_interleave_op.py +++ b/test/legacy_test/test_repeat_interleave_op.py @@ -104,14 +104,15 @@ def test_check_grad_normal(self): class TestIndexSelectAPI(unittest.TestCase): def input_data(self): - self.data_zero_dim_x = np.array(0.5) + self.data_zero_dim_x = np.array(0.5).astype('float32') self.data_x = np.array( [ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], ] - ) + ).astype('float32') + self.data_zero_dim_index = np.array(2) self.data_index = np.array([0, 1, 2, 1]).astype('int32') def test_repeat_interleave_api(self): @@ -267,6 +268,17 @@ def test_dygraph_api(self): expect_out = np.repeat(self.data_zero_dim_x, index, axis=None) np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + # case 4 zero_dim_index + with base.dygraph.guard(): + x = base.dygraph.to_variable(self.data_zero_dim_x) + index = base.dygraph.to_variable(self.data_zero_dim_index) + z = paddle.repeat_interleave(x, index, None) + np_z = z.numpy() + expect_out = np.repeat( + self.data_zero_dim_x, self.data_zero_dim_index, axis=None + ) + np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_reshape_op.py b/test/legacy_test/test_reshape_op.py index dd1f7e0044734..69f056741ba2b 100755 --- a/test/legacy_test/test_reshape_op.py +++ b/test/legacy_test/test_reshape_op.py @@ -19,6 +19,7 @@ import paddle from paddle import base +from paddle.pir_utils import test_with_pir_api from paddle.static import Program, program_guard @@ -416,12 +417,13 @@ def _set_paddle_api(self): def _executed_api(self): self.reshape = paddle.reshape + @test_with_pir_api def _test_api(self): paddle.enable_static() input = np.random.random([2, 25]).astype("float32") shape = [2, 5, 5] - main_prog = Program() - with program_guard(main_prog, Program()): + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, paddle.static.Program()): positive_five = self.fill_constant([1], "int32", 5) x = self.data(name="x", shape=[2, 25], dtype="float32") @@ -517,13 +519,6 @@ def test_x_type(): self.assertRaises(TypeError, test_x_type) - # The x dtype of reshape_op must be float16, float32, float64, int32 or int64. - def test_x_dtype(): - x2 = self.data(name="x2", shape=[2, 25], dtype="int8") - self.reshape(x2, shape=[2, 5, 5]) - - self.assertRaises(TypeError, test_x_dtype) - def test_x_dtype_float16(): x_float16 = self.data( name="x_float16", shape=[2, 25], dtype="float16" @@ -559,6 +554,7 @@ def test_shape_3(): self.assertRaises(AssertionError, test_shape_3) paddle.disable_static() + @test_with_pir_api def test_paddle_api_error(self): self._set_paddle_api() self._test_errors() @@ -649,25 +645,30 @@ def test_dygraph(self): paddle.enable_static() + @test_with_pir_api def test_static(self): main_prog = base.Program() with base.program_guard(main_prog, base.Program()): x = paddle.rand([]) x.stop_gradient = False out = paddle.reshape(x, [-1]) - base.backward.append_backward(out) - - prog = paddle.static.default_main_program() - block = prog.global_block() - - x_grad = block.var(base.framework.grad_var_name(x.name)) - out_grad = block.var(base.framework.grad_var_name(out.name)) + if paddle.framework.in_pir_mode(): + grads = paddle.autograd.ir_backward.grad(out, x) + x_grad = grads[0] + out_grad = x_grad.get_defining_op().operand_source(1) + else: + base.backward.append_backward(out) + prog = paddle.static.default_main_program() + block = prog.global_block() + + x_grad = block.var(base.framework.grad_var_name(x.name)) + out_grad = block.var(base.framework.grad_var_name(out.name)) # Test compile shape - self.assertEqual(x.shape, ()) - self.assertEqual(out.shape, (1,)) - self.assertEqual(x_grad.shape, ()) - self.assertEqual(out_grad.shape, (1,)) + self.assertEqual(tuple(x.shape), ()) + self.assertEqual(tuple(out.shape), (1,)) + self.assertEqual(tuple(x_grad.shape), ()) + self.assertEqual(tuple(out_grad.shape), (1,)) exe = base.Executor() result = exe.run(main_prog, fetch_list=[x, out, x_grad, out_grad]) diff --git a/test/legacy_test/test_rms_norm_op.py b/test/legacy_test/test_rms_norm_op.py index 79e20e906d92c..dc9061ad95924 100644 --- a/test/legacy_test/test_rms_norm_op.py +++ b/test/legacy_test/test_rms_norm_op.py @@ -18,6 +18,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def quant_helper( @@ -342,49 +343,6 @@ def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype): ) return out_s[0], paddle_naive_rmsnorm_out - def test_rmsnorm_pir(self): - paddle.disable_static() - x = paddle.to_tensor(self.x_np.astype("float32")) - gamma = paddle.to_tensor(self.norm_weight_np.astype("float32")) - beta = paddle.to_tensor(self.norm_bias_np.astype("float32")) - - paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon) - paddle.enable_static() - - with paddle.pir_utils.IrGuard(): - x_static = paddle.static.data( - name="x_static", shape=[self.batch, self.cols], dtype="float32" - ) - gamma_static = paddle.static.data( - name="gamma_static", shape=[self.cols], dtype="float32" - ) - beta_static = paddle.static.data( - name="beta_static", shape=[self.cols], dtype="float32" - ) - out, _ = paddle.incubate.nn.functional.fused_rms_norm( - x_static, - gamma_static, - beta_static, - self.epsilon, - begin_norm_axis=1, - ) - exe = base.Executor(self.place) - out_s = exe.run( - feed={ - "x_static": self.x_np.astype("float32"), - "gamma_static": self.norm_weight_np.astype("float32"), - "beta_static": self.norm_bias_np.astype("float32"), - }, - fetch_list=[out], - ) - - np.testing.assert_allclose( - out_s[0], - paddle_naive_rmsnorm_out.numpy(), - rtol=1e-3, - atol=1e-3, - ) - def check_rmsnorm_int8(self, x_np, gamma_np, beta_np, dtype): paddle.disable_static() x = paddle.to_tensor(x_np.astype(dtype)) @@ -491,6 +449,7 @@ def check_residual_bias_rmsnorm( ) return out_s[0], paddle_naive_rmsnorm_out + @test_with_pir_api def test_rmsnorm_fp16(self): if not paddle.is_compiled_with_cuda(): return @@ -505,6 +464,7 @@ def test_rmsnorm_fp16(self): atol=1e-3, ) + @test_with_pir_api def test_residual_bias_add_rmsnorm_fp16(self): if not paddle.is_compiled_with_cuda(): return @@ -524,6 +484,7 @@ def test_residual_bias_add_rmsnorm_fp16(self): atol=1e-3, ) + @test_with_pir_api def test_rmsnorm_int8(self): if not paddle.is_compiled_with_cuda(): return diff --git a/test/legacy_test/test_run_program_op.py b/test/legacy_test/test_run_program_op.py index 2d223e9474703..729f3e6ba114d 100644 --- a/test/legacy_test/test_run_program_op.py +++ b/test/legacy_test/test_run_program_op.py @@ -208,7 +208,6 @@ def create_var_base(is_input, name): outputs['OutScope'] = [core.Scope()] - outputs['DOut'] = [create_var_base(False, "Fake_var")] return outputs def calc_dygraph_output(self, place): @@ -256,7 +255,6 @@ def calc_dygraph_output(self, place): inputs['Params'], outputs['Out'], outputs['OutScope'], - outputs['DOut'], None, *self.attrs ) @@ -309,7 +307,6 @@ def calc_dygraph_grad(self, place): inputs['Params'], outputs['Out'], outputs['OutScope'], - outputs['DOut'], None, *self.attrs ) diff --git a/test/legacy_test/test_scale_op.py b/test/legacy_test/test_scale_op.py index 5f33de74b3b61..e743ee49ae796 100644 --- a/test/legacy_test/test_scale_op.py +++ b/test/legacy_test/test_scale_op.py @@ -23,7 +23,7 @@ import paddle from paddle import base from paddle.base import core -from paddle.static import Program, program_guard +from paddle.pir_utils import test_with_pir_api class TestScaleOp(OpTest): @@ -200,11 +200,12 @@ class TestScaleApiStatic(unittest.TestCase): def _executed_api(self, x, scale=1.0, bias=0.0): return paddle.scale(x, scale, bias) + @test_with_pir_api def test_api(self): paddle.enable_static() input = np.random.random([2, 25]).astype("float32") - main_prog = Program() - with program_guard(main_prog, Program()): + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, paddle.static.Program()): x = paddle.static.data(name="x", shape=[2, 25], dtype="float32") out = self._executed_api(x, scale=2.0, bias=3.0) @@ -300,7 +301,6 @@ def test_grad(self): class TestScaleOpZeroNumelVariable(unittest.TestCase): def test_check_zero_numel_cpu(self): - paddle.enable_static() paddle.set_device('cpu') data = paddle.ones([0, 1]) out = paddle.scale(data, 2) diff --git a/test/legacy_test/test_scatter_nd_op.py b/test/legacy_test/test_scatter_nd_op.py index a92432e4a0b6d..e9e541e09af67 100644 --- a/test/legacy_test/test_scatter_nd_op.py +++ b/test/legacy_test/test_scatter_nd_op.py @@ -21,7 +21,7 @@ import paddle from paddle import base from paddle.base import core -from paddle.base.dygraph.base import switch_to_static_graph +from paddle.pir_utils import test_with_pir_api def numpy_scatter_nd(ref, index, updates, fun): @@ -94,10 +94,12 @@ def _set_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out', check_prim=True) + self.check_grad( + ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + ) class TestScatterNdAddSimpleFP16Op(TestScatterNdAddSimpleOp): @@ -125,13 +127,13 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_prim=True + place, ['X', 'Updates'], 'Out', check_prim=True, check_pir=True ) @@ -170,10 +172,12 @@ def _set_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out', check_prim=True) + self.check_grad( + ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + ) class TestScatterNdAddWithEmptyIndexFP16(TestScatterNdAddWithEmptyIndex): @@ -201,13 +205,13 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_prim=True + place, ['X', 'Updates'], 'Out', check_prim=True, check_pir=True ) @@ -249,10 +253,12 @@ def _set_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out', check_prim=True) + self.check_grad( + ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + ) class TestScatterNdAddWithHighRankSameFP16(TestScatterNdAddWithHighRankSame): @@ -280,13 +286,13 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_prim=True + place, ['X', 'Updates'], 'Out', check_prim=True, check_pir=True ) @@ -312,10 +318,12 @@ def setUp(self): self.outputs = {'Out': expect_np} def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out', check_prim=True) + self.check_grad( + ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + ) # Test Python API @@ -422,7 +430,7 @@ def testcase5(self): np.testing.assert_array_equal(gpu_value.numpy(), cpu_value.numpy()) paddle.set_device(device) - @switch_to_static_graph + @test_with_pir_api def test_static_graph(): with paddle.static.program_guard( paddle.static.Program(), paddle.static.Program() @@ -434,15 +442,26 @@ def test_static_graph(): val_t = paddle.static.data( name="val", dtype=val.dtype, shape=val.shape ) - out_t = paddle.scatter_nd_add(x_t, index_t, val_t) - feed = {x_t.name: x, index_t.name: index, val_t.name: val} - fetch = [out_t] - gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0)) - gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0] cpu_exe = paddle.static.Executor(paddle.CPUPlace()) - cpu_value = cpu_exe.run(feed=feed, fetch_list=fetch)[0] - np.testing.assert_array_equal(gpu_value, cpu_value) + out_t = paddle.scatter_nd_add(x_t, index_t, val_t) + gpu_value = gpu_exe.run( + feed={ + 'x': x, + 'index': index, + 'val': val, + }, + fetch_list=[out_t], + ) + cpu_value = cpu_exe.run( + feed={ + 'x': x, + 'index': index, + 'val': val, + }, + fetch_list=[out_t], + ) + np.testing.assert_array_equal(gpu_value, cpu_value) test_static_graph() diff --git a/test/legacy_test/test_scatter_op.py b/test/legacy_test/test_scatter_op.py index d0cd04903956a..d44982c6321d0 100644 --- a/test/legacy_test/test_scatter_op.py +++ b/test/legacy_test/test_scatter_op.py @@ -22,6 +22,7 @@ from paddle import base from paddle.base import core from paddle.base.dygraph.base import switch_to_static_graph +from paddle.pir_utils import test_with_pir_api class TestScatterOp(OpTest): @@ -52,10 +53,12 @@ def _set_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out", check_prim=True) + self.check_grad( + ["X", "Updates"], "Out", check_prim=True, check_pir=True + ) class TestScatterFP16Op(TestScatterOp): @@ -78,7 +81,7 @@ def if_enable_cinn(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -88,6 +91,7 @@ def test_check_grad(self): ['X', 'Updates'], 'Out', check_prim=True, + check_pir=True, ) @@ -120,10 +124,12 @@ def _set_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out", check_prim=True) + self.check_grad( + ["X", "Updates"], "Out", check_prim=True, check_pir=True + ) class TestScatterFP16Op0(TestScatterOp0): @@ -146,7 +152,7 @@ def if_enable_cinn(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -156,6 +162,7 @@ def test_check_grad(self): ['X', 'Updates'], 'Out', check_prim=True, + check_pir=True, ) @@ -191,10 +198,12 @@ def _set_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out", check_prim=True) + self.check_grad( + ["X", "Updates"], "Out", check_prim=True, check_pir=True + ) class TestScatterFP16Op1(TestScatterOp1): @@ -217,7 +226,7 @@ def if_enable_cinn(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -227,6 +236,7 @@ def test_check_grad(self): ['X', 'Updates'], 'Out', check_prim=True, + check_pir=True, ) @@ -263,7 +273,7 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3) + self.check_output_with_place(place, atol=1e-3, check_pir=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -273,6 +283,7 @@ def test_check_grad(self): ['X', 'Updates'], 'Out', check_prim=True, + check_pir=True, ) @@ -334,7 +345,7 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3) + self.check_output_with_place(place, atol=1e-3, check_pir=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -344,6 +355,7 @@ def test_check_grad(self): ['X', 'Updates'], 'Out', check_prim=True, + check_pir=True, ) @@ -396,10 +408,12 @@ def _set_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out', check_prim=True) + self.check_grad( + ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + ) class TestScatterFP16Op4(TestScatterOp4): @@ -422,7 +436,7 @@ def if_enable_cinn(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -432,6 +446,7 @@ def test_check_grad(self): ['X', 'Updates'], 'Out', check_prim=True, + check_pir=True, ) @@ -468,7 +483,7 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3) + self.check_output_with_place(place, atol=1e-3, check_pir=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -478,6 +493,7 @@ def test_check_grad(self): ['X', 'Updates'], 'Out', check_prim=True, + check_pir=True, ) @@ -530,10 +546,12 @@ def _set_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out", check_prim=True) + self.check_grad( + ["X", "Updates"], "Out", check_prim=True, check_pir=True + ) class TestScatterFP16Op6(TestScatterOp6): @@ -556,7 +574,7 @@ def _set_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -566,6 +584,7 @@ def test_check_grad(self): ['X', 'Updates'], 'Out', check_prim=True, + check_pir=True, ) @@ -580,7 +599,9 @@ def executed_api(self): self.scatter = paddle.scatter def check_static_result(self, place): - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): input = paddle.static.data( name="input", shape=[3, 2], dtype="float64" ) @@ -596,9 +617,9 @@ def check_static_result(self, place): np.float64 ) - exe = base.Executor(place) + exe = paddle.static.Executor(place) fetches = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={ "input": input_data, "index": index_data, @@ -613,6 +634,7 @@ def check_static_result(self, place): True, ) + @test_with_pir_api def test_static(self): for place in self.places: self.check_static_result(place=place) @@ -675,7 +697,6 @@ def test_static_graph(): updates_t.name: updates, } fetch = [out_t] - gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0)) gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0] return gpu_value diff --git a/test/legacy_test/test_segment_ops.py b/test/legacy_test/test_segment_ops.py index 81c6b5845a03f..a78db61b6e399 100644 --- a/test/legacy_test/test_segment_ops.py +++ b/test/legacy_test/test_segment_ops.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api def compute_segment_sum(x, segment_ids): @@ -123,7 +124,7 @@ def setUp(self): self.convert_bf16() def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad(["X"], "Out") @@ -165,7 +166,9 @@ def prepare(self): self.attrs = {'pooltype': "MAX"} def test_check_grad(self): - self.check_grad(["X"], "Out", user_defined_grads=[self.gradient]) + self.check_grad( + ["X"], "Out", user_defined_grads=[self.gradient], check_pir=True + ) class TestSegmentMax2(TestSegmentMax): @@ -220,11 +223,11 @@ def setUp(self): def test_check_output(self): if core.is_compiled_with_cuda(): - self.check_output_with_place(core.CUDAPlace(0)) + self.check_output_with_place(core.CUDAPlace(0), check_pir=True) # due to CPU kernel not implement calculate 'SummedIds' # so cannot check 'SummedIds' del self.outputs['SummedIds'] - self.check_output_with_place(core.CPUPlace()) + self.check_output_with_place(core.CPUPlace(), check_pir=True) class TestSegmentMean2(TestSegmentMean): @@ -271,7 +274,7 @@ def prepare(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, check_pir=True) def test_check_grad(self): self.check_grad_with_place(self.place, ["X"], "Out") @@ -289,11 +292,14 @@ def prepare(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, check_pir=True) def test_check_grad(self): self.check_grad_with_place( - self.place, ["X"], "Out", user_defined_grads=[self.gradient] + self.place, + ["X"], + "Out", + user_defined_grads=[self.gradient], ) @@ -309,11 +315,14 @@ def prepare(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, check_pir=True) def test_check_grad(self): self.check_grad_with_place( - self.place, ["X"], "Out", user_defined_grads=[self.gradient] + self.place, + ["X"], + "Out", + user_defined_grads=[self.gradient], ) @@ -329,13 +338,14 @@ def prepare(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, check_pir=True) def test_check_grad(self): self.check_grad_with_place(self.place, ["X"], "Out") class API_SegmentOpsTest(unittest.TestCase): + @test_with_pir_api def test_static(self): with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data(name="x", shape=[3, 3], dtype="float32") @@ -389,6 +399,7 @@ def test_dygraph(self): class API_GeometricSegmentOpsTest(unittest.TestCase): + @test_with_pir_api def test_static(self): with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data(name="x", shape=[3, 3], dtype="float32") diff --git a/test/legacy_test/test_sgd_op.py b/test/legacy_test/test_sgd_op.py index a69039baa2634..d71b297185892 100644 --- a/test/legacy_test/test_sgd_op.py +++ b/test/legacy_test/test_sgd_op.py @@ -17,6 +17,7 @@ import numpy as np from op import Operator from op_test import OpTest +from utils import dygraph_guard import paddle from paddle import base @@ -51,7 +52,7 @@ def conf(self): self.w = 105 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestSGDOpCase8X(TestSGDOp): @@ -427,5 +428,65 @@ def test_main(self): ) +class TestSGDSimple(unittest.TestCase): + def setUp(self) -> None: + self.data = np.random.random(size=(2, 2)).astype('float32') + + def run_static(self): + with paddle.pir_utils.IrGuard(): + paddle.seed(10) + np.random.seed(10) + + exe = paddle.static.Executor('gpu') + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + + with paddle.static.program_guard(train_program, startup_program): + input = paddle.static.data( + shape=[2, 2], name='input', dtype='float32' + ) + model = paddle.nn.Linear(2, 2) + output = model(input) + loss = paddle.mean(output) + + optimizer = paddle.optimizer.SGD() + optimizer.minimize(loss) + + exe.run(startup_program) + + out = [] + for _ in range(5): + (loss_data,) = exe.run( + train_program, feed={"input": self.data}, fetch_list=[loss] + ) + out.append(loss_data) + return out + + def run_dygraph(self): + with dygraph_guard(): + paddle.seed(10) + np.random.seed(10) + + out = [] + model = paddle.nn.Linear(2, 2) + optimizer = paddle.optimizer.SGD(parameters=model.parameters()) + for _ in range(5): + output = model(paddle.to_tensor(self.data)) + loss = paddle.mean(output) + out.append(loss.numpy()) + loss.backward() + optimizer.step() + optimizer.clear_grad() + + return out + + def test_main(self): + if not paddle.is_compiled_with_cuda(): + return + out1 = self.run_dygraph() + out2 = self.run_static() + np.testing.assert_allclose(out1, out2) + + if __name__ == "__main__": unittest.main() diff --git a/test/legacy_test/test_sign_op.py b/test/legacy_test/test_sign_op.py index 80dcc6909bfb7..fc8a0ed27547c 100644 --- a/test/legacy_test/test_sign_op.py +++ b/test/legacy_test/test_sign_op.py @@ -34,10 +34,10 @@ def setUp(self): self.outputs = {'Out': np.sign(self.inputs['X'])} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestSignFP16Op(TestSignOp): @@ -70,34 +70,18 @@ def setUp(self): self.place = core.CUDAPlace(0) def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, check_pir=True) def test_check_grad(self): - self.check_grad_with_place(self.place, ['X'], 'Out') - - -class TestSignOpError(unittest.TestCase): - def test_errors(self): - with program_guard(Program(), Program()): - # The input type of sign_op must be Variable or numpy.ndarray. - input1 = 12 - self.assertRaises(TypeError, paddle.sign, input1) - # The input dtype of sign_op must be float16, float32, float64. - input2 = paddle.static.data( - name='input2', shape=[-1, 12, 10], dtype="int32" - ) - input3 = paddle.static.data( - name='input3', shape=[-1, 12, 10], dtype="int64" - ) - self.assertRaises(TypeError, paddle.sign, input2) - self.assertRaises(TypeError, paddle.sign, input3) - input4 = paddle.static.data( - name='input4', shape=[-1, 4], dtype="float16" - ) - paddle.sign(input4) + self.check_grad_with_place(self.place, ['X'], 'Out', check_pir=True) class TestSignAPI(unittest.TestCase): + def setUp(self): + self.place = [base.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(base.CUDAPlace(0)) + def test_dygraph(self): with base.dygraph.guard(): np_x = np.array([-1.0, 0.0, -0.0, 1.2, 1.5], dtype='float64') @@ -108,23 +92,51 @@ def test_dygraph(self): self.assertEqual((np_z == z_expected).all(), True) def test_static(self): - with program_guard(Program(), Program()): - # The input type of sign_op must be Variable or numpy.ndarray. - input1 = 12 - self.assertRaises(TypeError, paddle.tensor.math.sign, input1) - # The input dtype of sign_op must be float16, float32, float64. - input2 = paddle.static.data( - name='input2', shape=[-1, 12, 10], dtype="int32" - ) - input3 = paddle.static.data( - name='input3', shape=[-1, 12, 10], dtype="int64" - ) - self.assertRaises(TypeError, paddle.tensor.math.sign, input2) - self.assertRaises(TypeError, paddle.tensor.math.sign, input3) - input4 = paddle.static.data( - name='input4', shape=[-1, 4], dtype="float16" - ) - paddle.sign(input4) + np_input2 = np.random.uniform(-10, 10, (12, 10)).astype("int16") + np_input3 = np.random.uniform(-10, 10, (12, 10)).astype("int32") + np_input4 = np.random.uniform(-10, 10, (12, 10)).astype("int64") + np_out2 = np.sign(np_input2) + np_out3 = np.sign(np_input3) + np_out4 = np.sign(np_input4) + + def run(place): + with program_guard(Program(), Program()): + # The input type of sign_op must be Variable or numpy.ndarray. + input1 = 12 + self.assertRaises(TypeError, paddle.tensor.math.sign, input1) + # The result of sign_op must correct. + input2 = paddle.static.data( + name='input2', shape=[12, 10], dtype="int16" + ) + input3 = paddle.static.data( + name='input3', shape=[12, 10], dtype="int32" + ) + input4 = paddle.static.data( + name='input4', shape=[12, 10], dtype="int64" + ) + out2 = paddle.sign(input2) + out3 = paddle.sign(input3) + out4 = paddle.sign(input4) + exe = paddle.static.Executor(place) + res2, res3, res4 = exe.run( + paddle.static.default_main_program(), + feed={ + "input2": np_input2, + "input3": np_input3, + "input4": np_input4, + }, + fetch_list=[out2, out3, out4], + ) + self.assertEqual((res2 == np_out2).all(), True) + self.assertEqual((res3 == np_out3).all(), True) + self.assertEqual((res4 == np_out4).all(), True) + input5 = paddle.static.data( + name='input5', shape=[-1, 4], dtype="float16" + ) + paddle.sign(input5) + + for place in self.place: + run(place) class TestSignDoubleGradCheck(unittest.TestCase): diff --git a/test/legacy_test/test_slice_op.py b/test/legacy_test/test_slice_op.py index 8791bf94c16dc..e409287c90b68 100644 --- a/test/legacy_test/test_slice_op.py +++ b/test/legacy_test/test_slice_op.py @@ -657,7 +657,7 @@ def test_1(self): exe = base.Executor(place=base.CPUPlace()) res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={ "x": input, 'starts': np.array([-3, 0, 2]).astype("int32"), @@ -674,6 +674,65 @@ def test_1(self): np.testing.assert_array_equal(res_6, input[-3:3, 0:100, :, 2:-1]) np.testing.assert_array_equal(res_7, input[-1, 0:100, :, 2:-1]) + def test_pir(self): + with paddle.pir_utils.IrGuard(), paddle.static.program_guard( + paddle.static.Program() + ): + input = np.random.random([3, 4, 5, 6]).astype("float64") + minus_1 = paddle.tensor.fill_constant([], "int32", -1) + minus_3 = paddle.tensor.fill_constant([], "int64", -3) + starts = paddle.static.data(name='starts', shape=[3], dtype="int32") + ends = paddle.static.data(name='ends', shape=[3], dtype="int32") + x = paddle.static.data( + name="x", + shape=[3, 4, 5, 6], + dtype="float64", + ) + + # value_int64 is greater than 2147483647 which is the max of int32 + value_int64 = paddle.tensor.fill_constant([1], "int64", 2147483648) + + out_1 = paddle.slice( + x, + axes=[0, 1, 2], + starts=[-3, 0, 2], + ends=[value_int64, 100, -1], + ) + out_2 = paddle.slice( + x, axes=[0, 1, 3], starts=[minus_3, 0, 2], ends=[3, 100, -1] + ) + out_3 = paddle.slice( + x, + axes=[0, 1, 3], + starts=[minus_3, 0, 2], + ends=[3, 100, minus_1], + ) + out_4 = paddle.slice(x, axes=[0, 1, 2], starts=starts, ends=ends) + + out_5 = x[-3:3, 0:100, 2:-1] + out_6 = x[minus_3:3, 0:100, :, 2:-1] + # open it after supporting control flow + # out_7 = x[minus_1, 0:100, :, 2:minus_1] + + exe = base.Executor(place=base.CPUPlace()) + res_1, res_2, res_3, res_4, res_5, res_6 = exe.run( + paddle.static.default_main_program(), + feed={ + "x": input, + 'starts': np.array([-3, 0, 2]).astype("int32"), + 'ends': np.array([3, 100, -1]).astype("int32"), + }, + fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6], + ) + + np.testing.assert_array_equal(res_1, input[-3:3, 0:100, 2:-1, :]) + np.testing.assert_array_equal(res_2, input[-3:3, 0:100, :, 2:-1]) + np.testing.assert_array_equal(res_3, input[-3:3, 0:100, :, 2:-1]) + np.testing.assert_array_equal(res_4, input[-3:3, 0:100, 2:-1, :]) + np.testing.assert_array_equal(res_5, input[-3:3, 0:100, 2:-1, :]) + np.testing.assert_array_equal(res_6, input[-3:3, 0:100, :, 2:-1]) + # np.testing.assert_array_equal(res_7, input[-1, 0:100, :, 2:-1]) + class TestSliceApiWithTensor(unittest.TestCase): def test_starts_ends_is_tensor(self): @@ -754,7 +813,7 @@ def setUp(self): def set_program_and_run(self, main_program, case_num): with paddle_static_guard(): - with base.program_guard(main_program): + with paddle.static.program_guard(main_program): x = [ paddle.static.data( name='x0', shape=self.shape, dtype="float32" @@ -810,7 +869,7 @@ def set_program_and_run(self, main_program, case_num): ) def test_case_1(self): - main_program = base.Program() + main_program = paddle.static.Program() self.set_program_and_run(main_program, 1) self.assertTrue(self.sliced_arr.type == core.VarDesc.VarType.LOD_TENSOR) @@ -822,7 +881,7 @@ def test_case_1(self): def test_case_2(self): with paddle_static_guard(): - main_program = base.Program() + main_program = paddle.static.Program() self.set_program_and_run(main_program, 2) self.assertTrue( @@ -838,7 +897,7 @@ def test_case_2(self): def test_case_3(self): with paddle_static_guard(): - main_program = base.Program() + main_program = paddle.static.Program() self.set_program_and_run(main_program, 3) self.assertTrue( @@ -893,6 +952,13 @@ def test(self): out0 = paddle.slice(x, axes=[1], starts=[0], ends=[3]) self.assertEqual(out0.shape, (3, -1, 5)) + def test_pir(self): + with paddle.pir_utils.IrGuard(): + x = paddle.static.data('x', shape=[3, -1, 5]) + + out0 = paddle.slice(x, axes=[1], starts=[0], ends=[3]) + self.assertEqual(out0.shape, [3, -1, 5]) + def test_axis_less_than_zero(self): # Using paddle.disable_static will make other unittests fail. with base.dygraph.guard(): diff --git a/test/legacy_test/test_softmax_op.py b/test/legacy_test/test_softmax_op.py index ae98b43476619..74b685333d925 100644 --- a/test/legacy_test/test_softmax_op.py +++ b/test/legacy_test/test_softmax_op.py @@ -22,6 +22,7 @@ import paddle.nn.functional as F from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api np.random.seed(10) @@ -512,6 +513,7 @@ def setUp(self): def executed_api(self): self.softmax = F.softmax + @test_with_pir_api def test_static_check(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -590,6 +592,7 @@ def test_dygraph(self): paddle.enable_static() + @test_with_pir_api def test_static(self): with static_guard(): main_prog = base.Program() @@ -597,18 +600,17 @@ def test_static(self): x = paddle.rand([]) x.stop_gradient = False out = paddle.nn.functional.softmax(x) - base.backward.append_backward(out) # Test compile shape - self.assertEqual(x.shape, ()) - self.assertEqual(out.shape, ()) + self.assertEqual(tuple(x.shape), ()) + self.assertEqual(tuple(out.shape), ()) exe = base.Executor() result = exe.run(main_prog, fetch_list=[x, out]) # Test runtime shape - self.assertEqual(result[0].shape, ()) - self.assertEqual(result[1].shape, ()) + self.assertEqual(tuple(result[0].shape), ()) + self.assertEqual(tuple(result[1].shape), ()) class TestSoftmaxInplaceAPI(TestSoftmaxAPI): diff --git a/test/legacy_test/test_solve_op.py b/test/legacy_test/test_solve_op.py index 1d15da4019e65..040cf1a80fa06 100644 --- a/test/legacy_test/test_solve_op.py +++ b/test/legacy_test/test_solve_op.py @@ -50,10 +50,10 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_pir=True) # x broadcast + 3D batch case @@ -71,10 +71,12 @@ def setUp(self): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=1e-1) + self.check_grad( + ['X', 'Y'], 'Out', max_relative_error=1e-1, check_pir=True + ) # 3D batch + y vector case @@ -92,10 +94,12 @@ def setUp(self): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.04) + self.check_grad( + ['X', 'Y'], 'Out', max_relative_error=0.04, check_pir=True + ) # 3D batch + y broadcast case @@ -113,10 +117,12 @@ def setUp(self): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.02) + self.check_grad( + ['X', 'Y'], 'Out', max_relative_error=0.02, check_pir=True + ) # x broadcast + 3D batch case @@ -134,10 +140,12 @@ def setUp(self): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.02) + self.check_grad( + ['X', 'Y'], 'Out', max_relative_error=0.02, check_pir=True + ) # 3D normal batch case @@ -155,10 +163,10 @@ def setUp(self): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_pir=True) # 4D normal batch case @@ -176,10 +184,10 @@ def setUp(self): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_pir=True) # 4D batch + y broadcast case @@ -197,10 +205,10 @@ def setUp(self): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_pir=True) # 5D normal batch case @@ -218,10 +226,12 @@ def setUp(self): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.04) + self.check_grad( + ['X', 'Y'], 'Out', max_relative_error=0.04, check_pir=True + ) # 5D batch + y broadcast case @@ -239,10 +249,12 @@ def setUp(self): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.04) + self.check_grad( + ['X', 'Y'], 'Out', max_relative_error=0.04, check_pir=True + ) class TestSolveOpError(unittest.TestCase): diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index a192078899dd7..0a674009651d0 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import Program, core, program_guard +from paddle.pir_utils import test_with_pir_api class TestSplitOp(OpTest): @@ -321,39 +322,43 @@ def test_check_grad(self): class TestSplitAPI(unittest.TestCase): + @test_with_pir_api def test_api(self): - input_1 = np.random.random([4, 5, 6]).astype("int32") - positive_1_int32 = paddle.tensor.fill_constant([1], "int32", 1) - positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1) - positive_2_int64 = paddle.tensor.fill_constant([1], "int64", 2) - x_1 = paddle.static.data(shape=[4, 5, 6], dtype='int32', name='x_1') - x_2 = paddle.static.data(shape=[4, 5, None], dtype='int32', name='x_2') - - out_0, out_1, out_2 = paddle.split( - x=x_1, - num_or_sections=[positive_2_int64, positive_1_int32, -1], - axis=positive_1_int64, - ) + with paddle.static.program_guard(paddle.static.Program()): + input_1 = np.random.random([4, 5, 6]).astype("int32") + positive_1_int32 = paddle.tensor.fill_constant([1], "int32", 1) + positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1) + positive_2_int64 = paddle.tensor.fill_constant([1], "int64", 2) + x_1 = paddle.static.data(shape=[4, 5, 6], dtype='int32', name='x_1') + x_2 = paddle.static.data( + shape=[4, 5, None], dtype='int32', name='x_2' + ) - out_3, out_4, out_5 = paddle.split( - x=x_1, num_or_sections=[2, 1, 2], axis=positive_1_int32 - ) - paddle.split(x=x_2, num_or_sections=2, axis=2) + out_0, out_1, out_2 = paddle.split( + x=x_1, + num_or_sections=[positive_2_int64, positive_1_int32, -1], + axis=positive_1_int64, + ) - exe = base.Executor(place=base.CPUPlace()) - [res_0, res_1, res_2, res_3, res_4, res_5] = exe.run( - base.default_main_program(), - feed={"x_1": input_1, "x_2": input_1}, - fetch_list=[out_0, out_1, out_2, out_3, out_4, out_5], - ) + out_3, out_4, out_5 = paddle.split( + x=x_1, num_or_sections=[2, 1, 2], axis=positive_1_int32 + ) + paddle.split(x=x_2, num_or_sections=2, axis=2) + + exe = base.Executor(place=base.CPUPlace()) + [res_0, res_1, res_2, res_3, res_4, res_5] = exe.run( + paddle.static.default_main_program(), + feed={"x_1": input_1, "x_2": input_1}, + fetch_list=[out_0, out_1, out_2, out_3, out_4, out_5], + ) - out = np.split(input_1, [2, 3], 1) - np.testing.assert_array_equal(res_0, out[0]) - np.testing.assert_array_equal(res_1, out[1]) - np.testing.assert_array_equal(res_2, out[2]) - np.testing.assert_array_equal(res_3, out[0]) - np.testing.assert_array_equal(res_4, out[1]) - np.testing.assert_array_equal(res_5, out[2]) + out = np.split(input_1, [2, 3], 1) + np.testing.assert_array_equal(res_0, out[0]) + np.testing.assert_array_equal(res_1, out[1]) + np.testing.assert_array_equal(res_2, out[2]) + np.testing.assert_array_equal(res_3, out[0]) + np.testing.assert_array_equal(res_4, out[1]) + np.testing.assert_array_equal(res_5, out[2]) class TestSplitOpError(unittest.TestCase): @@ -417,14 +422,13 @@ def test_0_num_tensor(): class API_TestSplit(unittest.TestCase): + @test_with_pir_api def test_out(self): with base.program_guard(base.Program(), base.Program()): data1 = paddle.static.data( - 'data1', shape=[-1, 4, 6, 6], dtype='float64' + 'data1', shape=[4, 6, 6], dtype='float64' ) - data1.desc.set_need_check_feed(False) - data2 = paddle.static.data('data2', shape=[-1, 1], dtype='int32') - data2.desc.set_need_check_feed(False) + data2 = paddle.static.data('data2', shape=[1], dtype='int32') x0, x1, x2 = paddle.split(data1, num_or_sections=3, axis=data2) place = base.CPUPlace() exe = base.Executor(place) @@ -444,12 +448,12 @@ def test_out(self): class API_TestSplit2(unittest.TestCase): + @test_with_pir_api def test_out(self): with base.program_guard(base.Program(), base.Program()): data1 = paddle.static.data( - 'data1', shape=[-1, 4, 6, 6], dtype='float64' + 'data1', shape=[4, 6, 6], dtype='float64' ) - data1.desc.set_need_check_feed(False) x0, x1, x2 = paddle.split(data1, num_or_sections=3, axis=2) place = base.CPUPlace() exe = base.Executor(place) @@ -466,6 +470,7 @@ def test_out(self): class API_TestSplit3(unittest.TestCase): + @test_with_pir_api def test_out(self): with base.program_guard(base.Program(), base.Program()): data = paddle.static.data('data', shape=[-1, 10], dtype='float64') @@ -480,6 +485,7 @@ def test_out(self): class API_TestSplit4(unittest.TestCase): + @test_with_pir_api def test_out(self): with base.program_guard(base.Program(), base.Program()): data = paddle.static.data('data', shape=[-1, 10], dtype='float64') @@ -498,6 +504,7 @@ def test_out(self): class API_TestSplit5(unittest.TestCase): + @test_with_pir_api def test_out(self): for use_cuda in ( [False, True] if core.is_compiled_with_cuda() else [False] @@ -518,6 +525,7 @@ def test_out(self): class API_TestSplit6(unittest.TestCase): + @test_with_pir_api def test_out(self): with base.program_guard(base.Program(), base.Program()): data = paddle.static.data('data', shape=[-1, 10], dtype='float64') diff --git a/test/legacy_test/test_splits_api.py b/test/legacy_test/test_splits_api.py index 2b562179b8752..4e319e6cb4b91 100644 --- a/test/legacy_test/test_splits_api.py +++ b/test/legacy_test/test_splits_api.py @@ -18,6 +18,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api def func_ref(func, x, num_or_sections): @@ -52,6 +53,7 @@ def set_input(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): paddle.enable_static() for func, func_type in test_list: @@ -166,6 +168,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_error(self): paddle.enable_static() for func, _ in test_list: diff --git a/test/legacy_test/test_squeeze2_op.py b/test/legacy_test/test_squeeze2_op.py index f1a689024de8a..f7470e1b0ef01 100755 --- a/test/legacy_test/test_squeeze2_op.py +++ b/test/legacy_test/test_squeeze2_op.py @@ -291,6 +291,16 @@ def test_axes_type(): self.assertRaises(TypeError, test_axes_type) + def test_pir_error(self): + def test_axes_type(): + with paddle.pir_utils.IrGuard(): + x2 = paddle.static.data( + name="x2", shape=[2, 1, 25], dtype="int32" + ) + self.squeeze(x2, axis=2.1) + + self.assertRaises(ValueError, test_axes_type) + class TestSqueezeInplaceAPI(TestSqueezeAPI): def executed_api(self): diff --git a/test/legacy_test/test_stack_op.py b/test/legacy_test/test_stack_op.py index fb8eda704db6a..472777b9cfd72 100644 --- a/test/legacy_test/test_stack_op.py +++ b/test/legacy_test/test_stack_op.py @@ -19,7 +19,7 @@ import paddle from paddle import base -from paddle.base.framework import Program, program_guard +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -220,11 +220,10 @@ def setUp(self): if base.is_compiled_with_cuda() else base.CPUPlace() ) - self.set_program() - def set_program(self): - self.program = base.Program() - with base.program_guard(self.program): + def test_case(self): + self.program = paddle.static.Program() + with paddle.static.program_guard(self.program): input = paddle.assign(self.x) tensor_array = paddle.tensor.create_array(dtype='float32') zero = paddle.tensor.fill_constant( @@ -235,8 +234,6 @@ def set_program(self): paddle.tensor.array_write(input, zero + i, tensor_array) self.out_var = paddle.stack(tensor_array, axis=self.axis) - - def test_case(self): self.assertTrue(self.out_var.shape[self.axis] == -1) exe = base.Executor(self.place) res = exe.run(self.program, fetch_list=self.out_var) @@ -260,11 +257,10 @@ def setUp(self): if base.is_compiled_with_cuda() else base.CPUPlace() ) - self.set_program() - def set_program(self): - self.program = base.Program() - with base.program_guard(self.program): + def test_case(self): + self.program = paddle.static.Program() + with paddle.static.program_guard(self.program): input = paddle.assign(self.x) tensor_array = paddle.tensor.create_array(dtype='float32') zero = paddle.tensor.fill_constant( @@ -275,8 +271,6 @@ def set_program(self): paddle.tensor.array_write(input, zero + i, tensor_array) self.out_var = paddle.stack(tensor_array, axis=self.axis) - - def test_case(self): self.assertTrue(self.out_var.shape[self.axis] == -1) exe = base.Executor(self.place) res = exe.run(self.program, fetch_list=self.out_var) @@ -286,8 +280,11 @@ def test_case(self): class API_test(unittest.TestCase): + @test_with_pir_api def test_out(self): - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): data1 = paddle.static.data('data1', shape=[1, 2], dtype='float64') data2 = paddle.static.data('data2', shape=[1, 2], dtype='float64') data3 = paddle.static.data('data3', shape=[1, 2], dtype='float64') @@ -309,6 +306,11 @@ def test_single_tensor_error(self): x = paddle.rand([2, 3]) self.assertRaises(TypeError, paddle.stack, x) + def test_pir_single_tensor_error(self): + with paddle.pir_utils.IrGuard(): + x = paddle.rand([2, 3]) + self.assertRaises(ValueError, paddle.stack, x) + class API_DygraphTest(unittest.TestCase): def test_out(self): @@ -338,9 +340,10 @@ def test_single_tensor_error(self): class TestStackOpWithNegativeShape(unittest.TestCase): + @test_with_pir_api def test_out(self): - main_prg, startup_prg = Program(), Program() - with program_guard(main_prg, startup_prg): + main_prg, startup_prg = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(main_prg, startup_prg): b = paddle.static.data(name='b', shape=[-1], dtype='int64') e = paddle.static.data(name='e', shape=[3], dtype='int64') k = paddle.stack([b, e], axis=0) diff --git a/test/legacy_test/test_std_layer.py b/test/legacy_test/test_std_layer.py index 22ef341259142..aed3e750402e5 100644 --- a/test/legacy_test/test_std_layer.py +++ b/test/legacy_test/test_std_layer.py @@ -17,6 +17,7 @@ import numpy as np import paddle +from paddle.pir_utils import test_with_pir_api def ref_std(x, axis=None, unbiased=True, keepdim=False): @@ -61,6 +62,7 @@ def dygraph(self): paddle.enable_static() return out.numpy() + @test_with_pir_api def test_api(self): out_ref = ref_std(self.x, self.axis, self.unbiased, self.keepdim) out_dygraph = self.dygraph() @@ -120,6 +122,7 @@ def test_error(self): class Testfp16Std(unittest.TestCase): + @test_with_pir_api def test_fp16_with_gpu(self): paddle.enable_static() if paddle.base.core.is_compiled_with_cuda(): diff --git a/test/legacy_test/test_stride.py b/test/legacy_test/test_stride.py index a80451e36fdc4..ffeeade304ce5 100644 --- a/test/legacy_test/test_stride.py +++ b/test/legacy_test/test_stride.py @@ -640,7 +640,7 @@ def test_stride_gpu(self): class TestToStaticCheck(unittest.TestCase): def test_error(self): - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def func(): x_np = np.random.random(size=[2, 3, 4]).astype('float32') x = paddle.to_tensor(x_np) @@ -650,7 +650,7 @@ def func(): self.assertRaises(ValueError, func) def test_no_error(self): - @paddle.jit.to_static + @paddle.jit.to_static(full_graph=True) def func(): x_np = np.random.random(size=[2, 3, 4]).astype('float32') x = paddle.to_tensor(x_np) diff --git a/test/legacy_test/test_sum_op.py b/test/legacy_test/test_sum_op.py index d8536bc771955..ecd60da843b80 100644 --- a/test/legacy_test/test_sum_op.py +++ b/test/legacy_test/test_sum_op.py @@ -28,6 +28,7 @@ from paddle import base, enable_static from paddle.base import core from paddle.base.layer_helper import LayerHelper +from paddle.pir_utils import test_with_pir_api def sum_wrapper(X, use_mkldnn=False): @@ -393,6 +394,7 @@ def test_check_grad(self): class API_Test_Add_n(unittest.TestCase): + @test_with_pir_api def test_api(self): with base.program_guard(base.Program(), base.Program()): input0 = paddle.tensor.fill_constant( diff --git a/test/legacy_test/test_take_along_axis_op.py b/test/legacy_test/test_take_along_axis_op.py index b86bb0222ec7e..54aa0388c7541 100644 --- a/test/legacy_test/test_take_along_axis_op.py +++ b/test/legacy_test/test_take_along_axis_op.py @@ -19,6 +19,7 @@ import paddle from paddle.framework import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -43,10 +44,12 @@ def setUp(self): self.outputs = {'Result': self.target} def test_check_output(self): - self.check_output(check_cinn=self.check_cinn) + self.check_output(check_cinn=self.check_cinn, check_pir=True) def test_check_grad(self): - self.check_grad(['Input'], 'Result', check_cinn=self.check_cinn) + self.check_grad( + ['Input'], 'Result', check_cinn=self.check_cinn, check_pir=True + ) def init_data(self): self.x_type = "float64" @@ -101,11 +104,17 @@ def setUp(self): self.place = core.CUDAPlace(0) def test_check_output(self): - self.check_output_with_place(self.place, check_cinn=self.check_cinn) + self.check_output_with_place( + self.place, check_cinn=self.check_cinn, check_pir=True + ) def test_check_grad(self): self.check_grad_with_place( - self.place, ['Input'], 'Result', check_cinn=self.check_cinn + self.place, + ['Input'], + 'Result', + check_cinn=self.check_cinn, + check_pir=True, ) def init_data(self): @@ -142,6 +151,7 @@ def setUp(self): if core.is_compiled_with_cuda(): self.place.append(paddle.CUDAPlace(0)) + @test_with_pir_api def test_api_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): diff --git a/test/legacy_test/test_tensor_fill_diagonal_tensor.py b/test/legacy_test/test_tensor_fill_diagonal_tensor.py index 7409cdae1f007..cf3493d603976 100644 --- a/test/legacy_test/test_tensor_fill_diagonal_tensor.py +++ b/test/legacy_test/test_tensor_fill_diagonal_tensor.py @@ -17,7 +17,6 @@ import numpy as np import paddle -import paddle.nn.functional as F from paddle import base @@ -202,9 +201,9 @@ def test_largedim(self): loss.backward() expected_pred = v - 2 - expected_pred = F.diag_embed(expected_pred) + 2 + expected_pred = paddle.diag_embed(expected_pred) + 2 expected_grad = paddle.ones(v.shape, dtype=dtype) - 2 - expected_grad = F.diag_embed(expected_grad) + 1 + expected_grad = paddle.diag_embed(expected_grad) + 1 self.assertEqual((ny == expected_pred).all(), True) self.assertEqual((y.grad == expected_grad).all(), True) diff --git a/test/legacy_test/test_tensor_fill_diagonal_tensor_.py b/test/legacy_test/test_tensor_fill_diagonal_tensor_.py index 482f3e542f6fc..7966470e4e8fb 100644 --- a/test/legacy_test/test_tensor_fill_diagonal_tensor_.py +++ b/test/legacy_test/test_tensor_fill_diagonal_tensor_.py @@ -17,7 +17,6 @@ import numpy as np import paddle -import paddle.nn.functional as F from paddle import base @@ -203,9 +202,9 @@ def test_largedim(self): loss.backward() expected_pred = v - 2 - expected_pred = F.diag_embed(expected_pred) + 2 + expected_pred = paddle.diag_embed(expected_pred) + 2 expected_grad = paddle.ones(v.shape, dtype=dtype) - 2 - expected_grad = F.diag_embed(expected_grad) + 1 + expected_grad = paddle.diag_embed(expected_grad) + 1 self.assertEqual((y == expected_pred).all(), True) self.assertEqual((y.grad == expected_grad).all(), True) diff --git a/test/legacy_test/test_tensordot.py b/test/legacy_test/test_tensordot.py index 16d2015573d10..0e41772abd6cb 100644 --- a/test/legacy_test/test_tensordot.py +++ b/test/legacy_test/test_tensordot.py @@ -342,9 +342,21 @@ def test_error(self): paddle.disable_static() x = paddle.to_tensor(self.x) y = paddle.to_tensor(self.y) - for axes in self.all_axes: - with self.assertRaises(BaseException): - paddle.tensordot(x, y, axes) + + with self.assertRaises(TypeError): + paddle.tensordot(x, y, axes=self.all_axes[0]) + with self.assertRaises(TypeError): + paddle.tensordot(x, y, axes=self.all_axes[1]) + with self.assertRaises(AssertionError): + paddle.tensordot(x, y, axes=self.all_axes[2]) + with self.assertRaises(IndexError): + paddle.tensordot(x, y, axes=self.all_axes[3]) + with self.assertRaises(ValueError): + paddle.tensordot(x, y, axes=self.all_axes[4]) + with self.assertRaises(AssertionError): + paddle.tensordot(x, y, axes=self.all_axes[5]) + with self.assertRaises(AssertionError): + paddle.tensordot(x, y, axes=self.all_axes[6]) class TestTensordotAPIAxesTypeFloat64(TestTensordotAPIAxesType): diff --git a/test/legacy_test/test_top_k_v2_op.py b/test/legacy_test/test_top_k_v2_op.py index 9ff5d03473afc..41d021c9085ad 100644 --- a/test/legacy_test/test_top_k_v2_op.py +++ b/test/legacy_test/test_top_k_v2_op.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api def numpy_topk(x, k=1, axis=-1, largest=True): @@ -63,10 +64,10 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) class TestTopkOp_ZeroDim(TestTopkOp): @@ -270,11 +271,13 @@ def if_enable_cinn(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', check_prim=True) + self.check_grad_with_place( + place, ['X'], 'Out', check_prim=True, check_pir=True + ) class TestTopKAPI(unittest.TestCase): @@ -377,8 +380,8 @@ def run_static(self, place): result1 = paddle.topk(input_tensor, k=2) result2 = paddle.topk(input_tensor, k=2, axis=-1) result3 = paddle.topk(input_tensor, k=k_tensor, axis=1) - self.assertEqual(result3[0].shape, (6, -1, 8)) - self.assertEqual(result3[1].shape, (6, -1, 8)) + self.assertEqual(tuple(result3[0].shape), (6, -1, 8)) + self.assertEqual(tuple(result3[1].shape), (6, -1, 8)) result4 = paddle.topk(input_tensor, k=2, axis=1, largest=False) result5 = paddle.topk(input_tensor, k=2, axis=-1, largest=False) result6 = paddle.topk(large_input_tensor, k=1, axis=-1) @@ -461,21 +464,28 @@ def run_static(self, place): sort_paddle[0], numpy_result[0], rtol=1e-05 ) - def test_cases(self): + def test_dygraph_cases(self): places = [core.CPUPlace()] if core.is_compiled_with_cuda(): places.append(core.CUDAPlace(0)) for place in places: self.run_dygraph(place) + + @test_with_pir_api + def test_static_cases(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + for place in places: self.run_static(place) def test_errors(self): with paddle.base.dygraph.guard(): x = paddle.to_tensor([1, 2, 3]) - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): paddle.topk(x, k=-1) - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): paddle.topk(x, k=0) diff --git a/test/legacy_test/test_trace_op.py b/test/legacy_test/test_trace_op.py index 1d53c1180b836..a62c9e7f9aa8a 100644 --- a/test/legacy_test/test_trace_op.py +++ b/test/legacy_test/test_trace_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base, tensor from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestTraceOp(OpTest): @@ -30,10 +31,10 @@ def setUp(self): self.outputs = {'Out': self.target} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['Input'], 'Out') + self.check_grad(['Input'], 'Out', check_pir=True) def init_config(self): self.case = np.random.randn(20, 6).astype('float64') @@ -108,11 +109,15 @@ def setUp(self): self.place = core.CUDAPlace(0) def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, check_pir=True) def test_check_grad(self): self.check_grad_with_place( - self.place, ['Input'], 'Out', numeric_grad_delta=0.02 + self.place, + ['Input'], + 'Out', + numeric_grad_delta=0.02, + check_pir=True, ) def init_config(self): @@ -145,22 +150,24 @@ def init_config(self): class TestTraceAPICase(unittest.TestCase): + @test_with_pir_api def test_case1(self): - case = np.random.randn(2, 20, 2, 3).astype('float32') - data1 = paddle.static.data( - name='data1', shape=[2, 20, 2, 3], dtype='float32' - ) - out1 = tensor.trace(data1) - out2 = tensor.trace(data1, offset=-5, axis1=1, axis2=-1) - - place = core.CPUPlace() - exe = base.Executor(place) - results = exe.run( - base.default_main_program(), - feed={"data1": case}, - fetch_list=[out1, out2], - return_numpy=True, - ) + with paddle.static.program_guard(paddle.static.Program()): + case = np.random.randn(2, 20, 2, 3).astype('float32') + data1 = paddle.static.data( + name='data1', shape=[2, 20, 2, 3], dtype='float32' + ) + out1 = tensor.trace(data1) + out2 = tensor.trace(data1, offset=-5, axis1=1, axis2=-1) + + place = core.CPUPlace() + exe = base.Executor(place) + results = exe.run( + paddle.static.default_main_program(), + feed={"data1": case}, + fetch_list=[out1, out2], + return_numpy=True, + ) target1 = np.trace(case) target2 = np.trace(case, offset=-5, axis1=1, axis2=-1) np.testing.assert_allclose(results[0], target1, rtol=1e-05) diff --git a/test/legacy_test/test_transpose_op.py b/test/legacy_test/test_transpose_op.py index 32f071eafb472..4752c8c26bd33 100644 --- a/test/legacy_test/test_transpose_op.py +++ b/test/legacy_test/test_transpose_op.py @@ -22,6 +22,7 @@ import paddle from paddle import base from paddle.base import Program, core, program_guard +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -541,6 +542,7 @@ def test_each_elem_value_check(): class TestTransposeApi(unittest.TestCase): + @test_with_pir_api def test_static_out(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): @@ -578,7 +580,8 @@ def test_dygraph_out(self): class TestTAPI(unittest.TestCase): - def test_out(self): + @test_with_pir_api + def test_static_out(self): with base.program_guard(base.Program()): data = paddle.static.data(shape=[10], dtype="float64", name="data") data_t = paddle.t(data) @@ -613,6 +616,7 @@ def test_out(self): expected_result = np.transpose(data_np) self.assertEqual((result == expected_result).all(), True) + def test_dygraph_out(self): with base.dygraph.guard(): np_x = np.random.random([10]).astype("float64") data = base.dygraph.to_variable(np_x) @@ -637,6 +641,7 @@ def test_out(self): z_expected = np.array(np.transpose(np_x)) self.assertEqual((np_z == z_expected).all(), True) + @test_with_pir_api def test_errors(self): with base.program_guard(base.Program()): x = paddle.static.data(name='x', shape=[10, 5, 3], dtype='float64') @@ -648,7 +653,8 @@ def test_x_dimension_check(): class TestMoveAxis(unittest.TestCase): - def test_moveaxis1(self): + @test_with_pir_api + def test_static_moveaxis1(self): x_np = np.random.randn(2, 3, 4, 5, 7) expected = np.moveaxis(x_np, [0, 4, 3, 2], [1, 3, 2, 0]) paddle.enable_static() @@ -661,6 +667,9 @@ def test_moveaxis1(self): np.testing.assert_array_equal(out_np, expected) + def test_dygraph_moveaxis1(self): + x_np = np.random.randn(2, 3, 4, 5, 7) + expected = np.moveaxis(x_np, [0, 4, 3, 2], [1, 3, 2, 0]) paddle.disable_static() x = paddle.to_tensor(x_np) out = paddle.moveaxis(x, [0, 4, 3, 2], [1, 3, 2, 0]) @@ -668,7 +677,8 @@ def test_moveaxis1(self): np.testing.assert_array_equal(out.numpy(), expected) paddle.enable_static() - def test_moveaxis2(self): + @test_with_pir_api + def test_static_moveaxis2(self): x_np = np.random.randn(2, 3, 5) expected = np.moveaxis(x_np, -2, -1) paddle.enable_static() @@ -681,6 +691,9 @@ def test_moveaxis2(self): np.testing.assert_array_equal(out_np, expected) + def test_dygraph_moveaxis2(self): + x_np = np.random.randn(2, 3, 5) + expected = np.moveaxis(x_np, -2, -1) paddle.disable_static() x = paddle.to_tensor(x_np) out = x.moveaxis(-2, -1) @@ -697,6 +710,7 @@ def test_moveaxis3(self): self.assertEqual(out.shape, [2, 3]) paddle.enable_static() + @test_with_pir_api def test_error(self): x = paddle.randn([2, 3, 4, 5]) # src must have the same number with dst diff --git a/test/legacy_test/test_triangular_solve_op.py b/test/legacy_test/test_triangular_solve_op.py index abdc9d6fe1bd6..f3624b5332817 100644 --- a/test/legacy_test/test_triangular_solve_op.py +++ b/test/legacy_test/test_triangular_solve_op.py @@ -64,10 +64,10 @@ def setUp(self): self.outputs = {'Out': self.output} def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', check_cinn=True) + self.check_grad(['X', 'Y'], 'Out', check_cinn=True, check_pir=True) # 2D(broadcast) + 3D, test 'transpose' diff --git a/test/legacy_test/test_tril_indices_op.py b/test/legacy_test/test_tril_indices_op.py index 9cef30a02519f..19336b11ec994 100644 --- a/test/legacy_test/test_tril_indices_op.py +++ b/test/legacy_test/test_tril_indices_op.py @@ -19,6 +19,7 @@ import paddle from paddle import base +from paddle.pir_utils import test_with_pir_api class TestTrilIndicesOp(OpTest): @@ -31,7 +32,7 @@ def setUp(self): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_pir=True) def init_config(self): self.attrs = {'rows': 4, 'cols': 4, 'offset': -1} @@ -58,6 +59,7 @@ def init_config(self): class TestTrilIndicesAPICaseStatic(unittest.TestCase): + @test_with_pir_api def test_static(self): places = ( [paddle.CPUPlace(), paddle.base.CUDAPlace(0)] @@ -109,6 +111,7 @@ def test_num_offset_type_check(): class TestTrilIndicesAPICaseDefault(unittest.TestCase): + @test_with_pir_api def test_default_CPU(self): paddle.enable_static() with paddle.static.program_guard( diff --git a/test/legacy_test/test_trilinear_interp_v2_op.py b/test/legacy_test/test_trilinear_interp_v2_op.py index 71d249d49908a..45511da5754b0 100755 --- a/test/legacy_test/test_trilinear_interp_v2_op.py +++ b/test/legacy_test/test_trilinear_interp_v2_op.py @@ -362,10 +362,10 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) def init_test_case(self): create_test_case0(self) @@ -454,10 +454,10 @@ def init_test_case(self): class TestTrilinearInterpOpFP16(TestTrilinearInterpOp): def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) def init_test_case(self): create_test_case0(self) @@ -591,10 +591,10 @@ def setUp(self): self.outputs = {'Out': convert_float_to_uint16(output_np)} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) def init_test_case(self): create_test_case0(self) @@ -724,7 +724,9 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output_with_place(place=core.CPUPlace(), atol=1) + self.check_output_with_place( + place=core.CPUPlace(), atol=1, check_pir=True + ) def init_test_case(self): self.interp_method = 'trilinear' @@ -902,10 +904,10 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X'], 'Out', in_place=True, check_pir=True) def init_test_case(self): self.interp_method = 'trilinear' diff --git a/test/legacy_test/test_trunc_op.py b/test/legacy_test/test_trunc_op.py index e67c0d94b78bc..3f157fe879b05 100644 --- a/test/legacy_test/test_trunc_op.py +++ b/test/legacy_test/test_trunc_op.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -36,10 +37,10 @@ def init_dtype_type(self): self.dtype = np.float64 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5) + self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5, check_pir=True) class TestFloatTruncOp(TestTruncOp): @@ -66,6 +67,7 @@ def setUp(self): self.x = np.random.random((20, 20)).astype(np.float32) self.place = paddle.CPUPlace() + @test_with_pir_api def test_api_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): @@ -114,11 +116,13 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', numeric_grad_delta=1e-5) + self.check_grad_with_place( + place, ['X'], 'Out', numeric_grad_delta=1e-5, check_pir=True + ) if __name__ == "__main__": diff --git a/test/legacy_test/test_uniform_random_bf16_op.py b/test/legacy_test/test_uniform_random_bf16_op.py index 1c32b30bf6899..9fbc54b961579 100644 --- a/test/legacy_test/test_uniform_random_bf16_op.py +++ b/test/legacy_test/test_uniform_random_bf16_op.py @@ -22,6 +22,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api from paddle.tensor import random @@ -160,6 +161,7 @@ def check_with_place(self, place): class TestUniformRandomOpAPISeed(unittest.TestCase): + @test_with_pir_api def test_attr_tensor_API(self): _seed = 10 gen = paddle.seed(_seed) diff --git a/test/legacy_test/test_uniform_random_op.py b/test/legacy_test/test_uniform_random_op.py index 1e301f53d7fc2..c8daff881a27a 100644 --- a/test/legacy_test/test_uniform_random_op.py +++ b/test/legacy_test/test_uniform_random_op.py @@ -24,6 +24,7 @@ from paddle import base from paddle.base import Program, core, program_guard from paddle.base.framework import convert_np_dtype_to_dtype_ +from paddle.pir_utils import test_with_pir_api from paddle.tensor import random @@ -223,8 +224,13 @@ def test_Variable2(): self.assertRaises(TypeError, test_Variable2) def test_out_dtype(): - out = paddle.uniform(shape=[3, 4], dtype='float64') - self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP64) + out = paddle.tensor.random.uniform( + shape=[3, 4], dtype='float64' + ) + if paddle.framework.in_pir_mode(): + self.assertEqual(out.dtype, base.core.DataType.FLOAT64) + else: + self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP64) test_out_dtype() paddle.disable_static() @@ -331,6 +337,7 @@ def test_api(self): class TestUniformRandomOp_attr_tensor_API(unittest.TestCase): + @test_with_pir_api def test_attr_tensor_API(self): paddle.enable_static() startup_program = base.Program() @@ -348,6 +355,7 @@ def test_attr_tensor_API(self): outs = exe.run(train_program, fetch_list=[ret]) paddle.disable_static() + @test_with_pir_api def test_attr_tensorlist_int32_API(self): paddle.enable_static() startup_program = base.Program() @@ -389,6 +397,7 @@ def test_attr_tensor_int32_API(self): class TestUniformRandomOp_API_seed(unittest.TestCase): + @test_with_pir_api def test_attr_tensor_API(self): paddle.enable_static() _seed = 10 @@ -490,6 +499,7 @@ def test_check_output(self): class TestUniformRandomBatchSizeLikeOpError(unittest.TestCase): + @test_with_pir_api def test_errors(self): paddle.enable_static() main_prog = Program() @@ -523,6 +533,7 @@ def test_dtype(): class TestUniformAlias(unittest.TestCase): + @test_with_pir_api def test_alias(self): paddle.uniform([2, 3], min=-5.0, max=5.0) paddle.tensor.uniform([2, 3], min=-5.0, max=5.0) @@ -535,6 +546,7 @@ def test_uniform_random(): class TestUniformOpError(unittest.TestCase): + @test_with_pir_api def test_errors(self): paddle.enable_static() main_prog = Program() @@ -567,7 +579,10 @@ def test_out_dtype(): out = paddle.tensor.random.uniform( shape=[3, 4], dtype='float64' ) - self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP64) + if paddle.framework.in_pir_mode(): + self.assertEqual(out.dtype, base.core.DataType.FLOAT64) + else: + self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP64) test_out_dtype() paddle.disable_static() diff --git a/test/legacy_test/test_unique.py b/test/legacy_test/test_unique.py index 8fe9dfa9af635..808cd8227bb7d 100644 --- a/test/legacy_test/test_unique.py +++ b/test/legacy_test/test_unique.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestUniqueOp(OpTest): @@ -413,6 +414,7 @@ def test_dygraph_attr_dtype(self): self.assertTrue((inverse.numpy() == np_inverse).all(), True) self.assertTrue((counts.numpy() == np_counts).all(), True) + @test_with_pir_api def test_static_graph(self): with paddle_static_guard(): with paddle.static.program_guard( diff --git a/test/legacy_test/test_unique_consecutive_op.py b/test/legacy_test/test_unique_consecutive_op.py index 36fd33490d18c..72ef3aa79b4a8 100644 --- a/test/legacy_test/test_unique_consecutive_op.py +++ b/test/legacy_test/test_unique_consecutive_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def reference_unique_consecutive( @@ -203,6 +204,7 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): with base.program_guard(base.Program(), base.Program()): paddle.enable_static() @@ -217,7 +219,6 @@ def check_static_result(self, place): x_np = np.random.randint(20, size=100).astype("float32") exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), feed={"input_x": x_np}, fetch_list=[result], ) @@ -240,6 +241,7 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): with base.program_guard(base.Program(), base.Program()): paddle.enable_static() @@ -256,7 +258,6 @@ def check_static_result(self, place): x_np = np.random.randint(20, size=100).astype("float32") exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), feed={"input_x": x_np}, fetch_list=[result], ) @@ -281,6 +282,7 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): with base.program_guard(base.Program(), base.Program()): paddle.enable_static() @@ -297,7 +299,6 @@ def check_static_result(self, place): x_np = np.random.randint(20, size=100).astype("float32") exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), feed={"input_x": x_np}, fetch_list=[result], ) @@ -347,7 +348,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) if __name__ == "__main__": diff --git a/test/legacy_test/test_unsqueeze2_op.py b/test/legacy_test/test_unsqueeze2_op.py index cb1a6c868671e..10246419fef5b 100755 --- a/test/legacy_test/test_unsqueeze2_op.py +++ b/test/legacy_test/test_unsqueeze2_op.py @@ -18,6 +18,7 @@ from op_test import OpTest import paddle +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -257,34 +258,38 @@ def setUp(self): def executed_api(self): self.unsqueeze = paddle.unsqueeze + @test_with_pir_api def test_api(self): - input = np.random.random([3, 2, 5]).astype("float64") - x = paddle.static.data(name='x', shape=[3, 2, 5], dtype="float64") - positive_3_int32 = paddle.tensor.fill_constant([1], "int32", 3) - positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1) - axes_tensor_int32 = paddle.static.data( - name='axes_tensor_int32', shape=[3], dtype="int32" - ) - axes_tensor_int64 = paddle.static.data( - name='axes_tensor_int64', shape=[3], dtype="int64" - ) + with paddle.static.program_guard(paddle.static.Program()): + input = np.random.random([3, 2, 5]).astype("float64") + x = paddle.static.data(name='x', shape=[3, 2, 5], dtype="float64") + positive_3_int32 = paddle.tensor.fill_constant([1], "int32", 3) + positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1) + axes_tensor_int32 = paddle.static.data( + name='axes_tensor_int32', shape=[3], dtype="int32" + ) + axes_tensor_int64 = paddle.static.data( + name='axes_tensor_int64', shape=[3], dtype="int64" + ) - out_1 = self.unsqueeze(x, axis=[3, 1, 1]) - out_2 = self.unsqueeze(x, axis=[positive_3_int32, positive_1_int64, 1]) - out_3 = self.unsqueeze(x, axis=axes_tensor_int32) - out_4 = self.unsqueeze(x, axis=3) - out_5 = self.unsqueeze(x, axis=axes_tensor_int64) - - exe = paddle.static.Executor(place=paddle.CPUPlace()) - res_1, res_2, res_3, res_4, res_5 = exe.run( - paddle.static.default_main_program(), - feed={ - "x": input, - "axes_tensor_int32": np.array([3, 1, 1]).astype("int32"), - "axes_tensor_int64": np.array([3, 1, 1]).astype("int64"), - }, - fetch_list=[out_1, out_2, out_3, out_4, out_5], - ) + out_1 = self.unsqueeze(x, axis=[3, 1, 1]) + out_2 = self.unsqueeze( + x, axis=[positive_3_int32, positive_1_int64, 1] + ) + out_3 = self.unsqueeze(x, axis=axes_tensor_int32) + out_4 = self.unsqueeze(x, axis=3) + out_5 = self.unsqueeze(x, axis=axes_tensor_int64) + + exe = paddle.static.Executor(place=paddle.CPUPlace()) + res_1, res_2, res_3, res_4, res_5 = exe.run( + paddle.static.default_main_program(), + feed={ + "x": input, + "axes_tensor_int32": np.array([3, 1, 1]).astype("int32"), + "axes_tensor_int64": np.array([3, 1, 1]).astype("int64"), + }, + fetch_list=[out_1, out_2, out_3, out_4, out_5], + ) np.testing.assert_array_equal(res_1, input.reshape([3, 1, 1, 2, 5, 1])) np.testing.assert_array_equal(res_2, input.reshape([3, 1, 1, 2, 5, 1])) @@ -299,6 +304,13 @@ def test_axes_type(): self.assertRaises(TypeError, test_axes_type) + def test_pir_axes_type(): + with paddle.pir_utils.IrGuard(): + x2 = paddle.static.data(name="x2", shape=[2, 25], dtype="int32") + self.unsqueeze(x2, axis=2.1) + + self.assertRaises(ValueError, test_pir_axes_type) + class TestUnsqueezeInplaceAPI(TestUnsqueezeAPI): def executed_api(self): diff --git a/test/legacy_test/test_unstack_op.py b/test/legacy_test/test_unstack_op.py index 4bc980025fe33..1175edceaf7db 100755 --- a/test/legacy_test/test_unstack_op.py +++ b/test/legacy_test/test_unstack_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestUnStackOpBase(OpTest): @@ -60,10 +61,10 @@ def setUp(self): self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], self.get_y_names()) + self.check_grad(['X'], self.get_y_names(), check_pir=True) class TestUnStackFP16Op(TestUnStackOpBase): @@ -164,7 +165,7 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): with base.dygraph.guard(): @@ -181,6 +182,7 @@ def test_check_grad(self): class TestUnstackZeroInputOp(unittest.TestCase): + @test_with_pir_api def unstack_zero_input_static(self): paddle.enable_static() diff --git a/test/legacy_test/test_var_base.py b/test/legacy_test/test_var_base.py index 748ac4ca608ab..6b388e2e7e4b1 100644 --- a/test/legacy_test/test_var_base.py +++ b/test/legacy_test/test_var_base.py @@ -87,6 +87,10 @@ def check_with_place(place): self.assertEqual(y.place.__repr__(), "Place(gpu:0)") y = x.cuda(blocking=True) self.assertEqual(y.place.__repr__(), "Place(gpu:0)") + y = x.cuda(device_id=0, blocking=True) + self.assertEqual(y.place.__repr__(), "Place(gpu:0)") + y = x.cuda(device_id=0, blocking=False) + self.assertEqual(y.place.__repr__(), "Place(gpu:0)") with self.assertRaises(ValueError): y = x.cuda("test") diff --git a/test/legacy_test/test_where_op.py b/test/legacy_test/test_where_op.py index ba9a5fbc3f0e1..89328610e9272 100644 --- a/test/legacy_test/test_where_op.py +++ b/test/legacy_test/test_where_op.py @@ -19,8 +19,10 @@ import paddle from paddle import base +from paddle.autograd.ir_backward import grad from paddle.base import Program, core, program_guard from paddle.base.backward import append_backward +from paddle.pir_utils import test_with_pir_api class TestWhereOp(OpTest): @@ -132,7 +134,9 @@ def ref_y_backward(self, dout): def test_api(self, use_cuda=False): for x_stop_gradient in [False, True]: for y_stop_gradient in [False, True]: - with base.program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): cond = paddle.static.data( name='cond', shape=[-1] + self.shape, dtype='bool' ) @@ -165,7 +169,7 @@ def test_api(self, use_cuda=False): if y_stop_gradient is False: fetch_list.append(y.grad_name) out = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={'cond': self.cond, 'x': self.x, 'y': self.y}, fetch_list=fetch_list, ) @@ -183,13 +187,66 @@ def test_api(self, use_cuda=False): out[2], self.ref_y_backward(out[1]) ) + def test_pir_api(self, use_cuda=False): + for x_stop_gradient in [False, True]: + for y_stop_gradient in [False, True]: + with paddle.pir_utils.IrGuard(), paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + cond = paddle.static.data( + name='cond', shape=self.shape, dtype='bool' + ) + x = paddle.static.data( + name='x', shape=self.shape, dtype='float32' + ) + y = paddle.static.data( + name='y', shape=self.shape, dtype='float32' + ) + x.stop_gradient = x_stop_gradient + y.stop_gradient = y_stop_gradient + result = paddle.where(cond, x, y) + result.stop_gradient = False + loss = paddle.mean(result) + [x_grad, y_grad] = grad(loss, (x, y)) + default_main_program = paddle.static.default_main_program() + fetch_list = [result] + if x_stop_gradient is False: + fetch_list.append(x_grad) + if y_stop_gradient is False: + fetch_list.append(y_grad) + for use_cuda in [False, True]: + if use_cuda and (not base.core.is_compiled_with_cuda()): + break + place = ( + base.CUDAPlace(0) if use_cuda else base.CPUPlace() + ) + exe = base.Executor(place) + + out = exe.run( + default_main_program, + feed={'cond': self.cond, 'x': self.x, 'y': self.y}, + fetch_list=fetch_list, + ) + np.testing.assert_array_equal(out[0], self.out) + if x_stop_gradient is False: + np.testing.assert_array_equal( + out[1], self.ref_x_backward(out[1]) + ) + if y.stop_gradient is False: + np.testing.assert_array_equal( + out[2], self.ref_y_backward(out[2]) + ) + elif y.stop_gradient is False: + np.testing.assert_array_equal( + out[1], self.ref_y_backward(out[1]) + ) + + @test_with_pir_api def test_api_broadcast(self, use_cuda=False): - main_program = Program() - with base.program_guard(main_program): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): x = paddle.static.data(name='x', shape=[-1, 4, 1], dtype='float32') - x.desc.set_need_check_feed(False) y = paddle.static.data(name='y', shape=[-1, 4, 2], dtype='float32') - y.desc.set_need_check_feed(False) x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype('float32') y_i = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype( 'float32' @@ -201,7 +258,7 @@ def test_api_broadcast(self, use_cuda=False): place = base.CUDAPlace(0) if use_cuda else base.CPUPlace() exe = base.Executor(place) out = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={'x': x_i, 'y': y_i}, fetch_list=[result], ) @@ -209,15 +266,14 @@ def test_api_broadcast(self, use_cuda=False): out[0], np.where((x_i > 1), x_i, y_i) ) + @test_with_pir_api def test_scalar(self): - paddle.enable_static() - main_program = Program() - with base.program_guard(main_program): - cond_shape = [2, 4] + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + cond_shape = [4] cond = paddle.static.data( - name='cond', shape=[-1] + cond_shape, dtype='bool' + name='cond', shape=cond_shape, dtype='bool' ) - cond.desc.set_need_check_feed(False) x_data = 1.0 y_data = 2.0 cond_data = np.array([False, False, True, True]).astype('bool') @@ -228,7 +284,7 @@ def test_scalar(self): place = base.CUDAPlace(0) if use_cuda else base.CPUPlace() exe = base.Executor(place) out = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={'cond': cond_data}, fetch_list=[result], ) @@ -237,20 +293,13 @@ def test_scalar(self): def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape): paddle.enable_static() - main_program = Program() - with base.program_guard(main_program): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): cond = paddle.static.data( - name='cond', shape=[-1] + cond_shape, dtype='bool' - ) - x = paddle.static.data( - name='x', shape=[-1] + x_shape, dtype='float32' + name='cond', shape=cond_shape, dtype='bool' ) - y = paddle.static.data( - name='y', shape=[-1] + y_shape, dtype='float32' - ) - x.desc.set_need_check_feed(False) - y.desc.set_need_check_feed(False) - cond.desc.set_need_check_feed(False) + x = paddle.static.data(name='x', shape=x_shape, dtype='float32') + y = paddle.static.data(name='y', shape=y_shape, dtype='float32') cond_data_tmp = np.random.random(size=cond_shape).astype('float32') cond_data = cond_data_tmp < 0.3 x_data = np.random.random(size=x_shape).astype('float32') @@ -262,55 +311,63 @@ def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape): place = base.CUDAPlace(0) if use_cuda else base.CPUPlace() exe = base.Executor(place) out = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={'cond': cond_data, 'x': x_data, 'y': y_data}, fetch_list=[result], ) expect = np.where(cond_data, x_data, y_data) np.testing.assert_array_equal(out[0], expect) + @test_with_pir_api def test_static_api_broadcast_1(self): cond_shape = [2, 4] a_shape = [2, 2, 4] b_shape = [2, 2, 4] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + @test_with_pir_api def test_static_api_broadcast_2(self): cond_shape = [2, 1] a_shape = [2, 2, 4] b_shape = [2, 2, 4] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + @test_with_pir_api def test_static_api_broadcast_3(self): cond_shape = [2, 2, 1] a_shape = [2, 2, 4] b_shape = [2, 2, 4] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + @test_with_pir_api def test_static_api_broadcast_4(self): cond_shape = [2, 1, 4] a_shape = [2, 2, 4] b_shape = [2, 2, 4] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + @test_with_pir_api def test_static_api_broadcast_5(self): cond_shape = [3, 2, 2, 4] a_shape = [2, 2, 4] b_shape = [2, 2, 4] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + @test_with_pir_api def test_static_api_broadcast_6(self): cond_shape = [2, 2, 4] a_shape = [2, 2, 1] b_shape = [2, 2, 1] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + @test_with_pir_api def test_static_api_broadcast_7(self): cond_shape = [2, 2, 4] a_shape = [2, 1, 4] b_shape = [2, 1, 4] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + @test_with_pir_api def test_static_api_broadcast_8(self): cond_shape = [3, 2, 2, 4] a_shape = [2, 2, 1] @@ -433,7 +490,9 @@ def test_where_condition(self): class TestWhereOpError(unittest.TestCase): def test_errors(self): - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype('float64') y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype('float64') cond_i = np.array([False, False, True, True]).astype('bool') @@ -443,6 +502,12 @@ def test_Variable(): self.assertRaises(TypeError, test_Variable) + def test_OpResult(): + with paddle.pir_utils.IrGuard(): + paddle.where(cond_i, x_i, y_i) + + self.assertRaises(ValueError, test_OpResult) + def test_type(): x = paddle.static.data(name='x', shape=[-1, 4], dtype='bool') x.desc.set_need_check_feed(False) diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index c05b62b7f7ac4..231fb0bed32f9 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import unittest import numpy as np @@ -23,10 +24,14 @@ from paddle.base.backward import append_backward from paddle.base.framework import Program, program_guard +sys.path.append("../dygraph_to_static") +from dygraph_to_static_utils_new import compare_legacy_with_pir + paddle.enable_static() class TestApiWhileLoop(unittest.TestCase): + @compare_legacy_with_pir def test_var_tuple(self): def cond(i): return paddle.less_than(i, ten) @@ -55,6 +60,7 @@ def body(i): np.asarray(res[0]), np.full(1, 10, np.int64), rtol=1e-05 ) + # @compare_legacy_with_pir def test_var_list(self): def cond(i, mem): return paddle.less_than(i, ten) @@ -91,6 +97,7 @@ def body(i, mem): data = np.add(data, data_one) np.testing.assert_allclose(np.asarray(res[1]), data, rtol=1e-05) + @compare_legacy_with_pir def test_var_dict(self): def cond(i, ten, test_dict, test_list, test_list_dict): return paddle.less_than(i, ten) @@ -175,6 +182,7 @@ def body(i, ten, test_dict, test_list, test_list_dict): class TestApiWhileLoop_Nested(unittest.TestCase): + # @compare_legacy_with_pir def test_nested_net(self): def external_cond(i, j, init, sums): return paddle.less_than(i, loop_len1) @@ -245,6 +253,7 @@ def internal_body(j, init, sums): class TestApiWhileLoop_Backward(unittest.TestCase): + # TODO(zhangbo): Support while grad exe for pir def test_while_loop_backward(self): def cond(i, x): return paddle.less_than(i, eleven) @@ -292,6 +301,7 @@ def body(i, x): np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05) np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05) + # TODO(zhangbo): Support while grad exe for pir def test_while_loop_backward2(self): def cond(i, x): return i < 3 @@ -337,6 +347,7 @@ def body(i, x): class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase): + # TODO(zhangbo): Support while grad exe for pir def test_nested_net_with_backward_and_lodtensor(self): def external_cond(i, j, x, mem_array): return paddle.less_than(i, array_len) @@ -425,6 +436,7 @@ def internal_body(j, x, mem_array): class TestApiWhileLoopWithSwitchCase(unittest.TestCase): + # @compare_legacy_with_pir def test_with_switch_case(self): def cond(i): return paddle.less_than(i, ten) @@ -474,6 +486,7 @@ def fn_add_one(): class TestApiWhileLoop_Error(unittest.TestCase): + @compare_legacy_with_pir def test_error(self): def cond_returns_constant(i): return 1 @@ -642,6 +655,7 @@ def value_error_body_returns_with_mutable_list(): class TestApiWhileLoopSliceInBody(unittest.TestCase): + # @compare_legacy_with_pir def test_var_slice(self): def cond(z, i): return i + 1 <= x_shape[0] diff --git a/test/legacy_test/test_while_op.py b/test/legacy_test/test_while_op.py index 3f12fa397a3a8..766c23dbdceb0 100644 --- a/test/legacy_test/test_while_op.py +++ b/test/legacy_test/test_while_op.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import unittest import numpy @@ -23,6 +24,9 @@ from paddle.base.executor import Executor from paddle.incubate.layers.nn import shuffle_batch +sys.path.append("../dygraph_to_static") +from dygraph_to_static_utils_new import compare_legacy_with_pir + paddle.enable_static() @@ -63,7 +67,6 @@ def simple_net(self): i = paddle.increment(x=i) paddle.tensor.array_write(result, i=i, array=mem_array) - paddle.assign(paddle.less_than(x=i, y=array_len), cond) with while_op2.block(): d2 = paddle.tensor.array_read(array=data_array, i=j) @@ -73,10 +76,13 @@ def simple_net(self): j = paddle.increment(x=j) paddle.tensor.array_write(result2, i=j, array=mem_array) paddle.assign(paddle.less_than(x=j, y=array_len2), cond2) + + paddle.assign(paddle.less_than(x=i, y=array_len), cond) sum_result = paddle.tensor.array_read(array=mem_array, i=j) loss = paddle.mean(sum_result) return loss, sum_result + # TODO(zhangbo): Support pir test(support write_to_array and read_from_array, support while_grad). def test_simple_net(self): main_program = base.Program() startup_program = base.Program() @@ -98,13 +104,13 @@ def test_simple_net(self): ) self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01) + # TODO(zhangbo): Support pir test(support write_to_array and read_from_array) def test_simple_net_forward(self): main_program = base.Program() startup_program = base.Program() with base.program_guard(main_program, startup_program): self.simple_net() binary = base.compiler.CompiledProgram(main_program) - cpu = core.CPUPlace() exe = Executor(cpu) d = [] @@ -115,6 +121,7 @@ def test_simple_net_forward(self): for _ in range(2): exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]}) + @compare_legacy_with_pir def test_exceptions(self): i = paddle.zeros(shape=[2], dtype='int64') array_len = paddle.tensor.fill_constant( @@ -129,6 +136,7 @@ def test_exceptions(self): class BadInputTest(unittest.TestCase): + @compare_legacy_with_pir def test_error(self): with base.program_guard(base.Program()): @@ -184,6 +192,7 @@ def body_func(i, ten, batch_info, origin_seq): class TestOutputsMustExistsInputs(unittest.TestCase): + @compare_legacy_with_pir def test_outputs_exists_inputs(self): """ We guarantee that the output tensor must be in the input tensor, so that the output and input can correspond to each other, but the input can be greater than the number of outputs. It's required in paddle2onnx. diff --git a/test/mkldnn/test_activation_mkldnn_op.py b/test/mkldnn/test_activation_mkldnn_op.py index e6ef8388f771d..d37cea47450c7 100644 --- a/test/mkldnn/test_activation_mkldnn_op.py +++ b/test/mkldnn/test_activation_mkldnn_op.py @@ -482,6 +482,14 @@ def setUp(self): self.outputs = {'Out': out} self.attrs = {"use_mkldnn": True} + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out', check_pir=True) + class TestMKLDNNRound_ZeroDim(TestActivation_ZeroDim): def setUp(self): @@ -494,6 +502,14 @@ def setUp(self): self.outputs = {'Out': out} self.attrs = {"use_mkldnn": True} + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out', check_pir=True) + class TestMKLDNNSigmoidDim4(TestSigmoid): def setUp(self): diff --git a/test/mkldnn/test_conv2d_transpose_mkldnn_op.py b/test/mkldnn/test_conv2d_transpose_mkldnn_op.py index 55fdbefe16c0a..f5b8a40714d4b 100644 --- a/test/mkldnn/test_conv2d_transpose_mkldnn_op.py +++ b/test/mkldnn/test_conv2d_transpose_mkldnn_op.py @@ -19,6 +19,7 @@ from test_conv2d_transpose_op import TestConv2DTransposeOp from paddle import enable_static +from paddle.base import core def conv2d_bias_naive(out, bias): @@ -39,6 +40,18 @@ def test_check_grad_no_input(self): def test_check_grad_no_filter(self): return + def test_check_output(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + if self.use_cudnn: + place = core.CUDAPlace(0) + self.check_output_with_place( + place, + atol=1e-5, + check_dygraph=(not self.use_mkldnn), + ) + else: + self.check_output(check_dygraph=(not self.use_mkldnn)) + def init_op_type(self): self.data_format = "NCHW" self.op_type = "conv2d_transpose" diff --git a/test/prim/model/bert.py b/test/prim/model/bert.py index f7cf05f7ca243..fe54de520f88f 100644 --- a/test/prim/model/bert.py +++ b/test/prim/model/bert.py @@ -251,7 +251,7 @@ def __init__(self, config: BertConfig, to_static, enable_cinn): if enable_cinn: build_strategy.build_cinn_pass = True self.encoder = paddle.jit.to_static( - self.encoder, None, build_strategy + self.encoder, None, build_strategy, full_graph=True ) self.pooler = BertPooler(config) # self.apply(self.init_weights) diff --git a/test/prim/model/test_prim_simplenet_cinn.py b/test/prim/model/test_prim_simplenet_cinn.py index 6482e849560e0..06b5085ae7729 100644 --- a/test/prim/model/test_prim_simplenet_cinn.py +++ b/test/prim/model/test_prim_simplenet_cinn.py @@ -26,7 +26,9 @@ def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) class PrimeNet(paddle.nn.Layer): diff --git a/test/prim/model/test_resnet_cinn.py b/test/prim/model/test_resnet_cinn.py index ef932603f8a58..7734f9da60909 100644 --- a/test/prim/model/test_resnet_cinn.py +++ b/test/prim/model/test_resnet_cinn.py @@ -185,7 +185,9 @@ def train(to_static, enable_prim, enable_cinn): build_strategy = paddle.static.BuildStrategy() if enable_cinn: build_strategy.build_cinn_pass = True - resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy) + resnet = paddle.jit.to_static( + resnet, build_strategy=build_strategy, full_graph=True + ) optimizer = optimizer_setting(parameter_list=resnet.parameters()) train_losses = run(resnet, data_loader, optimizer, 'train') diff --git a/test/prim/model/test_resnet_prim.py b/test/prim/model/test_resnet_prim.py index de81f2b78b650..e3e2d859fa4b6 100644 --- a/test/prim/model/test_resnet_prim.py +++ b/test/prim/model/test_resnet_prim.py @@ -186,7 +186,9 @@ def train(to_static, enable_prim, enable_cinn): build_strategy = paddle.static.BuildStrategy() if enable_cinn: build_strategy.build_cinn_pass = True - resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy) + resnet = paddle.jit.to_static( + resnet, build_strategy=build_strategy, full_graph=True + ) optimizer = optimizer_setting(parameter_list=resnet.parameters()) train_losses = run(resnet, data_loader, optimizer, 'train') diff --git a/test/prim/model/test_resnet_prim_cinn.py b/test/prim/model/test_resnet_prim_cinn.py index 933da8fcf105c..5ebf0684259cc 100644 --- a/test/prim/model/test_resnet_prim_cinn.py +++ b/test/prim/model/test_resnet_prim_cinn.py @@ -186,7 +186,9 @@ def train(to_static, enable_prim, enable_cinn): build_strategy = paddle.static.BuildStrategy() if enable_cinn: build_strategy.build_cinn_pass = True - resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy) + resnet = paddle.jit.to_static( + resnet, build_strategy=build_strategy, full_graph=True + ) optimizer = optimizer_setting(parameter_list=resnet.parameters()) train_losses = run(resnet, data_loader, optimizer, 'train') diff --git a/test/prim/pir_prim/CMakeLists.txt b/test/prim/pir_prim/CMakeLists.txt index c31e7254ff60c..cb8a1269b808e 100644 --- a/test/prim/pir_prim/CMakeLists.txt +++ b/test/prim/pir_prim/CMakeLists.txt @@ -1,22 +1,22 @@ -set(TEST_PRIM_PURE_NEW_IR_CASES +set(TEST_PRIM_PURE_PIR_CASES test_prim_program test_prim_simpnet test_prim_custom_vjp test_prim_jit - test_pir_prim_flags) + test_pir_prim_flags test_sink_decomp) -foreach(target ${TEST_PRIM_PURE_NEW_IR_CASES}) +foreach(target ${TEST_PRIM_PURE_PIR_CASES}) py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1 FLAGS_enable_pir_api=true) endforeach() file( - GLOB TEST_PRIM_TRANS_NEW_IR_CASES + GLOB TEST_PRIM_TRANS_PIR_CASES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") -string(REPLACE ".py" "" TEST_PRIM_TRANS_NEW_IR_CASES - "${TEST_PRIM_TRANS_NEW_IR_CASES}") +string(REPLACE ".py" "" TEST_PRIM_TRANS_PIR_CASES + "${TEST_PRIM_TRANS_PIR_CASES}") -list(REMOVE_ITEM TEST_PRIM_TRANS_NEW_IR_CASES ${TEST_PRIM_PURE_NEW_IR_CASES}) +list(REMOVE_ITEM TEST_PRIM_TRANS_PIR_CASES ${TEST_PRIM_PURE_PIR_CASES}) -foreach(target ${TEST_PRIM_TRANS_NEW_IR_CASES}) +foreach(target ${TEST_PRIM_TRANS_PIR_CASES}) py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1 - FLAGS_enable_new_ir_in_executor=true) + FLAGS_enable_pir_in_executor=true) endforeach() diff --git a/test/prim/pir_prim/test_custom_vjp_trait.py b/test/prim/pir_prim/test_custom_vjp_trait.py index 273bd02a2ba76..f5faa9f6e46b5 100644 --- a/test/prim/pir_prim/test_custom_vjp_trait.py +++ b/test/prim/pir_prim/test_custom_vjp_trait.py @@ -21,7 +21,7 @@ paddle.enable_static() -def get_gelu_program_new_ir(): +def get_gelu_program_pir(): main_program, start_program = ( paddle.static.Program(), paddle.static.Program(), @@ -30,11 +30,11 @@ def get_gelu_program_new_ir(): x = paddle.static.data('x', [2, 3, 3], dtype='float32') net = nn.GELU() out = net(x) - newir_program = pir.translate_to_new_ir(main_program.desc) - return newir_program + pir_program = pir.translate_to_pir(main_program.desc) + return pir_program -def get_multiply_program_new_ir(): +def get_multiply_program_pir(): main_program, start_program = ( paddle.static.Program(), paddle.static.Program(), @@ -43,20 +43,20 @@ def get_multiply_program_new_ir(): x = paddle.static.data('x', [2, 3, 3], dtype='float32') y = paddle.static.data('y', [2, 3, 3], dtype='float32') out = paddle.multiply(x, y) - newir_program = pir.translate_to_new_ir(main_program.desc) - return newir_program + pir_program = pir.translate_to_pir(main_program.desc) + return pir_program class TestCustomVjpTrait(unittest.TestCase): def test_gelu_op_custom_vjp_trait(self): - newir_program = get_gelu_program_new_ir() - op = newir_program.global_block().ops[-1] + pir_program = get_gelu_program_pir() + op = pir_program.global_block().ops[-1] self.assertEqual(op.name(), "pd_op.gelu") self.assertEqual(has_custom_vjp(op), True) def test_multiply_op_custom_vjp_trait(self): - newir_program = get_multiply_program_new_ir() - op = newir_program.global_block().ops[-1] + pir_program = get_multiply_program_pir() + op = pir_program.global_block().ops[-1] self.assertEqual(op.name(), "pd_op.multiply") self.assertEqual(has_custom_vjp(op), False) diff --git a/test/prim/pir_prim/test_decomp_op.py b/test/prim/pir_prim/test_decomp_op.py index 3a70ea3389272..949a9acb9e629 100644 --- a/test/prim/pir_prim/test_decomp_op.py +++ b/test/prim/pir_prim/test_decomp_op.py @@ -36,26 +36,24 @@ def get_ir_program(): y_s = paddle.add(x_s, y_s) y_s = paddle.mean(y_s) y_s = paddle.tanh(y_s) - newir_program = pir.translate_to_new_ir(main_program.desc) - return newir_program + pir_program = pir.translate_to_pir(main_program.desc) + return pir_program class TestBuildOp(unittest.TestCase): def test_build_op(self): - newir_program = get_ir_program() - y = newir_program.global_block().ops[-2].results() + pir_program = get_ir_program() + y = pir_program.global_block().ops[-2].results() orig_shape = y[0].shape with paddle.pir_utils.IrGuard(): core._set_prim_forward_enabled(True) - y_new = decompose(newir_program, y) + y_new = decompose(pir_program, y) core._set_prim_forward_enabled(False) new_shape = y_new[0].shape assert ( orig_shape == new_shape ), f"Original shape {orig_shape} is not equal to new shape {new_shape}" - op_name_list = [ - op.name() for op in newir_program.global_block().ops - ] + op_name_list = [op.name() for op in pir_program.global_block().ops] self.assertEqual( op_name_list, [ diff --git a/test/prim/pir_prim/test_decompose_op.py b/test/prim/pir_prim/test_decompose_op.py new file mode 100644 index 0000000000000..791f3fdeed945 --- /dev/null +++ b/test/prim/pir_prim/test_decompose_op.py @@ -0,0 +1,265 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +import paddle +from paddle import pir +from paddle.base import core +from paddle.decomposition import decomp + +paddle.enable_static() + + +def check_param_mappings(param_mappings): + for VarDesc, Values in param_mappings.items(): + if len(Values) < 0 or len(Values) > 1: + raise ValueError("currently only support one-to-one param_mappings") + + +def get_pir_grad_var_to_var_map(param_mappings, old_ir_grad_var_to_var_map): + pir_grad_var_to_var_map = {} + for grad_var, var in old_ir_grad_var_to_var_map.items(): + if grad_var in param_mappings.keys(): + new_grad_var = param_mappings[grad_var][0] + new_var = param_mappings[var][0] + pir_grad_var_to_var_map[new_grad_var] = new_var + return pir_grad_var_to_var_map + + +def get_fwd_op(bwd_op, grad_var_to_var_map): + bwd_op_input_names = bwd_op.get_input_names() + for idx, input_name in enumerate(bwd_op_input_names): + if input_name == "out_grad": + out_grad = bwd_op.operand(idx).source() + out = grad_var_to_var_map[out_grad] + fwd_op = out.get_defining_op() + return fwd_op + + return None + + +def get_pir_program_and_param_map(): + shape = [2, 3] + mp = paddle.static.Program() + with paddle.static.program_guard(mp): + # construct graph + x = paddle.static.data('x', shape, dtype='float32') + x.stop_gradient = False + y = paddle.static.data('y', shape, dtype='float32') + y.stop_gradient = False + z = paddle.static.data('z', shape, dtype='float32') + z.stop_gradient = False + tmp1 = paddle.add(x, y) + tmp2 = paddle.multiply(tmp1, z) + tmp3 = paddle.mean(tmp2, axis=-1, keepdim=True) + tmp4 = paddle.rsqrt(tmp3) + scale = paddle.tensor.fill_constant( + shape=tmp4.shape[1:], + dtype=tmp4.dtype, + value=1.0, + ) + scale.stop_gradient = True + tmp5 = paddle.nn.functional.layer_norm( + tmp4, tmp4.shape[1:], scale, None, 1e-5 + ) + tmp6 = paddle.nn.functional.dropout(tmp5, p=0.5) + out = paddle.add(x, tmp6) + # construct backward graph + gradients = paddle.static.gradients(out, [x, y, z]) + + pir_program, param_mappings = pir.translate_to_pir_with_param_map(mp.desc) + check_param_mappings(param_mappings) + + return pir_program, param_mappings + + +class TestDecomposeOp(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.shape_x = [2, 3] + self.x = np.random.random(self.shape_x).astype("float32") + self.shape_y = [2, 3] + self.y = np.random.random(self.shape_y).astype("float32") + self.shape_z = [2, 3] + self.z = np.random.random(self.shape_z).astype("float32") + + def net(self, flag=None): + ( + pir_program, + param_mappings, + ) = get_pir_program_and_param_map() + + pir_ops = pir_program.global_block().ops + global_outputs = [pir_ops[9].result(0)] + global_grads = [ + pir_ops[-1].result(0), + pir_ops[-3].result(1), + pir_ops[-4].result(1), + ] + + with paddle.pir_utils.IrGuard(), paddle.pir.core.program_guard( + pir_program + ): + if flag == "decompose": + core._set_prim_forward_enabled(True) + core._set_prim_backward_enabled(True) + + # get the old_ir_grad_var_to_var map + old_ir_grad_var_to_var_map = { + 'dropout_1.tmp_0@GRAD': 'dropout_1.tmp_0', + 'elementwise_add_2@GRAD': 'elementwise_add_2', + 'elementwise_add_3@GRAD': 'elementwise_add_3', + 'elementwise_mul_1@GRAD': 'elementwise_mul_1', + 'layer_norm_1.tmp_2@GRAD': 'layer_norm_1.tmp_2', + 'rsqrt_1.tmp_0@GRAD': 'rsqrt_1.tmp_0', + 'mean_1.tmp_0@GRAD': 'mean_1.tmp_0', + 'x@GRAD': 'x', + 'x@GRAD@RENAME@block0@0': 'x', + 'x@GRAD@RENAME@block0@1': 'x', + 'y@GRAD': 'y', + 'z@GRAD': 'z', + } + grad_var_to_var_map = get_pir_grad_var_to_var_map( + param_mappings, old_ir_grad_var_to_var_map + ) + # get global outputs and grads info, when decomposing an op that corresponds to global outputs and grads, then update the global outputs and grads + ( + fwd_leaf_ops, + fwd_leaf_ops_output_indexes, + ) = decomp.get_leaf_ops( + pir_program.global_block(), global_outputs + ) # without update during execution + ( + bwd_leaf_ops, + bwd_leaf_ops_output_indexes, + ) = decomp.get_leaf_ops( + pir_program.global_block(), global_grads + ) + + bwd_ops_to_be_decomposed = [ + "pd_op.layer_norm_grad", + "pd_op.dropout_grad", + "pd_op.mean_grad", + "pd_op.add_grad", + "pd_op.multiply_grad", + "pd_op.rsqrt_grad", + ] + for bwd_op in pir_ops: + if ( + flag == "decompose" + and bwd_op.name() in bwd_ops_to_be_decomposed + ): + fwd_op = get_fwd_op(bwd_op, grad_var_to_var_map) + assert fwd_op is not None, "fwd_op is None" + + bwd_leaf_op_index = ( + bwd_leaf_ops.index(bwd_op) + if bwd_op in bwd_leaf_ops + else None + ) + ( + new_grads, + bwd_has_decomposed, + ) = decomp.decompose_bwd_op_directly( + pir_program.global_block(), + fwd_op, + bwd_op, + grad_var_to_var_map, + ) + if bwd_has_decomposed: + if bwd_leaf_op_index is not None: + decomp.replace_graph_outputs( + global_grads, + new_grads, + bwd_leaf_op_index, + bwd_leaf_ops_output_indexes, + ) + + else: + fwd_leaf_op_index = ( + fwd_leaf_ops.index(fwd_op) + if fwd_op in fwd_leaf_ops + else None + ) + fwd_inputs = [x.source() for x in fwd_op.operands()] + ( + new_fwd_outputs, + fwd_has_decomposed, + ) = decomp.decompose_fwd_op( + pir_program.global_block(), + fwd_op, + grad_var_to_var_map, + ) + if fwd_has_decomposed: + if fwd_leaf_op_index is not None: + decomp.replace_graph_outputs( + global_outputs, + new_fwd_outputs, + fwd_leaf_op_index, + fwd_leaf_ops_output_indexes, + ) + + bwd_leaf_op_index = ( + bwd_leaf_ops.index(bwd_op) + if bwd_op in bwd_leaf_ops + else None + ) + new_grads = ( + decomp.decompose_bwd_op_after_fwd_op( + pir_program.global_block(), + fwd_op, + bwd_op, + grad_var_to_var_map, + fwd_inputs, + new_fwd_outputs, + ) + ) + if bwd_leaf_op_index is not None: + decomp.replace_graph_outputs( + global_grads, + new_grads, + bwd_leaf_op_index, + bwd_leaf_ops_output_indexes, + ) + + # execution + exe = paddle.static.Executor() + outs = exe.run( + pir_program, + feed={'x': self.x, 'y': self.y, 'z': self.z}, + fetch_list=[ + global_outputs[0], + global_grads[0], + global_grads[1], + global_grads[2], + ], + ) + core._set_prim_backward_enabled(False) + core._set_prim_forward_enabled(False) + + return outs + + def test_decompose_layer_norm_op(self): + res_ref = self.net() + res = self.net("decompose") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/prim/pir_prim/test_pir_prim_flags.py b/test/prim/pir_prim/test_pir_prim_flags.py index 4bee4da74a4d1..0305274011b50 100644 --- a/test/prim/pir_prim/test_pir_prim_flags.py +++ b/test/prim/pir_prim/test_pir_prim_flags.py @@ -102,16 +102,19 @@ def train(self): x = paddle.randn([2, 4]) x.stop_gradient = False net = PrimeNet() - net = paddle.jit.to_static(net) + net.forward = paddle.jit.to_static(full_graph=True)(net.forward) out = net(x) loss = paddle.mean(out) loss.backward() self.check_prim(net) def check_prim(self, net): - block = net.forward.program_cache.last()[-1][ - -1 - ].train_program.global_block() + program = net.forward.program_cache.last()[-1][-1].train_program + if isinstance( + program, paddle.jit.dy2static.pir_partial_program.RunableProgram + ): + program = program.program + block = program.global_block() ops = [op.name() for op in block.ops] self.assertTrue('pd_op.tanh_grad' in ops) self.assertTrue('pd_op.exp_grad' in ops) diff --git a/test/prim/pir_prim/test_prim_jit.py b/test/prim/pir_prim/test_prim_jit.py index 72958eff9a1d7..61a12e05d7de4 100644 --- a/test/prim/pir_prim/test_prim_jit.py +++ b/test/prim/pir_prim/test_prim_jit.py @@ -20,7 +20,7 @@ from paddle.framework import core -class TestDy2staticNewIR(unittest.TestCase): +class TestDy2staticPir(unittest.TestCase): def test_basic_network_backward(self): core._set_prim_all_enabled(True) @@ -30,7 +30,7 @@ def func(x): return out # ==== dygraph computation ==== - static_func = paddle.jit.to_static(func) + static_func = paddle.jit.to_static(func, full_graph=True) x = paddle.randn((8, 16, 64)) x.stop_gradient = False ref_out = func(x) * 2 diff --git a/test/prim/pir_prim/test_sink_decomp.py b/test/prim/pir_prim/test_sink_decomp.py new file mode 100644 index 0000000000000..e9154eba60976 --- /dev/null +++ b/test/prim/pir_prim/test_sink_decomp.py @@ -0,0 +1,153 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle.autograd.ir_backward import grad +from paddle.base import core +from paddle.decomposition import decompose + +paddle.enable_static() + + +class TestPrimMode(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.shape_x = [8, 16, 32, 64] + self.shape_y = [8, 16, 32, 64] + self.x = np.random.random(self.shape_x).astype("float32") + self.y = np.random.random(self.shape_y).astype("float32") + self.prog = None + + def base_net(self, flag=None): + if flag == "forward": + core._set_prim_forward_enabled(True) + elif flag == "backward": + core._set_prim_backward_enabled(True) + elif flag == "all": + core._set_prim_all_enabled(True) + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data('x', self.shape_x, dtype='float32') + y = paddle.static.data('y', self.shape_y, dtype='float32') + x.stop_gradient = False + y.stop_gradient = False + divide_out = paddle.divide(x, y) + sum_out = paddle.mean(divide_out, axis=0) + [new_out] = decompose(main_program, [sum_out]) + gradients = grad(new_out, (x, y)) + + exe = paddle.static.Executor() + [fwd, dx, dy] = exe.run( + feed={'x': self.x, 'y': self.y}, fetch_list=[new_out, gradients] + ) + + whole_ops = [op.name() for op in main_program.global_block().ops] + self.prog = main_program + if flag == "forward": + core._set_prim_forward_enabled(False) + assert ( + 'pd_op.mean' not in whole_ops + and 'pd_op.divide_grad' in whole_ops + ) + elif flag == "backward": + core._set_prim_backward_enabled(False) + assert ( + 'pd_op.mean' in whole_ops + and 'pd_op.divide_grad' not in whole_ops + ) + elif flag == "all": + core._set_prim_all_enabled(False) + assert ( + 'pd_op.mean' not in whole_ops + and 'pd_op.divide_grad' not in whole_ops + ) + else: + assert ( + 'pd_op.mean' in whole_ops and 'pd_op.divide_grad' in whole_ops + ) + return fwd, dx, dy + + def test_prim_forward(self): + res_ref = self.base_net() + res = self.base_net("forward") + for ref, actual in zip(res_ref, res): + np.testing.assert_equal(ref, actual) + + def test_prim_backward(self): + res_ref = self.base_net() + res = self.base_net("backward") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + + def test_prim_all(self): + res_ref = self.base_net() + res = self.base_net("all") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + + def test_has_decomp(self): + _ = self.base_net() + for op in self.prog.global_block().ops: + if op.name() == "pd_op.divide": + self.assertEqual(core.has_decomp(op), False) + if op.name() == "pd_op.mean": + self.assertEqual(core.has_decomp(op), True) + + +class TestReluSink(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.shape_x = [8, 16, 32, 64] + self.x = np.random.random(self.shape_x).astype("float32") + self.prog = None + + def base_net(self, flag=None): + if flag == "forward": + core._set_prim_forward_enabled(True) + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data('x', self.shape_x, dtype='float32') + x.stop_gradient = False + sum_out = F.relu(x) + [new_out] = decompose(main_program, [sum_out]) + gradients = grad(new_out, x) + + exe = paddle.static.Executor() + [fwd, dx] = exe.run( + feed={'x': self.x}, fetch_list=[new_out, gradients] + ) + + whole_ops = [op.name() for op in main_program.global_block().ops] + self.prog = main_program + if flag == "forward": + core._set_prim_forward_enabled(False) + assert 'pd_op.relu' not in whole_ops + else: + assert 'pd_op.relu' in whole_ops + return fwd, dx + + def test_relu_forward(self): + res_ref = self.base_net() + res = self.base_net("forward") + for ref, actual in zip(res_ref, res): + np.testing.assert_equal(ref, actual) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/prim/pir_prim/test_vjp_prim.py b/test/prim/pir_prim/test_vjp_prim.py index 2755f2854487f..3a2075650f590 100644 --- a/test/prim/pir_prim/test_vjp_prim.py +++ b/test/prim/pir_prim/test_vjp_prim.py @@ -39,8 +39,8 @@ def get_ir_divide_program(): ) dout.stop_gradient = False out = paddle.divide(x, y) - newir_program = pir.translate_to_new_ir(main_program.desc) - return newir_program + pir_program = pir.translate_to_pir(main_program.desc) + return pir_program def get_ir_sum_program(): @@ -57,25 +57,31 @@ def get_ir_sum_program(): dout = paddle.tensor.fill_constant(shape=[], dtype='float32', value=1.0) dout.stop_gradient = False out = paddle.sum(x) - newir_program = pir.translate_to_new_ir(main_program.desc) - return newir_program + pir_program = pir.translate_to_pir(main_program.desc) + return pir_program class TestVjpPrim(unittest.TestCase): def test_divide_grad_prim_case1(self): - newir_program = get_ir_divide_program() + pir_program = get_ir_divide_program() paddle.framework.core._set_prim_backward_enabled(True) with paddle.pir_utils.IrGuard(): - dout = newir_program.global_block().ops[-2].result(0) + dout = pir_program.global_block().ops[-2].result(0) out_grads = [[dout]] stop_gradients = [[False], [False]] - divide_op = newir_program.global_block().ops[-1] - with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(divide_op, out_grads, stop_gradients) - reshape_op2 = newir_program.global_block().ops[-1] - reshape_op1 = newir_program.global_block().ops[-8] + divide_op = pir_program.global_block().ops[-1] + with paddle.pir.core.program_guard(pir_program): + grad_outs = call_vjp( + divide_op, + [[value] for value in divide_op.operands_source()], + [[value] for value in divide_op.results()], + out_grads, + stop_gradients, + ) + reshape_op2 = pir_program.global_block().ops[-1] + reshape_op1 = pir_program.global_block().ops[-8] self.assertEqual(len(grad_outs), 2) - self.assertEqual(len(newir_program.global_block().ops), 21) + self.assertEqual(len(pir_program.global_block().ops), 21) self.assertEqual(reshape_op2.result(0), grad_outs[0][0]) self.assertEqual(reshape_op1.result(0), grad_outs[1][0]) all_op_names = [ @@ -101,19 +107,25 @@ def test_divide_grad_prim_case1(self): "pd_op.full_int_array", "pd_op.reshape", ] - for idx, op in enumerate(newir_program.global_block().ops): + for idx, op in enumerate(pir_program.global_block().ops): self.assertEqual(op.name(), all_op_names[idx]) paddle.framework.core._set_prim_backward_enabled(False) def test_divide_grad_no_prim(self): - newir_program = get_ir_divide_program() + pir_program = get_ir_divide_program() paddle.framework.core._set_prim_backward_enabled(False) - dout = newir_program.global_block().ops[-2].result(0) + dout = pir_program.global_block().ops[-2].result(0) out_grads = [[dout]] stop_gradients = [[False], [False]] - divide_op = newir_program.global_block().ops[-1] - with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(divide_op, out_grads, stop_gradients) + divide_op = pir_program.global_block().ops[-1] + with paddle.pir.core.program_guard(pir_program): + grad_outs = call_vjp( + divide_op, + [[value] for value in divide_op.operands_source()], + [[value] for value in divide_op.results()], + out_grads, + stop_gradients, + ) self.assertEqual(len(grad_outs), 2) self.assertEqual( grad_outs[0][0].get_defining_op().name(), "pd_op.divide_grad" @@ -121,21 +133,27 @@ def test_divide_grad_no_prim(self): self.assertEqual( grad_outs[1][0].get_defining_op().name(), "pd_op.divide_grad" ) - self.assertEqual(len(newir_program.global_block().ops), 5) + self.assertEqual(len(pir_program.global_block().ops), 5) def test_sum_grad_prim(self): - newir_program = get_ir_sum_program() + pir_program = get_ir_sum_program() paddle.framework.core._set_prim_backward_enabled(True) with paddle.pir_utils.IrGuard(): - dout = newir_program.global_block().ops[-3].result(0) + dout = pir_program.global_block().ops[-3].result(0) out_grads = [[dout]] stop_gradients = [[False]] - sum_op = newir_program.global_block().ops[-1] - with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(sum_op, out_grads, stop_gradients) - expand_op = newir_program.global_block().ops[-1] + sum_op = pir_program.global_block().ops[-1] + with paddle.pir.core.program_guard(pir_program): + grad_outs = call_vjp( + sum_op, + [[value] for value in sum_op.operands_source()], + [[value] for value in sum_op.results()], + out_grads, + stop_gradients, + ) + expand_op = pir_program.global_block().ops[-1] self.assertEqual(len(grad_outs), 1) - self.assertEqual(len(newir_program.global_block().ops), 8) + self.assertEqual(len(pir_program.global_block().ops), 8) self.assertEqual(expand_op.result(0), grad_outs[0][0]) all_op_names = [ "pd_op.full", @@ -147,24 +165,30 @@ def test_sum_grad_prim(self): "pd_op.full_int_array", "pd_op.expand", ] - for idx, op in enumerate(newir_program.global_block().ops): + for idx, op in enumerate(pir_program.global_block().ops): self.assertEqual(op.name(), all_op_names[idx]) paddle.framework.core._set_prim_backward_enabled(False) def test_sum_grad_no_prim(self): - newir_program = get_ir_sum_program() + pir_program = get_ir_sum_program() paddle.framework.core._set_prim_backward_enabled(False) - dout = newir_program.global_block().ops[-2].result(0) + dout = pir_program.global_block().ops[-2].result(0) out_grads = [[dout]] stop_gradients = [[False]] - sum_op = newir_program.global_block().ops[-1] - with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(sum_op, out_grads, stop_gradients) + sum_op = pir_program.global_block().ops[-1] + with paddle.pir.core.program_guard(pir_program): + grad_outs = call_vjp( + sum_op, + [[value] for value in sum_op.operands_source()], + [[value] for value in sum_op.results()], + out_grads, + stop_gradients, + ) self.assertEqual(len(grad_outs), 1) self.assertEqual( grad_outs[0][0].get_defining_op().name(), "pd_op.sum_grad" ) - self.assertEqual(len(newir_program.global_block().ops), 5) + self.assertEqual(len(pir_program.global_block().ops), 5) if __name__ == "__main__": diff --git a/test/prim/prim/flags/test_prim_flags.py b/test/prim/prim/flags/test_prim_flags.py index c1164a5e626e4..9f6c84577697c 100644 --- a/test/prim/prim/flags/test_prim_flags.py +++ b/test/prim/prim/flags/test_prim_flags.py @@ -153,7 +153,7 @@ def train(self): x = paddle.randn([2, 4]) x.stop_gradient = False net = PrimeNet() - net = paddle.jit.to_static(net) + net = paddle.jit.to_static(net, full_graph=True) out = net(x) loss = paddle.mean(out) diff --git a/test/prim/prim/flags/test_prim_flags_case.py b/test/prim/prim/flags/test_prim_flags_case.py index 126c15de81fe2..a565732683821 100644 --- a/test/prim/prim/flags/test_prim_flags_case.py +++ b/test/prim/prim/flags/test_prim_flags_case.py @@ -23,7 +23,9 @@ def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) class PrimeNet(paddle.nn.Layer): diff --git a/test/prim/process/test_check_inputs.py b/test/prim/process/test_check_inputs.py index 631da96cc8b23..b844f52ea81d8 100644 --- a/test/prim/process/test_check_inputs.py +++ b/test/prim/process/test_check_inputs.py @@ -43,7 +43,7 @@ def test_error_input(self): np_data = np.random.random([3, 4]).astype("float32") tensor_data = paddle.to_tensor(np_data) shape = paddle.to_tensor([2, 3, 4]) - net = paddle.jit.to_static(fn) + net = paddle.jit.to_static(fn, full_graph=True) with self.assertRaises(NotImplementedError): _ = net(tensor_data, shape).numpy() core._set_prim_all_enabled(False) diff --git a/test/prim/test_comp_custom_vjp.py b/test/prim/test_comp_custom_vjp.py index fb62fe80202a4..40638bc579cf9 100644 --- a/test/prim/test_comp_custom_vjp.py +++ b/test/prim/test_comp_custom_vjp.py @@ -70,7 +70,7 @@ def test_enable_prim_fwd(self): self.ops_fwd_enable_bwd_disable, tuple( op.type - for op in paddle.jit.to_static(self.f) + for op in paddle.jit.to_static(full_graph=True)(self.f) .get_concrete_program()[1] ._train_program.block(0) .ops @@ -86,7 +86,7 @@ def test_enable_prim_bwd(self): self.ops_fwd_disable_bwd_enable, tuple( op.type - for op in paddle.jit.to_static(self.f) + for op in paddle.jit.to_static(full_graph=True)(self.f) .get_concrete_program()[1] ._train_program.block(0) .ops @@ -101,7 +101,7 @@ def test_enable_prim_all(self): self.ops_all_enable, tuple( op.type - for op in paddle.jit.to_static(self.f) + for op in paddle.jit.to_static(full_graph=True)(self.f) .get_concrete_program()[1] ._train_program.block(0) .ops diff --git a/test/prim/test_comp_dispensable.py b/test/prim/test_comp_dispensable.py index be76ce92ce7f0..9c7d10b645d5e 100644 --- a/test/prim/test_comp_dispensable.py +++ b/test/prim/test_comp_dispensable.py @@ -25,11 +25,10 @@ def tearDown(self): paddle.base.core._set_prim_all_enabled(False) def test_dispensable(self): - @paddle.jit.to_static def f(x): return paddle.split(x, num_or_sections=2) - f = paddle.jit.to_static(f) + f = paddle.jit.to_static(full_graph=True)(f) x = paddle.rand((8,)) x.stop_gradient = False diff --git a/test/rnn/rnn_numpy.py b/test/rnn/rnn_numpy.py index c98b62e3a600a..5371f05bbb040 100644 --- a/test/rnn/rnn_numpy.py +++ b/test/rnn/rnn_numpy.py @@ -38,12 +38,14 @@ def __init__( self, input_size, hidden_size, + weight=True, bias=True, nonlinearity="RNN_TANH", dtype="float64", ): self.input_size = input_size self.hidden_size = hidden_size + self.weight = weight self.bias = bias if nonlinearity == 'RNN_TANH': self.nonlinearity = np.tanh @@ -52,12 +54,16 @@ def __init__( self.parameters = {} std = 1.0 / math.sqrt(hidden_size) - self.weight_ih = np.random.uniform( - -std, std, (hidden_size, input_size) - ).astype(dtype) - self.weight_hh = np.random.uniform( - -std, std, (hidden_size, hidden_size) - ).astype(dtype) + if weight: + self.weight_ih = np.random.uniform( + -std, std, (hidden_size, input_size) + ).astype(dtype) + self.weight_hh = np.random.uniform( + -std, std, (hidden_size, hidden_size) + ).astype(dtype) + else: + self.weight_ih = np.ones((hidden_size, input_size)).astype(dtype) + self.weight_hh = np.ones((hidden_size, hidden_size)).astype(dtype) self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_hh'] = self.weight_hh if bias: @@ -67,11 +73,11 @@ def __init__( self.bias_hh = np.random.uniform(-std, std, (hidden_size,)).astype( dtype ) - self.parameters['bias_ih'] = self.bias_ih - self.parameters['bias_hh'] = self.bias_hh else: - self.bias_ih = None - self.bias_hh = None + self.bias_ih = np.zeros(hidden_size).astype(dtype) + self.bias_hh = np.zeros(hidden_size).astype(dtype) + self.parameters['bias_ih'] = self.bias_ih + self.parameters['bias_hh'] = self.bias_hh def init_state(self, inputs, batch_dim_index=0): batch_size = inputs.shape[batch_dim_index] @@ -92,18 +98,29 @@ def forward(self, inputs, hx=None): class GRUCell(LayerMixin): - def __init__(self, input_size, hidden_size, bias=True, dtype="float64"): + def __init__( + self, input_size, hidden_size, weight=True, bias=True, dtype="float64" + ): self.input_size = input_size self.hidden_size = hidden_size + self.weight = weight self.bias = bias self.parameters = {} std = 1.0 / math.sqrt(hidden_size) - self.weight_ih = np.random.uniform( - -std, std, (3 * hidden_size, input_size) - ).astype(dtype) - self.weight_hh = np.random.uniform( - -std, std, (3 * hidden_size, hidden_size) - ).astype(dtype) + if weight: + self.weight_ih = np.random.uniform( + -std, std, (3 * hidden_size, input_size) + ).astype(dtype) + self.weight_hh = np.random.uniform( + -std, std, (3 * hidden_size, hidden_size) + ).astype(dtype) + else: + self.weight_ih = np.ones((3 * hidden_size, input_size)).astype( + dtype + ) + self.weight_hh = np.ones((3 * hidden_size, hidden_size)).astype( + dtype + ) self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_hh'] = self.weight_hh if bias: @@ -113,11 +130,11 @@ def __init__(self, input_size, hidden_size, bias=True, dtype="float64"): self.bias_hh = np.random.uniform( -std, std, (3 * hidden_size) ).astype(dtype) - self.parameters['bias_ih'] = self.bias_ih - self.parameters['bias_hh'] = self.bias_hh else: - self.bias_ih = None - self.bias_hh = None + self.bias_ih = np.zeros(3 * hidden_size).astype(dtype) + self.bias_hh = np.zeros(3 * hidden_size).astype(dtype) + self.parameters['bias_ih'] = self.bias_ih + self.parameters['bias_hh'] = self.bias_hh def init_state(self, inputs, batch_dim_index=0): batch_size = inputs.shape[batch_dim_index] @@ -148,21 +165,31 @@ def __init__( self, input_size, hidden_size, + weight=True, bias=True, dtype="float64", proj_size=None, ): self.input_size = input_size self.hidden_size = hidden_size + self.weight = weight self.bias = bias self.parameters = {} std = 1.0 / math.sqrt(hidden_size) - self.weight_ih = np.random.uniform( - -std, std, (4 * hidden_size, input_size) - ).astype(dtype) - self.weight_hh = np.random.uniform( - -std, std, (4 * hidden_size, proj_size or hidden_size) - ).astype(dtype) + if weight: + self.weight_ih = np.random.uniform( + -std, std, (4 * hidden_size, input_size) + ).astype(dtype) + self.weight_hh = np.random.uniform( + -std, std, (4 * hidden_size, proj_size or hidden_size) + ).astype(dtype) + else: + self.weight_ih = np.ones((4 * hidden_size, input_size)).astype( + dtype + ) + self.weight_hh = np.ones((4 * hidden_size, hidden_size)).astype( + dtype + ) self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_hh'] = self.weight_hh self.proj_size = proj_size @@ -178,11 +205,11 @@ def __init__( self.bias_hh = np.random.uniform( -std, std, (4 * hidden_size) ).astype(dtype) - self.parameters['bias_ih'] = self.bias_ih - self.parameters['bias_hh'] = self.bias_hh else: - self.bias_ih = None - self.bias_hh = None + self.bias_ih = np.zeros(4 * hidden_size).astype(dtype) + self.bias_hh = np.zeros(4 * hidden_size).astype(dtype) + self.parameters['bias_ih'] = self.bias_ih + self.parameters['bias_hh'] = self.bias_hh def init_state(self, inputs, batch_dim_index=0): batch_size = inputs.shape[batch_dim_index] diff --git a/test/rnn/test_rnn_cells.py b/test/rnn/test_rnn_cells.py index 4bb6f49963f84..4b055fcf45d73 100644 --- a/test/rnn/test_rnn_cells.py +++ b/test/rnn/test_rnn_cells.py @@ -24,8 +24,9 @@ class TestSimpleRNNCell(unittest.TestCase): - def __init__(self, bias=True, place="cpu"): + def __init__(self, weight=True, bias=True, place="cpu"): super().__init__(methodName="runTest") + self.weight = weight self.bias = bias self.place = ( paddle.CPUPlace() if place == "cpu" else paddle.CUDAPlace(0) @@ -33,9 +34,14 @@ def __init__(self, bias=True, place="cpu"): def setUp(self): paddle.disable_static(self.place) - rnn1 = SimpleRNNCell(16, 32, bias=self.bias) + rnn1 = SimpleRNNCell(16, 32, weight=self.weight, bias=self.bias) rnn2 = paddle.nn.SimpleRNNCell( - 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias + 16, + 32, + weight_ih_attr=self.weight, + weight_hh_attr=self.weight, + bias_ih_attr=self.bias, + bias_hh_attr=self.bias, ) convert_params_for_cell(rnn1, rnn2) @@ -76,8 +82,9 @@ def runTest(self): class TestGRUCell(unittest.TestCase): - def __init__(self, bias=True, place="cpu"): + def __init__(self, weight=True, bias=True, place="cpu"): super().__init__(methodName="runTest") + self.weight = weight self.bias = bias self.place = ( paddle.CPUPlace() if place == "cpu" else paddle.CUDAPlace(0) @@ -85,9 +92,14 @@ def __init__(self, bias=True, place="cpu"): def setUp(self): paddle.disable_static(self.place) - rnn1 = GRUCell(16, 32, bias=self.bias) + rnn1 = GRUCell(16, 32, weight=self.weight, bias=self.bias) rnn2 = paddle.nn.GRUCell( - 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias + 16, + 32, + weight_ih_attr=self.weight, + weight_hh_attr=self.weight, + bias_ih_attr=self.bias, + bias_hh_attr=self.bias, ) convert_params_for_cell(rnn1, rnn2) @@ -128,17 +140,23 @@ def runTest(self): class TestLSTMCell(unittest.TestCase): - def __init__(self, bias=True, place="cpu"): + def __init__(self, weight=True, bias=True, place="cpu"): super().__init__(methodName="runTest") + self.weight = weight self.bias = bias self.place = ( paddle.CPUPlace() if place == "cpu" else paddle.CUDAPlace(0) ) def setUp(self): - rnn1 = LSTMCell(16, 32, bias=self.bias) + rnn1 = LSTMCell(16, 32, weight=self.weight, bias=self.bias) rnn2 = paddle.nn.LSTMCell( - 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias + 16, + 32, + weight_ih_attr=self.weight, + weight_hh_attr=self.weight, + bias_ih_attr=self.bias, + bias_hh_attr=self.bias, ) convert_params_for_cell(rnn1, rnn2) @@ -187,8 +205,13 @@ def runTest(self): def load_tests(loader, tests, pattern): suite = unittest.TestSuite() devices = ["cpu", "gpu"] if paddle.base.is_compiled_with_cuda() else ["cpu"] - for bias in [True, False]: - for device in devices: - for test_class in [TestSimpleRNNCell, TestGRUCell, TestLSTMCell]: - suite.addTest(test_class(bias, device)) + for weight in [True, False]: + for bias in [True, False]: + for device in devices: + for test_class in [ + TestSimpleRNNCell, + TestGRUCell, + TestLSTMCell, + ]: + suite.addTest(test_class(weight, bias, device)) return suite diff --git a/test/sot/CMakeLists.txt b/test/sot/CMakeLists.txt new file mode 100644 index 0000000000000..90047650507cf --- /dev/null +++ b/test/sot/CMakeLists.txt @@ -0,0 +1,15 @@ +file( + GLOB TEST_OPS + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") +set(SOT_ENVS SOT_LOG_LEVEL=0 COST_MODEL=False MIN_GRAPH_SIZE=0 STRICT_MODE=True + FLAGS_cudnn_deterministic=True) + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${SOT_ENVS}) +endforeach() + +if(WIN32 AND NOT WITH_GPU) + set_tests_properties(test_sot_resnet50_backward PROPERTIES TIMEOUT 420) +endif() diff --git a/test/sot/test_01_basic.py b/test/sot/test_01_basic.py index 8a03ea9fd3ae5..4a76cc2a2bdb5 100644 --- a/test/sot/test_01_basic.py +++ b/test/sot/test_01_basic.py @@ -14,9 +14,10 @@ import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle +from paddle.jit.sot.utils import strict_mode_guard def foo(x: int, y: paddle.Tensor): @@ -34,7 +35,7 @@ def numpy_add(x, y): class TestNumpyAdd(TestCaseBase): - @strict_mode_guard(0) + @strict_mode_guard(False) def test_numpy_add(self): x = paddle.to_tensor([2]) y = paddle.to_tensor([3]) diff --git a/test/sot/test_12_for_loop.py b/test/sot/test_12_for_loop.py index 63e3fedace4bf..3d3b59043504e 100644 --- a/test/sot/test_12_for_loop.py +++ b/test/sot/test_12_for_loop.py @@ -19,7 +19,7 @@ import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle from paddle.jit import sot @@ -27,6 +27,7 @@ from paddle.jit.sot.opcode_translator.executor.executor_cache import ( OpcodeExecutorCache, ) +from paddle.jit.sot.utils import strict_mode_guard def gener(): @@ -185,6 +186,7 @@ def test_for_for_fallback(self): paddle_output = for_iter(a, gener()) self.assert_nest_match(sym_output, paddle_output) + @strict_mode_guard(False) def test_for_break(self): a = paddle.to_tensor(1) sym_output = symbolic_translate(for_break)(a, gener()) @@ -294,5 +296,4 @@ def test_undefined_var_case_1(self): if __name__ == "__main__": - with strict_mode_guard(0): - unittest.main() + unittest.main() diff --git a/test/sot/test_19_closure.py b/test/sot/test_19_closure.py index 6191141e07f39..ddfd36e2a6096 100644 --- a/test/sot/test_19_closure.py +++ b/test/sot/test_19_closure.py @@ -15,9 +15,10 @@ import inspect import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle +from paddle.jit.sot.utils import strict_mode_guard def foo(x: int, y: paddle.Tensor): @@ -180,7 +181,7 @@ def test_closure(self): self.assert_results(foo5, paddle.to_tensor(2)) self.assert_results(foo6, paddle.to_tensor(2)) self.assert_results(numpy_sum, paddle.to_tensor(1)) - with strict_mode_guard(0): + with strict_mode_guard(False): self.assert_results( lambda_closure, paddle.to_tensor(2), paddle.to_tensor(1) ) diff --git a/test/sot/test_break_graph.py b/test/sot/test_break_graph.py index 532f1c7a4c497..cc1aca51caec3 100644 --- a/test/sot/test_break_graph.py +++ b/test/sot/test_break_graph.py @@ -153,5 +153,16 @@ def test_simple(self): self.assert_results(test_break_graph_repeat, x) +def break_graph_resume_pass_null(x, y): + return paddle.add(x, y[0:50] if y is not None else None) + + +class TestBreakGraphResumePassNull(TestCaseBase): + def test_break_graph_resume_pass_null(self): + x = paddle.rand([50, 50], dtype=paddle.float32) + y = paddle.rand([100, 50], dtype=paddle.float32) + self.assert_results(break_graph_resume_pass_null, x, y) + + if __name__ == "__main__": unittest.main() diff --git a/test/sot/test_map.py b/test/sot/test_builtin_map.py similarity index 96% rename from test/sot/test_map.py rename to test/sot/test_builtin_map.py index 812ab36673be4..f005ec10cdbe4 100644 --- a/test/sot/test_map.py +++ b/test/sot/test_builtin_map.py @@ -17,10 +17,11 @@ import unittest from typing import Iterable -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase from paddle.jit import sot from paddle.jit.sot.psdb import check_no_breakgraph +from paddle.jit.sot.utils import strict_mode_guard def double_num(num: float | int): @@ -110,7 +111,7 @@ def test_map_comprehension(self): ) def test_map_with_breakgraph(self): - with strict_mode_guard(0): + with strict_mode_guard(False): self.assert_results(test_map_list_with_breakgraph, [1, 2, 3, 4]) def test_map_unpack(self): diff --git a/test/sot/test_range.py b/test/sot/test_builtin_range.py similarity index 100% rename from test/sot/test_range.py rename to test/sot/test_builtin_range.py diff --git a/test/sot/test_case_base.py b/test/sot/test_case_base.py index 03ce3c98227e8..f5a57f66c186b 100644 --- a/test/sot/test_case_base.py +++ b/test/sot/test_case_base.py @@ -136,23 +136,3 @@ def copy_fn(fn): sym_copied_fn.__globals__[key], paddle_fn.__globals__[key] ) self.assert_nest_match(sym_output, paddle_output) - - -@contextlib.contextmanager -def strict_mode_guard(value): - if "STRICT_MODE" not in os.environ: - os.environ["STRICT_MODE"] = "0" - old_value = os.environ["STRICT_MODE"] - os.environ["STRICT_MODE"] = str(value) - yield - os.environ["STRICT_MODE"] = old_value - - -@contextlib.contextmanager -def cost_model_guard(value): - if "COST_MODEL" not in os.environ: - os.environ["COST_MODEL"] = "True" - old_value = os.environ["COST_MODEL"] - os.environ["COST_MODEL"] = str(value) - yield - os.environ["COST_MODEL"] = old_value diff --git a/test/sot/test_code_status.py b/test/sot/test_code_status.py deleted file mode 100644 index 9fec5712c2293..0000000000000 --- a/test/sot/test_code_status.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from test_case_base import TestCaseBase, strict_mode_guard - -import paddle -from paddle.jit import sot -from paddle.jit.sot.opcode_translator.skip_files import skip_function -from paddle.jit.sot.utils.code_status import CodeState, CodeStatus - - -class SimpleNet1(paddle.nn.Layer): - def __init__(self): - super().__init__() - self.layers = paddle.nn.LayerList( - [paddle.nn.Linear(10, 10) for _ in range(30)] - ) - - def forward(self, x): - for i in range(len(self.layers)): - sot.psdb.breakgraph() - x = self.layers[i](x) - x = self.layers[i](x) - x = self.layers[i](x) - x = self.layers[i](x) - return x - - -class SimpleNet2(paddle.nn.Layer): - def __init__(self): - super().__init__() - self.layers = paddle.nn.LayerList( - [paddle.nn.Linear(10, 10) for _ in range(30)] - ) - - def forward(self, x): - sot.psdb.fallback() - for i in range(len(self.layers)): - x = self.layers[i](x) - x = self.layers[i](x) - x = self.layers[i](x) - x = self.layers[i](x) - return x - - -def run_net(net, x): - for i in range(20): - x = net(x) - return x - - -class TestCodeInfo(TestCaseBase): - def test_case_1(self): - CodeStatus().clear() - net = SimpleNet1() - inp = paddle.rand((10, 10)) - self.assert_results(run_net, net, inp) - code_map = CodeStatus().code_map - states = [] - for k, v in code_map.items(): - if k.co_name.startswith("#") or k.co_name.startswith("$"): - states.append(v) - elif k in CodeStatus().WITH_GRAPH_API: - assert v.state == CodeState.WITH_GRAPH - else: - assert v.state == CodeState.WITHOUT_GRAPH - # run_net, forward, loop body, resumed part2 in loop body - assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 4 - # resumed part1 in loop body - assert ( - len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 1 - ) - - def test_case_2(self): - with strict_mode_guard(0): - CodeStatus().clear() - net = SimpleNet2() - inp = paddle.rand((10, 10)) - self.assert_results(run_net, net, inp) - code_map = CodeStatus().code_map - states = [] - for k, v in code_map.items(): - if k.co_name.startswith("#") or k.co_name.startswith("$"): - states.append(v) - elif k in CodeStatus().WITH_GRAPH_API: - assert v.state == CodeState.WITH_GRAPH - else: - assert v.state == CodeState.WITHOUT_GRAPH - # no graph found because fallback (paddle api will not enter simulate) - assert ( - len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 0 - ) - - -def no_skip_func_0(x): - return x + 1 - - -def skipped_func_0(): - pass - - -def skipped_func_1(x): - return x + 1 - - -def skipped_func_2(x): - return no_skip_func_0(x) - - -def call_skipped_func_0(x): - for i in range(15): - skipped_func_0() - x = skipped_func_1(x) - x = skipped_func_2(x) - return x - - -skip_function(skipped_func_0) -skip_function(skipped_func_1) -skip_function(skipped_func_2) -skip_function(call_skipped_func_0) - - -class TestDisableSkippedFrame(TestCaseBase): - def test_case_0(self): - CodeStatus().clear() - x = paddle.to_tensor([1]) - self.assert_results(call_skipped_func_0, x) - code_map = CodeStatus().code_map - assert ( - code_map[skipped_func_0.__code__].state == CodeState.WITHOUT_GRAPH - ) - assert ( - code_map[skipped_func_1.__code__].state == CodeState.WITHOUT_GRAPH - ) - assert code_map[skipped_func_2.__code__].state == CodeState.WITH_GRAPH - - -if __name__ == "__main__": - unittest.main() diff --git a/test/sot/test_enumerate.py b/test/sot/test_enumerate.py index f81a451da55c9..236eece7560d2 100644 --- a/test/sot/test_enumerate.py +++ b/test/sot/test_enumerate.py @@ -14,9 +14,10 @@ import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle +from paddle.jit.sot.utils import strict_mode_guard def test_enumerate_1(x: int, y: int): @@ -100,13 +101,13 @@ def test_cases(self): self.assert_results(test_enumerate_4, ty) # TODO(zmh): support range for tensor - with strict_mode_guard(0): + with strict_mode_guard(False): self.assert_results(test_enumerate_5, paddle.to_tensor([1, 2, 3])) self.assert_results(test_enumerate_6, paddle.to_tensor([1, 2, 3])) self.assert_results(test_enumerate_7, ty) # TODO(zmh): support -1 - with strict_mode_guard(0): + with strict_mode_guard(False): self.assert_results(test_enumerate_8, ty) self.assert_results(test_enumerate_10, layer_list, paddle.randn((10,))) diff --git a/test/sot/test_envs.py b/test/sot/test_envs.py new file mode 100644 index 0000000000000..24f8b102679d2 --- /dev/null +++ b/test/sot/test_envs.py @@ -0,0 +1,183 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import unittest + +from paddle.utils.environments import ( + BooleanEnvironmentVariable, + EnvironmentVariableGuard, + IntegerEnvironmentVariable, + StringEnvironmentVariable, +) + + +class TestBooleanEnvironmentVariable(unittest.TestCase): + def test_bool_env_get(self): + env_name = "___TEST_ENV_BOOL_GET" + env_bool = BooleanEnvironmentVariable(env_name, False) + + self.assertIs(env_bool.get(), False) + + os.environ[env_name] = "False" + self.assertIs(env_bool.get(), False) + + os.environ[env_name] = "OFF" + self.assertIs(env_bool.get(), False) + + os.environ[env_name] = "0" + self.assertIs(env_bool.get(), False) + + os.environ[env_name] = "True" + self.assertIs(env_bool.get(), True) + + os.environ[env_name] = "ON" + self.assertIs(env_bool.get(), True) + + os.environ[env_name] = "1" + self.assertIs(env_bool.get(), True) + + def test_bool_env_set(self): + env_name = "___TEST_ENV_BOOL_SET" + env_bool = BooleanEnvironmentVariable(env_name, False) + + env_bool.set(True) + self.assertIs(env_bool.get(), True) + + env_bool.set(False) + self.assertIs(env_bool.get(), False) + + with self.assertRaises(AssertionError): + env_bool.set("True") + + with self.assertRaises(AssertionError): + env_bool.set("False") + + with self.assertRaises(AssertionError): + env_bool.set(0) + + with self.assertRaises(AssertionError): + env_bool.set(1) + + def test_bool_env_guard(self): + env_name = "___TEST_ENV_BOOL_GUARD" + env_bool = BooleanEnvironmentVariable(env_name, False) + + with EnvironmentVariableGuard(env_bool, True): + self.assertIs(env_bool.get(), True) + + with EnvironmentVariableGuard(env_bool, False): + self.assertIs(env_bool.get(), False) + + +class TestStringEnvironmentVariable(unittest.TestCase): + def test_str_env_get(self): + env_name = "___TEST_ENV_STR_GET" + env_str = StringEnvironmentVariable(env_name, "DEFAULT") + + self.assertEqual(env_str.get(), "DEFAULT") + + os.environ[env_name] = "CASE1" + self.assertEqual(env_str.get(), "CASE1") + + os.environ[env_name] = "CASE2" + self.assertEqual(env_str.get(), "CASE2") + + def test_str_env_set(self): + env_name = "___TEST_ENV_STR_SET" + env_str = StringEnvironmentVariable(env_name, "DEFAULT") + + self.assertEqual(env_str.get(), "DEFAULT") + + env_str.set("CASE1") + self.assertEqual(env_str.get(), "CASE1") + + env_str.set("CASE2") + self.assertEqual(env_str.get(), "CASE2") + + with self.assertRaises(AssertionError): + env_str.set(True) + + with self.assertRaises(AssertionError): + env_str.set(False) + + with self.assertRaises(AssertionError): + env_str.set(0) + + with self.assertRaises(AssertionError): + env_str.set(1) + + def test_str_env_guard(self): + env_name = "___TEST_ENV_STR_GUARD" + env_str = StringEnvironmentVariable(env_name, "DEFAULT") + + with EnvironmentVariableGuard(env_str, "CASE1"): + self.assertEqual(env_str.get(), "CASE1") + + with EnvironmentVariableGuard(env_str, "CASE2"): + self.assertEqual(env_str.get(), "CASE2") + + +class TestIntegerEnvironmentVariable(unittest.TestCase): + def test_int_env_get(self): + env_name = "___TEST_ENV_INT_GET" + env_int = IntegerEnvironmentVariable(env_name, 42) + + self.assertEqual(env_int.get(), 42) + + os.environ[env_name] = "10" + self.assertEqual(env_int.get(), 10) + + os.environ[env_name] = "99999" + self.assertEqual(env_int.get(), 99999) + + def test_int_env_set(self): + env_name = "___TEST_ENV_INT_SET" + env_int = IntegerEnvironmentVariable(env_name, 42) + + self.assertEqual(env_int.get(), 42) + + env_int.set(99) + self.assertEqual(env_int.get(), 99) + + env_int.set(1000) + self.assertEqual(env_int.get(), 1000) + + with self.assertRaises(AssertionError): + env_int.set(True) + + with self.assertRaises(AssertionError): + env_int.set(False) + + with self.assertRaises(AssertionError): + env_int.set("10") + + with self.assertRaises(AssertionError): + env_int.set("42") + + def test_int_env_guard(self): + env_name = "___TEST_ENV_INT_GUARD" + env_int = IntegerEnvironmentVariable(env_name, 42) + + with EnvironmentVariableGuard(env_int, 99): + self.assertEqual(env_int.get(), 99) + + with EnvironmentVariableGuard(env_int, 1000): + self.assertEqual(env_int.get(), 1000) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_error_handling.py b/test/sot/test_error_handling.py index c74436f0d44f4..4e5000cd0c50d 100644 --- a/test/sot/test_error_handling.py +++ b/test/sot/test_error_handling.py @@ -14,9 +14,10 @@ import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase from paddle.jit import sot +from paddle.jit.sot.utils import strict_mode_guard def fn_with_try_except(): @@ -30,7 +31,7 @@ def fn_with_try_except(): class TestErrorHandling(TestCaseBase): - @strict_mode_guard(0) + @strict_mode_guard(False) def test_fn_with_try_except(self): self.assert_results(fn_with_try_except) diff --git a/test/sot/test_min_graph_size.py b/test/sot/test_min_graph_size.py new file mode 100644 index 0000000000000..04a90f326d855 --- /dev/null +++ b/test/sot/test_min_graph_size.py @@ -0,0 +1,76 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# GET_ITER (new) +# FOR_ITER (new) + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit import sot +from paddle.jit.sot.utils import min_graph_size_guard + + +def case_for(x, vars): + x = x + 1 + sot.psdb.breakgraph() + for y in vars: + x += y + return x + + +def case_if(x): + x = x + 1 + if x > 5: + x += 3 + else: + x += 4 + return x + + +def case_call(x): + y = paddle.to_tensor(x.numpy()) + x += y + return x + + +def case_all(x, vars): + x = x + 1 + for y in vars: + z = paddle.to_tensor(x.numpy()) + x += z + x += y + if x > 5: + x += y + else: + x += 3 + return x + + +class TestMinGraphSize(TestCaseBase): + @min_graph_size_guard(10) + def test_cases(self): + x = paddle.to_tensor(1) + self.assert_results(case_for, x, [1, 2, 3]) + self.assert_results(case_if, x) + self.assert_results(case_call, x) + self.assert_results(case_all, x, [4, 5, 6]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_numpy.py b/test/sot/test_numpy.py index 3600d4df7cc45..eb47e86b03b20 100644 --- a/test/sot/test_numpy.py +++ b/test/sot/test_numpy.py @@ -15,9 +15,10 @@ import unittest import numpy as np -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle +from paddle.jit.sot.utils import strict_mode_guard def foo(x, y): @@ -32,7 +33,7 @@ def test_tensor_add_numpy_number(self): self.assert_results(foo, x, y) self.assert_results(foo, y, x) - @strict_mode_guard(0) + @strict_mode_guard(False) def test_tensor_add_numpy_array(self): x = paddle.to_tensor([1.0]) y = np.array(2.0) diff --git a/test/sot/test_numpy_var_if.py b/test/sot/test_numpy_var_if.py index 9d7c4a7048e25..6e098df70d3be 100644 --- a/test/sot/test_numpy_var_if.py +++ b/test/sot/test_numpy_var_if.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest import numpy as np @@ -20,8 +19,9 @@ import paddle from paddle.jit.sot.psdb import check_no_breakgraph, check_no_fallback +from paddle.jit.sot.utils import ENV_MIN_GRAPH_SIZE -os.environ['MIN_GRAPH_SIZE'] = '-1' +ENV_MIN_GRAPH_SIZE.set(-1) @check_no_breakgraph diff --git a/test/sot/test_segment_linear.py b/test/sot/test_segment_linear.py index ee3b7d70f8d36..9bd1b8b447137 100644 --- a/test/sot/test_segment_linear.py +++ b/test/sot/test_segment_linear.py @@ -19,6 +19,7 @@ import paddle from paddle import nn from paddle.jit import sot +from paddle.jit.sot.utils import strict_mode_guard class Head(nn.Layer): @@ -48,6 +49,7 @@ def getshape(self, x): return x def forward(self, x): + sot.psdb.fallback() shape = self.getshape(x) feat = self.tmp(x.mean().reshape([1])).reshape([1, 1024, 10]) logits = self.head(feat, shape[2:]) @@ -55,8 +57,8 @@ def forward(self, x): class TestExecutor(TestCaseBase): + @strict_mode_guard(False) def test_simple(self): - sot.skip_function(SimpleNet.forward) x = paddle.randn((1, 8, 8)) net = SimpleNet() net = paddle.jit.to_static( diff --git a/test/sot/test_side_effects.py b/test/sot/test_side_effects.py index 46bed6e8d3c4e..96ec9a7c5f6a7 100644 --- a/test/sot/test_side_effects.py +++ b/test/sot/test_side_effects.py @@ -16,12 +16,12 @@ import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle from paddle.jit import sot from paddle.jit.sot import symbolic_translate -from paddle.jit.sot.utils import InnerError +from paddle.jit.sot.utils import InnerError, strict_mode_guard def dict_setitem(x): @@ -275,7 +275,7 @@ def test_list_reverse(self): def test_slice_in_for_loop(self): x = 2 - with strict_mode_guard(0): + with strict_mode_guard(False): self.assert_results_with_side_effects(slice_in_for_loop, x) def test_list_nested(self): diff --git a/test/sot/test_simulate_initialize.py b/test/sot/test_simulate_initialize.py index 495e06ac1dbda..08a30dfc5a696 100644 --- a/test/sot/test_simulate_initialize.py +++ b/test/sot/test_simulate_initialize.py @@ -31,6 +31,18 @@ def foo(x, y): return out +def foo2(x, y): + t = nn.Softmax() + out1 = t(paddle.to_tensor([x, y], dtype="float32")) + out2 = t(paddle.to_tensor([x, y], dtype="float32")) + return out1 + out2 + + +def error_foo(x): + t = nn.Linear(10, 10) + return t(x) + + def bar(x): a = A(x) t = paddle.to_tensor(x) @@ -40,12 +52,20 @@ def bar(x): class TestInit(TestCaseBase): def test_init_paddle_layer(self): self.assert_results(foo, 1, 2) + self.assert_results(foo2, 1, 2) def test_init_python_object(self): sot_output = symbolic_translate(bar)([1.0, 2.0]) dyn_output = bar([1.0, 2.0]) self.assert_nest_match(sot_output, dyn_output) + def test_error(self): + def run(): + inputs = paddle.randn((10, 10)) + symbolic_translate(error_foo)(inputs) + + self.assertRaises(paddle.jit.sot.utils.exceptions.InnerError, run) + if __name__ == "__main__": unittest.main() diff --git a/test/sot/test_cost_model.py b/test/sot/test_sot_cost_model.py similarity index 92% rename from test/sot/test_cost_model.py rename to test/sot/test_sot_cost_model.py index 07899a03efbfd..a3acec5942005 100644 --- a/test/sot/test_cost_model.py +++ b/test/sot/test_sot_cost_model.py @@ -15,11 +15,11 @@ import time import unittest -from test_case_base import TestCaseBase, cost_model_guard +from test_case_base import TestCaseBase import paddle from paddle.jit.sot import psdb, symbolic_translate -from paddle.jit.sot.utils import StepInfoManager, StepState +from paddle.jit.sot.utils import StepInfoManager, StepState, cost_model_guard def dyn_fast(x, net, iter_): @@ -58,7 +58,7 @@ def forward(self, x): class TestCostModel(TestCaseBase): - @cost_model_guard("True") + @cost_model_guard(True) def test_dyn_fast(self): x = paddle.rand([10]) net = paddle.nn.Linear(10, 10) @@ -69,7 +69,7 @@ def test_dyn_fast(self): state = StepInfoManager().step_record[dyn_fast.__code__].state assert state == StepState.RUN_DYN - @cost_model_guard("True") + @cost_model_guard(True) def test_sot_fast_with_multi_graph(self): x = paddle.rand([10]) net = paddle.nn.Linear(10, 10) @@ -84,7 +84,7 @@ def test_sot_fast_with_multi_graph(self): ) assert state == StepState.RUN_SOT - @cost_model_guard("True") + @cost_model_guard(True) def test_sot_fast_with_single_graph(self): x = paddle.rand([10]) net = paddle.nn.Linear(10, 10) @@ -98,7 +98,7 @@ def test_sot_fast_with_single_graph(self): ) assert state == StepState.RUN_SOT - @cost_model_guard("True") + @cost_model_guard(True) def test_net(self): x = paddle.rand([10]) net = Net() diff --git a/test/sot/test_exception.py b/test/sot/test_sot_exception.py similarity index 100% rename from test/sot/test_exception.py rename to test/sot/test_sot_exception.py diff --git a/test/sot/test_resnet.py b/test/sot/test_sot_resnet.py similarity index 100% rename from test/sot/test_resnet.py rename to test/sot/test_sot_resnet.py diff --git a/test/sot/test_resnet50_backward.py b/test/sot/test_sot_resnet50_backward.py similarity index 98% rename from test/sot/test_resnet50_backward.py rename to test/sot/test_sot_resnet50_backward.py index bd5aac0025e80..d1199dd421baf 100644 --- a/test/sot/test_resnet50_backward.py +++ b/test/sot/test_sot_resnet50_backward.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - -os.environ["FLAGS_cudnn_deterministic"] = "True" - import random import unittest diff --git a/test/sot/test_tensor_dtype_in_guard.py b/test/sot/test_tensor_dtype_in_guard.py index d5d001b7038d0..47740b72862b1 100644 --- a/test/sot/test_tensor_dtype_in_guard.py +++ b/test/sot/test_tensor_dtype_in_guard.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import unittest from test_case_base import ( @@ -21,6 +22,7 @@ import paddle from paddle.jit import sot +from paddle.jit.sot.utils import strict_mode_guard def foo(x, y): @@ -31,8 +33,8 @@ def foo(x, y): return out -@sot.skip_function def dtype_in_guard(x, y): + sot.psdb.fallback() with paddle.amp.auto_cast(level='O2'): for i in range(10): z = foo(x, y) @@ -47,8 +49,8 @@ def bar(x, y): return y - 1 -@sot.skip_function def dtype_as_input(x, y): + sot.psdb.fallback() with paddle.amp.auto_cast(level='O2'): for i in range(10): z = bar(x, y) @@ -57,19 +59,29 @@ def dtype_as_input(x, y): class TestDtypeInGuard(TestCaseBase): + @strict_mode_guard(False) def test_dtype_in_guard(self): with test_instruction_translator_cache_context() as ctx: x = paddle.to_tensor([2], dtype="float32") y = paddle.to_tensor([3], dtype="float32") self.assert_results(dtype_in_guard, x, y) - self.assertEqual(ctx.translate_count, 1) + if sys.version_info >= (3, 11): + # skipped with co_exceptiontable flag + self.assertEqual(ctx.translate_count, 1) + else: + self.assertEqual(ctx.translate_count, 2) + @strict_mode_guard(False) def test_input_dtype_in_guard(self): with test_instruction_translator_cache_context() as ctx: x = paddle.float32 y = paddle.to_tensor([3], dtype="float32") self.assert_results(dtype_as_input, x, y) - self.assertEqual(ctx.translate_count, 1) + if sys.version_info >= (3, 11): + # skipped with co_exceptiontable flag + self.assertEqual(ctx.translate_count, 1) + else: + self.assertEqual(ctx.translate_count, 2) if __name__ == "__main__": diff --git a/test/standalone_executor/test_standalone_cuda_graph_multi_stream.py b/test/standalone_executor/test_standalone_cuda_graph_multi_stream.py index b5c49313d87d4..48d73ad332cd1 100644 --- a/test/standalone_executor/test_standalone_cuda_graph_multi_stream.py +++ b/test/standalone_executor/test_standalone_cuda_graph_multi_stream.py @@ -28,6 +28,10 @@ def can_use_cuda_graph(): return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or float(paddle.version.cuda()) < 11.0, + "only support cuda >= 11.0", +) class TestCustomStream(unittest.TestCase): def setUp(self): self.steps = 10 diff --git a/test/standalone_executor/test_standalone_custom_event.py b/test/standalone_executor/test_standalone_custom_event.py index b87609841e6e4..18595eefe5a42 100644 --- a/test/standalone_executor/test_standalone_custom_event.py +++ b/test/standalone_executor/test_standalone_custom_event.py @@ -139,19 +139,14 @@ def create_standalone_exe(self, main_progs, startup_progs, fetch_list): # create jobs for program_id in range(prog_num): job = core.Job(f"prog_{program_id}") - # Set col_attr info for fetch_op to fetch the correct data after running multiple micro batch - if program_id == prog_num - 1: - for i in range(fetch_op_num): - job.set_col_attr_for_fetch_op( - fetch_op_indics[i], - i * micro_batch_num + micro_batch_id, - ) job_list.append(job) - type_to_program = {} + job_types = [] for program_id in range(prog_num): - type_to_program[f"prog_{program_id}"] = main_progs[program_id] - set_skip_gc_vars(micro_batch_num, type_to_program, job_list) + job_types.append(f"prog_{program_id}") + type_to_program = set_skip_gc_vars( + micro_batch_num, job_types, main_progs, job_list + ) for type in type_to_program.keys(): type_to_program[type] = type_to_program[type].desc diff --git a/test/standalone_executor/test_standalone_dist_attr_run_time_set_get.py b/test/standalone_executor/test_standalone_dist_attr_run_time_set_get.py new file mode 100644 index 0000000000000..3b1f74a6e47e7 --- /dev/null +++ b/test/standalone_executor/test_standalone_dist_attr_run_time_set_get.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import paddle +from paddle.static import Program, program_guard + +paddle.enable_static() + + +class TestOpProfiling(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + pass + + def _build_startup_program_and_train_program(self): + startup_program = Program() + train_program = Program() + with program_guard(train_program, startup_program): + data = paddle.static.data( + name='X', shape=[1024, 1], dtype='float32' + ) + hidden = paddle.static.nn.fc(data, 10) + loss = paddle.mean(hidden) + paddle.optimizer.SGD(learning_rate=0.01).minimize(loss) + return startup_program, train_program, loss + + def test_run_time_us_set_get_method(self): + ''' + * test if the newly added "run_time_us_" actually works (set then get) + ''' + ( + startup_program, + train_program, + loss, + ) = self._build_startup_program_and_train_program() + global_block = startup_program.global_block() + global_block.ops[0].dist_attr.run_time_us = 1.0 # set + dt = global_block.ops[0].dist_attr.run_time_us # get + if dt != 1.0: + raise RuntimeError("dist_attr set/get method failed!") + else: + sys.stdout.write("OK.") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/standalone_executor/test_standalone_executor_1f1b_plan.py b/test/standalone_executor/test_standalone_executor_1f1b_plan.py index 76ae03d842089..6911523f96424 100644 --- a/test/standalone_executor/test_standalone_executor_1f1b_plan.py +++ b/test/standalone_executor/test_standalone_executor_1f1b_plan.py @@ -37,7 +37,6 @@ def test_standalone_executor_1f1b_plan_stage0(self): job_type_list.append(job.type()) micro_batch_id_list.append(job.micro_batch_id()) expect_job_type_list = [ - "lr", "forward", "forward", "forward", @@ -57,7 +56,6 @@ def test_standalone_executor_1f1b_plan_stage0(self): "optimizer", ] expect_micro_batch_id_list = [ - 0, 0, 1, 2, @@ -97,7 +95,6 @@ def test_standalone_executor_1f1b_plan_stage1(self): job_type_list.append(job.type()) micro_batch_id_list.append(job.micro_batch_id()) expect_job_type_list = [ - "lr", "forward", "forward", "forward", @@ -117,7 +114,6 @@ def test_standalone_executor_1f1b_plan_stage1(self): "optimizer", ] expect_micro_batch_id_list = [ - 0, 0, 1, 2, @@ -157,7 +153,6 @@ def test_standalone_executor_1f1b_plan_stage2(self): job_type_list.append(job.type()) micro_batch_id_list.append(job.micro_batch_id()) expect_job_type_list = [ - "lr", "forward", "forward", "backward", @@ -177,7 +172,6 @@ def test_standalone_executor_1f1b_plan_stage2(self): "optimizer", ] expect_micro_batch_id_list = [ - 0, 0, 1, 0, @@ -217,7 +211,6 @@ def test_standalone_executor_1f1b_plan_stage3(self): job_type_list.append(job.type()) micro_batch_id_list.append(job.micro_batch_id()) expect_job_type_list = [ - "lr", "forward", "backward", "forward", @@ -237,7 +230,6 @@ def test_standalone_executor_1f1b_plan_stage3(self): "optimizer", ] expect_micro_batch_id_list = [ - 0, 0, 0, 1, diff --git a/test/standalone_executor/test_standalone_executor_fthenb_plan.py b/test/standalone_executor/test_standalone_executor_fthenb_plan.py index 76557231b83e4..10a7670a8f89f 100644 --- a/test/standalone_executor/test_standalone_executor_fthenb_plan.py +++ b/test/standalone_executor/test_standalone_executor_fthenb_plan.py @@ -36,7 +36,6 @@ def test_standalone_executor_fthenb_plan(self): for job in plan.job_list(): job_type_list.append(job.type()) expect_job_type_list = [ - "lr", "forward", "forward", "forward", diff --git a/test/standalone_executor/test_standalone_executor_multi_micro_batch.py b/test/standalone_executor/test_standalone_executor_multi_micro_batch.py index b829a69fa7f1b..6222431d5cce9 100644 --- a/test/standalone_executor/test_standalone_executor_multi_micro_batch.py +++ b/test/standalone_executor/test_standalone_executor_multi_micro_batch.py @@ -185,20 +185,14 @@ def run_train(self, split=False, micro_batch_num=1): for program_id in range(program_num): job = Job(f"P{program_id}") job.set_micro_batch_id(micro_batch_id) - # Set col_attr info for fetch_op to fetch the correct data after running multiple micro batch - if program_id == program_num - 1: - fetch_op_id_to_col_attr = {} - for i in range(fetch_op_num): - job.set_col_attr_for_fetch_op( - fetch_op_indics[i], - i * micro_batch_num + micro_batch_id, - ) job_list.append(job) - type_to_program = {} + job_types = [] for program_id in range(program_num): - type_to_program[f"P{program_id}"] = programs[program_id] - set_skip_gc_vars(micro_batch_num, type_to_program, job_list) + job_types.append(f"P{program_id}") + type_to_program = set_skip_gc_vars( + micro_batch_num, job_types, programs, job_list + ) for type in type_to_program.keys(): type_to_program[type] = type_to_program[type].desc diff --git a/test/white_list/pir_op_test_no_check_list b/test/white_list/pir_op_test_no_check_list new file mode 100644 index 0000000000000..8363980af0347 --- /dev/null +++ b/test/white_list/pir_op_test_no_check_list @@ -0,0 +1,3 @@ +test_exponential_op +test_randint_op +test_seed_op diff --git a/test/white_list/new_ir_op_test_precision_white_list b/test/white_list/pir_op_test_precision_white_list similarity index 100% rename from test/white_list/new_ir_op_test_precision_white_list rename to test/white_list/pir_op_test_precision_white_list diff --git a/test/white_list/new_ir_op_test_white_list b/test/white_list/pir_op_test_white_list similarity index 97% rename from test/white_list/new_ir_op_test_white_list rename to test/white_list/pir_op_test_white_list index dea0398f9d5fa..d3ff77d26da66 100644 --- a/test/white_list/new_ir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -56,19 +56,19 @@ test_cumprod_op test_deformable_conv_op test_determinant_op test_diag_embed -test_diagonal_op test_diag_v2 +test_diagonal_op test_digamma_op test_dist_op test_dot_op test_dpsgd_op test_edit_distance_op -test_eigh_op -test_eigh_op_static_build test_eig_op test_eig_op_static_build -test_eigvalsh_op +test_eigh_op +test_eigh_op_static_build test_eigvals_op +test_eigvalsh_op test_einsum_op test_elementwise_div_op test_elementwise_floordiv_op @@ -79,6 +79,7 @@ test_elementwise_mul_op test_elementwise_pow_op test_erfinv_op test_expand_v2_op +test_exponential_op test_eye_op test_fill_any_op test_fill_constant_batch_size_like @@ -101,17 +102,18 @@ test_group_norm_op test_histogram_op test_hsigmoid_op test_huber_loss_op -test_i0e_op test_i0_op -test_i1e_op +test_i0e_op test_i1_op +test_i1e_op +test_imperative_lod_tensor_to_selected_rows test_index_add_op test_index_sample_op test_instance_norm_op_v2 test_inverse_op test_ir_pybind -test_isclose_op test_is_empty_op +test_isclose_op test_kldiv_loss_op test_kron_op test_kthvalue_op @@ -120,10 +122,10 @@ test_lerp_op test_lgamma_op test_linear_interp_v2_op test_linspace -test_logcumsumexp_op -test_logit_op test_log_loss_op test_log_softmax +test_logcumsumexp_op +test_logit_op test_logspace test_logsumexp test_lookup_table_v2_op @@ -136,7 +138,9 @@ test_matrix_nms_op test_matrix_power_op test_maxout_op test_mean_op +test_memcpy_op test_mode_op +test_mul_op test_multi_dot_op test_multiplex_op test_mv_op @@ -161,9 +165,11 @@ test_prelu_op test_prior_box_op test_psroi_pool_op test_put_along_axis_op +test_randint_op test_range test_reduce_op test_reduce_op_static_build +test_repeat_interleave_op test_reshape_op test_reverse_op test_roi_align_op @@ -171,6 +177,7 @@ test_roi_pool_op test_rrelu_op test_scale_op test_searchsorted_op +test_seed_op test_segment_ops test_segment_ops_static_build test_selu_op @@ -182,6 +189,7 @@ test_sign_op test_size_op test_slice_op test_solve_op +test_sparse_momentum_op test_spectral_norm_op test_spectral_op test_squared_l2_norm_op @@ -209,4 +217,3 @@ test_warprnnt_op test_where_op test_yolo_box_op test_yolov3_loss_op -test_imperative_lod_tensor_to_selected_rows diff --git a/test/xpu/test_cast_op_xpu.py b/test/xpu/test_cast_op_xpu.py index 0a8043d523f5f..521875fbdbaee 100644 --- a/test/xpu/test_cast_op_xpu.py +++ b/test/xpu/test_cast_op_xpu.py @@ -20,6 +20,7 @@ create_test_class, get_xpu_op_support_types, ) +from op_test import convert_float_to_uint16, convert_uint16_to_float from op_test_xpu import XPUOpTest import paddle @@ -31,6 +32,7 @@ 'int64': int(core.VarDesc.VarType.INT64), 'float32': int(core.VarDesc.VarType.FP32), 'float16': int(core.VarDesc.VarType.FP16), + 'bfloat16': int(core.VarDesc.VarType.BF16), 'bool': int(core.VarDesc.VarType.BOOL), 'int8': int(core.VarDesc.VarType.INT8), 'uint8': int(core.VarDesc.VarType.UINT8), @@ -48,6 +50,7 @@ def dynamic_create_class(self): classes = [] for out_type in { 'float16', + 'bfloat16', 'float32', 'int32', 'int64', @@ -71,8 +74,18 @@ def setUp(self): else self.out_typename ) - self.inputs = {'X': ipt.astype(in_typename)} - self.outputs = {'Out': ipt.astype(in_typename).astype(out_typename)} + if in_typename == "bfloat16": + ipt_x = convert_float_to_uint16(ipt) + else: + ipt_x = ipt.astype(in_typename) + + if out_typename == "bfloat16": + opt = convert_uint16_to_float(convert_float_to_uint16(ipt_x)) + else: + opt = ipt_x.astype(out_typename) + + self.inputs = {'X': ipt_x} + self.outputs = {'Out': opt} self.attrs = { 'in_dtype': typeid_dict[in_typename], 'out_dtype': typeid_dict[out_typename], diff --git a/test/xpu/test_fill_any_op_xpu.py b/test/xpu/test_fill_any_op_xpu.py index 2d71f78e05c34..22e493be70b07 100644 --- a/test/xpu/test_fill_any_op_xpu.py +++ b/test/xpu/test_fill_any_op_xpu.py @@ -111,6 +111,23 @@ def test_backward(self): ) +class TestFillAnyLikeOpSpecialValue(unittest.TestCase): + def setUp(self): + self.special_values = [float("nan"), float("+inf"), float("-inf")] + self.dtypes = ["float32", "float16"] + + def test_dygraph_api(self): + paddle.disable_static() + paddle.set_device("xpu") + for dtype in self.dtypes: + for value in self.special_values: + ref = paddle.empty([4, 4], dtype=dtype) + val_pd = paddle.full_like(ref, value, dtype=dtype) + val_np = np.full([4, 4], value, dtype=dtype) + np.testing.assert_equal(val_pd.numpy(), val_np) + paddle.enable_static() + + support_types = get_xpu_op_support_types('fill_any') for stype in support_types: create_test_class(globals(), XPUTestFillAnyOp, stype) diff --git a/test/xpu/test_gaussian_random_op_xpu.py b/test/xpu/test_gaussian_random_op_xpu.py index abdec498f0a62..7e80bd00ac586 100644 --- a/test/xpu/test_gaussian_random_op_xpu.py +++ b/test/xpu/test_gaussian_random_op_xpu.py @@ -26,8 +26,23 @@ from paddle import base paddle.enable_static() +from paddle.base import core from paddle.tensor import random +typeid_dict = { + 'int32': int(core.VarDesc.VarType.INT32), + 'int64': int(core.VarDesc.VarType.INT64), + 'float32': int(core.VarDesc.VarType.FP32), + 'float16': int(core.VarDesc.VarType.FP16), + 'bfloat16': int(core.VarDesc.VarType.BF16), + 'bool': int(core.VarDesc.VarType.BOOL), + 'int8': int(core.VarDesc.VarType.INT8), + 'uint8': int(core.VarDesc.VarType.UINT8), + 'float64': int(core.VarDesc.VarType.FP64), +} + +from op_test import convert_uint16_to_float + class XPUTestGaussianRandomOp(XPUOpTestWrapper): def __init__(self): @@ -52,6 +67,7 @@ def setUp(self): "std": self.std, "seed": 10, "use_mkldnn": self.use_mkldnn, + "dtype": typeid_dict[self.in_type_str], } paddle.seed(10) @@ -67,6 +83,10 @@ def test_check_output(self): ) def verify_output(self, outs): + # special for bf16 + if self.in_type_str == "bfloat16": + outs = convert_uint16_to_float(outs) + self.assertEqual(outs[0].shape, (123, 92)) hist, _ = np.histogram(outs[0], range=(-3, 5)) hist = hist.astype("float32") @@ -100,6 +120,7 @@ def setUp(self): 'std': self.std, 'seed': self.seed, 'use_mkldnn': self.use_mkldnn, + "dtype": typeid_dict[self.in_type_str], } self.inputs = {"ShapeTensorList": shape_tensor_list} @@ -165,6 +186,7 @@ def setUp(self): 'std': self.std, 'seed': self.seed, 'use_mkldnn': self.use_mkldnn, + "dtype": typeid_dict[self.in_type_str], } self.outputs = {'Out': np.zeros((123, 92), dtype=self.dtype)} @@ -265,6 +287,11 @@ def test_default_fp16(): out = paddle.tensor.random.gaussian([2, 3]) self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP16) + def test_default_bf16(): + paddle.framework.set_default_dtype('bfloat16') + out = paddle.tensor.random.gaussian([2, 3]) + self.assertEqual(out.dtype, base.core.VarDesc.VarType.BF16) + def test_default_fp32(): paddle.framework.set_default_dtype('float32') out = paddle.tensor.random.gaussian([2, 3]) @@ -278,6 +305,7 @@ def test_default_fp64(): test_default_fp64() test_default_fp32() test_default_fp16() + test_default_bf16() paddle.enable_static() @@ -291,6 +319,11 @@ def test_default_fp16(): out = paddle.tensor.random.standard_normal([2, 3]) self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP16) + def test_default_bf16(): + paddle.framework.set_default_dtype('bfloat16') + out = paddle.tensor.random.standard_normal([2, 3]) + self.assertEqual(out.dtype, base.core.VarDesc.VarType.BF16) + def test_default_fp32(): paddle.framework.set_default_dtype('float32') out = paddle.tensor.random.standard_normal([2, 3]) @@ -304,6 +337,7 @@ def test_default_fp64(): test_default_fp64() test_default_fp32() test_default_fp16() + test_default_bf16() paddle.enable_static() diff --git a/test/xpu/test_scatter_nd_add_op_xpu.py b/test/xpu/test_scatter_nd_add_op_xpu.py index f303cd9ce5150..6efb4fec3b0f7 100644 --- a/test/xpu/test_scatter_nd_add_op_xpu.py +++ b/test/xpu/test_scatter_nd_add_op_xpu.py @@ -91,6 +91,9 @@ def setUp(self): def test_check_output(self): self.check_output_with_place(self.place) + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X', 'Updates'], 'Out') + def init_data(self): self.x_np = np.random.random([100]).astype(self.dtype) self.index_np = np.random.randint(0, 100, [100, 1]).astype("int32") @@ -103,8 +106,10 @@ def infer_dtype_from_inputs_outputs(self, inputs, outputs): class TestScatterNdAddWithEmptyIndex(TestScatterNdAdd): def init_data(self): self.x_np = np.random.random((10, 10)).astype(self.dtype) - self.index_np = np.array([[], []]).astype("int32") - self.updates_np = np.random.random((2, 10, 10)).astype(self.dtype) + self.index_np = np.array([[[], []], [[], []]]).astype("int32") + self.updates_np = np.random.random((2, 2, 10, 10)).astype( + self.dtype + ) class TestScatterNdAddOpWithHighRankSame(TestScatterNdAdd): def init_data(self): @@ -138,6 +143,13 @@ def init_data(self): update_shape = judge_update_shape(self.x_np, self.index_np) self.updates_np = np.random.rand(*update_shape).astype(self.dtype) + class TestScatterNdAddWithZeroDimUpdates(TestScatterNdAdd): + def init_data(self): + shape = (10,) + self.x_np = np.random.rand(*shape).astype(self.dtype) + self.index_np = np.random.randint(0, 10, [1]).astype("int32") + self.updates_np = np.array(np.random.rand()).astype(self.dtype) + support_types = get_xpu_op_support_types('scatter_nd_add') for stype in support_types: diff --git a/test/xpu/test_scatter_op_xpu.py b/test/xpu/test_scatter_op_xpu.py index c8b627fce82e5..7ff92985b34b2 100644 --- a/test/xpu/test_scatter_op_xpu.py +++ b/test/xpu/test_scatter_op_xpu.py @@ -145,7 +145,7 @@ def test_check_output(self): self.check_output_with_place(self.place) def test_check_grad(self): - self.check_grad_with_place(self.place, ['X'], 'Out') + self.check_grad_with_place(self.place, ['X', 'Updates'], 'Out') support_types = get_xpu_op_support_types('scatter') diff --git a/test/xpu/test_uniform_random_op_xpu.py b/test/xpu/test_uniform_random_op_xpu.py index 24972d64b0eb6..a82f305b047a4 100644 --- a/test/xpu/test_uniform_random_op_xpu.py +++ b/test/xpu/test_uniform_random_op_xpu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,32 +16,97 @@ import unittest import numpy as np -from test_uniform_random_op import ( - TestUniformRandomOp, - TestUniformRandomOpSelectedRows, +from get_test_cover_info import ( + XPUOpTestWrapper, + create_test_class, + get_xpu_op_support_types, ) +from op_test_xpu import XPUOpTest import paddle paddle.enable_static() +from paddle.base import core +typeid_dict = { + 'int32': int(core.VarDesc.VarType.INT32), + 'int64': int(core.VarDesc.VarType.INT64), + 'float32': int(core.VarDesc.VarType.FP32), + 'float16': int(core.VarDesc.VarType.FP16), + 'bfloat16': int(core.VarDesc.VarType.BF16), + 'bool': int(core.VarDesc.VarType.BOOL), + 'int8': int(core.VarDesc.VarType.INT8), + 'uint8': int(core.VarDesc.VarType.UINT8), + 'float64': int(core.VarDesc.VarType.FP64), +} -class TestXPUUniformRandomOp(TestUniformRandomOp): - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - outs = self.calc_output(place) - outs = [np.array(out) for out in outs] - outs.sort(key=len) - self.verify_output(outs) +def output_hist(out): + if out.dtype == np.uint16: + out = convert_uint16_to_float(out) + hist, _ = np.histogram(out, range=(-5, 10)) + hist = hist.astype("float32") + hist /= float(out.size) + prob = 0.1 * np.ones(10) + return hist, prob -class TestXPUUniformRandomOpSelectedRows(TestUniformRandomOpSelectedRows): - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_with_place(place) +from op_test import convert_uint16_to_float + + +class XPUTestUniformRandomOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'uniform_random' + self.use_dynamic_create_class = False + + class TestUniformRandomOp(XPUOpTest): + def init(self): + self.dtype = self.in_type + self.place = paddle.XPUPlace(0) + self.op_type = "uniform_random" + self.python_api = paddle.uniform + + def setUp(self): + self.init() + self.inputs = {} + self.use_mkldnn = False + self.set_attrs() + paddle.seed(10) + + self.outputs = {"Out": np.zeros((1000, 784), dtype=self.dtype)} + + def set_attrs(self): + self.attrs = { + "shape": [1000, 784], + "min": -5.0, + "max": 10.0, + "dtype": typeid_dict[self.in_type_str], + } + self.output_hist = output_hist + + def test_check_output(self): + self.check_output_with_place_customized( + self.verify_output, self.place + ) + + def verify_output(self, outs): + hist, prob = self.output_hist(np.array(outs[0])) + np.testing.assert_allclose(hist, prob, rtol=0, atol=0.01) + + class TestMaxMinAreInt(TestUniformRandomOp): + def set_attrs(self): + self.attrs = { + "shape": [1000, 784], + "min": -5, + "max": 10, + "dtype": typeid_dict[self.in_type_str], + } + self.output_hist = output_hist + + +support_types = get_xpu_op_support_types('uniform_random') +for stype in support_types: + create_test_class(globals(), XPUTestUniformRandomOp, stype) if __name__ == "__main__": unittest.main() diff --git a/test/xpu/test_where_op_xpu.py b/test/xpu/test_where_op_xpu.py index 5a740f8dee5e9..70819a0e7687a 100644 --- a/test/xpu/test_where_op_xpu.py +++ b/test/xpu/test_where_op_xpu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ create_test_class, get_xpu_op_support_types, ) +from op_test import convert_float_to_uint16 from op_test_xpu import XPUOpTest import paddle @@ -37,14 +38,23 @@ class TestXPUWhereOp(XPUOpTest): def setUp(self): self.init_config() self.init_data() + self.convert_data_if_bf16() self.inputs = {'Condition': self.cond, 'X': self.x, 'Y': self.y} self.outputs = {'Out': np.where(self.cond, self.x, self.y)} def init_data(self): - self.x = np.random.uniform(-3, 5, (100)).astype(self.dtype) - self.y = np.random.uniform(-3, 5, (100)).astype(self.dtype) + self.x = np.random.uniform(-3, 5, (100)) + self.y = np.random.uniform(-3, 5, (100)) self.cond = np.zeros(100).astype("bool") + def convert_data_if_bf16(self): + if self.dtype == np.uint16: + self.x = convert_float_to_uint16(self.x) + self.y = convert_float_to_uint16(self.y) + else: + self.x = self.x.astype(self.dtype) + self.y = self.y.astype(self.dtype) + def init_config(self): self.op_type = "where" self.dtype = self.in_type @@ -59,14 +69,14 @@ def test_check_grad(self): class TestXPUWhereOp2(TestXPUWhereOp): def init_data(self): - self.x = np.random.uniform(-5, 5, (60, 2)).astype(self.dtype) - self.y = np.random.uniform(-5, 5, (60, 2)).astype(self.dtype) + self.x = np.random.uniform(-5, 5, (60, 2)) + self.y = np.random.uniform(-5, 5, (60, 2)) self.cond = np.ones((60, 2)).astype("bool") class TestXPUWhereOp3(TestXPUWhereOp): def init_data(self): - self.x = np.random.uniform(-3, 5, (20, 2, 4)).astype(self.dtype) - self.y = np.random.uniform(-3, 5, (20, 2, 4)).astype(self.dtype) + self.x = np.random.uniform(-3, 5, (20, 2, 4)) + self.y = np.random.uniform(-3, 5, (20, 2, 4)) self.cond = np.array( np.random.randint(2, size=(20, 2, 4)), dtype=bool ) diff --git a/third_party/mkldnn b/third_party/mkldnn index 403667673f61a..01204edbda1c2 160000 --- a/third_party/mkldnn +++ b/third_party/mkldnn @@ -1 +1 @@ -Subproject commit 403667673f61a56289622fd5bc587b1856296fbc +Subproject commit 01204edbda1c2a4ff0cccd40476ed6bd2fb62d56 diff --git a/tools/check_file_diff_approvals.sh b/tools/check_file_diff_approvals.sh index 37d75207cfb84..e89892b2bf02a 100644 --- a/tools/check_file_diff_approvals.sh +++ b/tools/check_file_diff_approvals.sh @@ -287,8 +287,8 @@ if [ "${HAS_MODIFIED_DECLARATIONS}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then HAS_USED_CCTESTOLD=`git diff -U0 upstream/$BRANCH |grep "cc_test_old" || true` if [ "${HAS_USED_CCTESTOLD}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then - echo_line="You must be approved by phlrain or risemeup1 or zhangbo9674 for using cc_test_old. Thanks!\n" - check_approval 1 phlrain risemeup1 zhangbo9674 + echo_line="You must be approved by phlrain or risemeup1 or zhangbo9674 or Galaxy1458 for using cc_test_old. Thanks!\n" + check_approval 1 phlrain risemeup1 zhangbo9674 Galaxy1458 fi HAS_MODIFIED_API_COMPAT_YAML=`git diff --name-only upstream/$BRANCH | grep "paddle/phi/api/yaml/op_compat.yaml" || true` @@ -443,6 +443,19 @@ if [ "${ALL_OPTEST_BAN_DYGRAPH_MESSAGE}" != "" ] && [ "${GIT_PR_ID}" != "" ]; th check_approval 1 phlrain fuyinno4 QingshuChen lanxianghit fi +ALL_CHANGE_YAML_FILES=`git diff --numstat upstream/$BRANCH | awk '{print $3}' | grep ".yaml"` +BAN_COMP_MESSAGE="" +for CHANGE_FILE in ${ALL_CHANGE_YAML_FILES}; do + ALL_ITEM_BAN_COMP=`git diff -U0 upstream/$BRANCH ${PADDLE_ROOT}/${CHANGE_FILE} | grep "composite" || true` + if [ "${ALL_ITEM_BAN_COMP}" != "" ]; then + BAN_COMP_MESSAGE="${BAN_COMP_MESSAGE} ${CHANGE_FILE} : \n${ALL_ITEM_BAN_COMP} \n" + fi +done +if [ "${BAN_COMP_MESSAGE}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then + echo_line="If you need to change the key composite, you must have one RD (Charles-hit(wanghao), cyber-pioneer(chenzhuo), cxxly(chenxiaoxu)) review and approve. \nThe code that do not meet the specification are as follows:\n${BAN_COMP_MESSAGE}\n" + check_approval 1 Charles-hit cyber-pioneer cxxly +fi + NEW_OP_ADDED=`git diff --name-only --diff-filter=A upstream/$BRANCH |grep -oE ".+_op..*" || true` if [ "${NEW_OP_ADDED}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then GET_KERNEL_TYPE_FUNC_CNT=`git diff -U0 --diff-filter=A upstream/$BRANCH |grep "+" |grep -czoE "GetExpectedKernelType[(][^(){}]+[)][^{]+[{][^}]+[}]" || true` diff --git a/tools/cinn/build.sh b/tools/cinn/build.sh index a32ae972e340b..b3a92a16e0f63 100755 --- a/tools/cinn/build.sh +++ b/tools/cinn/build.sh @@ -154,9 +154,9 @@ function run_test { export runtime_include_dir=$workspace/paddle/cinn/runtime/cuda if [ ${TESTING_DEBUG_MODE:-OFF} == "ON" ] ; then - ctest --parallel 10 -V + ctest --parallel 10 -V -E "test_frontend_interpreter|test_cinn_fake_resnet" else - ctest --parallel 10 --output-on-failure + ctest --parallel 10 --output-on-failure -E "test_frontend_interpreter|test_cinn_fake_resnet" fi } diff --git a/tools/codestyle/copyright.hook b/tools/codestyle/copyright.hook index 8985e3882cdd6..e007af33ce3cb 100644 --- a/tools/codestyle/copyright.hook +++ b/tools/codestyle/copyright.hook @@ -36,7 +36,7 @@ def _generate_copyright(comment_mark): copyright=COPYRIGHT.split(os.linesep) header = copyright[0].rstrip() - p = re.search('(\d{4})', header).group(0) + p = re.search(r'(\d{4})', header).group(0) now = datetime.datetime.now() header = header.replace(p,str(now.year)) diff --git a/test/sot/extract_errors.py b/tools/codestyle/sort_txt_file.py similarity index 53% rename from test/sot/extract_errors.py rename to tools/codestyle/sort_txt_file.py index b9d9e505724ef..f08c79eb36484 100644 --- a/test/sot/extract_errors.py +++ b/tools/codestyle/sort_txt_file.py @@ -12,19 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re +import os import sys -runtime_error_msg = sys.stdin.read() -pattern = r'File "?(.*?)"?, line (\d+),.*\n(.*?)\n(.*?)$' -for match in re.finditer(pattern, runtime_error_msg, re.MULTILINE): - file = match.group(1) - if file.startswith("./"): - file = f"tests/{file[2:]}" - line = match.group(2) - error_info = match.group(4) - if "AssertionError" not in error_info: - # error_info = match.group(3) + '\n' + match.group(4) - output = f"::error file={file},line={line}::Error" - print(output) +def sort_by_dict_order(file_path): + with open(file_path, 'r') as f: + lines = f.readlines() + sorted_lines = sorted(lines) + with open(file_path, 'w') as f: + f.writelines(sorted_lines) + + +if __name__ == '__main__': + file_paths = sys.argv[1:] + for file_path in file_paths: + file_path = os.path.normpath(file_path) + sort_by_dict_order(file_path) diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index 5467eb1d5925f..822f0a11fec21 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -29,7 +29,9 @@ function collect_failed_tests() { serial_list="^test_conv2d_op$|\ ^test_conv2d_transpose_op$|\ +^test_dygraph_dataparallel_bf16$|\ ^test_dygraph_sharding_stage1_fp16$|\ +^test_dygraph_sharding_stage1_bf16$|\ ^test_dygraph_sharding_stage2_bf16$|\ ^test_dygraph_sharding_stage3_bf16$|\ ^test_conv3d_op$" diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index a89dafff96ab6..b1a19e118e7e4 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -284,7 +284,6 @@ 'test_depthwise_conv_mkldnn_pass', 'test_fleet_metric', 'test_fc_fuse_pass_cc', - 'test_fleet_private_function', 'test_fleet', 'test_executor_check_feed', 'test_py_reader_lod_level_share', @@ -2121,7 +2120,6 @@ 'test_dgc_optimizer', 'heter_server_test', 'test_custom_conj', - 'test_fleet_private_function', 'test_fake_init_op', 'brpc_service_sparse_sgd_test', 'test_tf32_cudnn',