Skip to content

Commit

Permalink
Merge pull request #5 from PaddlePaddle/develop
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
esythan authored Sep 26, 2021
2 parents a0edcd4 + 3fabc80 commit e1f0559
Show file tree
Hide file tree
Showing 615 changed files with 45,679 additions and 6,520 deletions.
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ project(paddle CXX C)
# enable language CUDA
# TODO(Shibo Tao): remove find_package(CUDA) completely.
find_package(CUDA QUIET)
find_package(MKL CONFIG QUIET)
option(WITH_ONEMKL "Compile PaddlePaddle with oneMKL" OFF)
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF)
option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)
Expand Down Expand Up @@ -225,6 +227,7 @@ option(WITH_STRIP "Strip so files of Whl packages" OFF)
option(NEW_RELEASE_CUBIN "PaddlePaddle next-level release strategy for pypi cubin package" OFF)
option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF)
option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF)
option(WITH_POCKETFFT "Compile with pocketfft support" ON)

# PY_VERSION
if(NOT PY_VERSION)
Expand Down Expand Up @@ -373,6 +376,10 @@ if (WITH_MIPS)
add_definitions(-DPADDLE_WITH_MIPS)
endif()

if (WITH_ONEMKL)
add_definitions(-DPADDLE_WITH_ONEMKL)
endif()

if (WITH_HETERPS)
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new")
Expand Down
2 changes: 1 addition & 1 deletion cmake/FindGperftools.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
find_library(GPERFTOOLS_TCMALLOC
NAMES tcmalloc
HINTS ${Gperftools_ROOT_DIR}/lib)

find_library(GPERFTOOLS_PROFILER
NAMES profiler
HINTS ${Gperftools_ROOT_DIR}/lib)
Expand Down
69 changes: 69 additions & 0 deletions cmake/external/lapack.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

INCLUDE (ExternalProject)

SET(LAPACK_PREFIX_DIR ${THIRD_PARTY_PATH}/lapack)
SET(LAPACK_SOURCE_DIR ${THIRD_PARTY_PATH}/lapack/src/extern_lapack)
SET(LAPACK_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lapack)
SET(LAPACK_INCLUDE_DIR ${LAPACK_SOURCE_DIR})
SET(LAPACK_LIB_DIR ${LAPACK_INSTALL_DIR}/lib)

# Note(zhouwei): lapack need fortan compiler which many machines don't have, so use precompiled library.
# use lapack tag v3.10.0 on 06/28/2021 https://github.com/Reference-LAPACK/lapack
if(LINUX)
SET(LAPACK_VER "lapack_lnx_v3.10.0.20210628" CACHE STRING "" FORCE)
SET(LAPACK_URL "https://paddlepaddledeps.bj.bcebos.com/${LAPACK_VER}.tar.gz" CACHE STRING "" FORCE)
SET(LAPACK_URL_MD5 71f8cc8237a8571692f3e07f9a4f25f6)
SET(GNU_RT_LIB_1 "${LAPACK_LIB_DIR}/libquadmath.so.0")
SET(GFORTRAN_LIB "${LAPACK_LIB_DIR}/libgfortran.so.3")
SET(BLAS_LIB "${LAPACK_LIB_DIR}/libblas.so.3")
SET(LAPACK_LIB "${LAPACK_LIB_DIR}/liblapack.so.3")
elseif(WIN32)
# Refer to [lapack-for-windows] http://icl.cs.utk.edu/lapack-for-windows/lapack/#lapacke
SET(LAPACK_VER "lapack_win_v3.10.0.20210628" CACHE STRING "" FORCE)
SET(LAPACK_URL "https://paddlepaddledeps.bj.bcebos.com/${LAPACK_VER}.zip" CACHE STRING "" FORCE)
SET(LAPACK_URL_MD5 590d080392dcd5abbd5dca767a50b63a)
SET(GNU_RT_LIB_1 "${LAPACK_LIB_DIR}/libquadmath-0.dll")
SET(GNU_RT_LIB_2 "${LAPACK_LIB_DIR}/libgcc_s_seh-1.dll")
SET(GFORTRAN_LIB "${LAPACK_LIB_DIR}/libgfortran-3.dll")
SET(BLAS_LIB "${LAPACK_LIB_DIR}/libblas.dll")
SET(LAPACK_LIB "${LAPACK_LIB_DIR}/liblapack.dll")
else()
SET(LAPACK_VER "lapack_mac_v3.10.0.20210628" CACHE STRING "" FORCE)
SET(LAPACK_URL "https://paddlepaddledeps.bj.bcebos.com/${LAPACK_VER}.tar.gz" CACHE STRING "" FORCE)
SET(LAPACK_URL_MD5 427aecf8dee8523de3566ca8e47944d7)
SET(GNU_RT_LIB_1 "${LAPACK_LIB_DIR}/libquadmath.0.dylib")
SET(GNU_RT_LIB_2 "${LAPACK_LIB_DIR}/libgcc_s.1.dylib")
SET(GFORTRAN_LIB "${LAPACK_LIB_DIR}/libgfortran.5.dylib")
SET(BLAS_LIB "${LAPACK_LIB_DIR}/libblas.3.dylib")
SET(LAPACK_LIB "${LAPACK_LIB_DIR}/liblapack.3.dylib")
endif()

ExternalProject_Add(
extern_lapack
${EXTERNAL_PROJECT_LOG_ARGS}
URL ${LAPACK_URL}
URL_MD5 ${LAPACK_URL_MD5}
PREFIX ${LAPACK_PREFIX_DIR}
DOWNLOAD_DIR ${LAPACK_SOURCE_DIR}
DOWNLOAD_NO_PROGRESS 1
PATCH_COMMAND ""
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory ${LAPACK_SOURCE_DIR} ${LAPACK_LIB_DIR}
BUILD_BYPRODUCTS ${BLAS_LIB}
BUILD_BYPRODUCTS ${LAPACK_LIB}
)
33 changes: 32 additions & 1 deletion cmake/external/lite.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,22 @@ if (LITE_WITH_XPU)
ENDIF()
endif()

if (LITE_WITH_NNADAPTER)
add_definitions(-DLITE_SUBGRAPH_WITH_NNADAPTER)
if (NNADAPTER_WITH_HUAWEI_ASCEND_NPU)
add_definitions(-DLITE_SUBGRAPH_WITH_NPU)
set(NPU_SDK_ROOT "/usr/local/Ascend/ascend-toolkit/latest" CACHE STRING "default NPU SDK ROOT")
endif()
endif()

if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
include(ExternalProject)
set(LITE_PROJECT extern_lite)
set(LITE_SOURCES_DIR ${THIRD_PARTY_PATH}/lite)
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)

if(NOT LITE_GIT_TAG)
set(LITE_GIT_TAG d3a3a6931b6d22d504d21ba32b3ae972770e9204)
set(LITE_GIT_TAG 4ab64daecc11fbf74fffdc6a4733f388472e7d5d)
endif()

if(NOT CUDA_ARCH_NAME)
Expand All @@ -67,6 +75,9 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
-DLITE_WITH_XPU=${LITE_WITH_XPU}
-DXPU_SDK_URL=${XPU_BASE_URL}
-DXPU_SDK_ENV=${XPU_SDK_ENV}
-DLITE_WITH_NNADAPTER=${LITE_WITH_NNADAPTER}
-DNNADAPTER_WITH_HUAWEI_ASCEND_NPU=${NNADAPTER_WITH_HUAWEI_ASCEND_NPU}
-DNNADAPTER_HUAWEI_ASCEND_NPU_SDK_ROOT=${NPU_SDK_ROOT}
-DLITE_WITH_CODE_META_INFO=OFF
-DLITE_WITH_ARM=ON)
ExternalProject_Add(
Expand Down Expand Up @@ -110,6 +121,9 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
-DLITE_WITH_XPU=${LITE_WITH_XPU}
-DXPU_SDK_URL=${XPU_BASE_URL}
-DXPU_SDK_ENV=${XPU_SDK_ENV}
-DLITE_WITH_NNADAPTER=${LITE_WITH_NNADAPTER}
-DNNADAPTER_WITH_HUAWEI_ASCEND_NPU=${NNADAPTER_WITH_HUAWEI_ASCEND_NPU}
-DNNADAPTER_HUAWEI_ASCEND_NPU_SDK_ROOT=${NPU_SDK_ROOT}
-DLITE_WITH_CODE_META_INFO=OFF
-DLITE_WITH_ARM=OFF)

Expand All @@ -120,6 +134,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
GIT_TAG ${LITE_GIT_TAG}
PREFIX ${LITE_SOURCES_DIR}
UPDATE_COMMAND ""
PATCH_COMMAND sed -i "s?NNadapter_bridges_path = os.path.abspath('..')+\"\/lite\/kernels\/nnadapter\/bridges\/paddle_use_bridges.h\"?NNadapter_bridges_path = os.path.abspath(\'..\')+\"\/extern_lite\/lite\/kernels\/nnadapter\/bridges\/paddle_use_bridges.h\"?" ${LITE_SOURCES_DIR}/src/extern_lite//lite/tools/cmake_tools/record_supported_kernel_op.py && sed -i "/general::ssa::ConvertToSSA(cpp_prog)$<SEMICOLON>/d" ${LITE_SOURCES_DIR}/src/extern_lite/lite/model_parser/model_parser.cc
BUILD_COMMAND ${LITE_BUILD_COMMAND}
INSTALL_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
Expand All @@ -146,6 +161,11 @@ endif()
if (WITH_ARM)
if(LITE_WITH_XPU)
set(LITE_OUTPUT_BIN_DIR inference_lite_lib.armlinux.armv8.xpu)
elseif(LITE_WITH_NNADAPTER)
message("Enable LITE_WITH_NNADAPTER")
if (NNADAPTER_WITH_HUAWEI_ASCEND_NPU)
set(LITE_OUTPUT_BIN_DIR inference_lite_lib.armlinux.armv8.nnadapter)
endif()
else()
set(LITE_OUTPUT_BIN_DIR inference_lite_lib.armlinux.armv8)
endif()
Expand Down Expand Up @@ -174,5 +194,16 @@ endfunction()
external_lite_libs(lite_full_static ${LITE_BINARY_DIR}/${LITE_OUTPUT_BIN_DIR}/cxx/lib/libpaddle_full_api_shared.so)
set(LITE_SHARED_LIB ${LITE_BINARY_DIR}/${LITE_OUTPUT_BIN_DIR}/cxx/lib/libpaddle_full_api_shared.so)

if (LITE_WITH_NNADAPTER)
set(LITE_NNADAPTER_LIB ${LITE_BINARY_DIR}/${LITE_OUTPUT_BIN_DIR}/cxx/lib/libnnadapter.so)
if (NNADAPTER_WITH_HUAWEI_ASCEND_NPU)
external_lite_libs(lite_nnadapter ${LITE_BINARY_DIR}/${LITE_OUTPUT_BIN_DIR}/cxx/lib/libnnadapter.so ${LITE_BINARY_DIR}/${LITE_OUTPUT_BIN_DIR}/cxx/lib/libhuawei_ascend_npu.so)
set(LITE_DEPS lite_full_static lite_nnadapter)
set(LITE_NNADAPTER_NPU_LIB ${LITE_BINARY_DIR}/${LITE_OUTPUT_BIN_DIR}/cxx/lib/libhuawei_ascend_npu.so)
endif()
else()
set(LITE_DEPS lite_full_static)
endif()

add_definitions(-DPADDLE_WITH_LITE)
add_definitions(-DLITE_WITH_LOG)
44 changes: 44 additions & 0 deletions cmake/external/pocketfft.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

include(ExternalProject)


set(POCKETFFT_PATH "${THIRD_PARTY_PATH}/pocketfft" CACHE STRING "A path setting for external_pocketfft path.")
set(POCKETFFT_PREFIX_DIR ${POCKETFFT_PATH})

set(POCKETFFT_REPOSITORY https://gitlab.mpcdf.mpg.de/mtr/pocketfft.git)
set(POCKETFFT_TAG release_for_eigen)

SET(POCKETFFT_INCLUDE_DIR ${POCKETFFT_PREFIX_DIR}/src)
message("POCKETFFT_INCLUDE_DIR is ${POCKETFFT_INCLUDE_DIR}")
include_directories(${POCKETFFT_INCLUDE_DIR})

ExternalProject_Add(
extern_pocketfft
${EXTERNAL_PROJECT_LOG_ARGS}
${SHALLOW_CLONE}
GIT_REPOSITORY ${POCKETFFT_REPOSITORY}
GIT_TAG ${POCKETFFT_TAG}
PREFIX ${POCKETFFT_PREFIX_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)

add_library(pocketfft INTERFACE)

add_dependencies(pocketfft extern_pocketfft)
2 changes: 1 addition & 1 deletion cmake/external/xbyak.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ExternalProject_Add(
DEPENDS ""
PREFIX ${XBYAK_PREFIX_DIR}
SOURCE_DIR ${XBYAK_SOURCE_DIR}
# UPDATE_COMMAND ""
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${XBYAK_INSTALL_ROOT}
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${XBYAK_INSTALL_ROOT}
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ELSE ()
ENDIF()

SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210909")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210921")
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
Expand Down
1 change: 1 addition & 0 deletions cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ function(op_library TARGET)
list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu")
list(REMOVE_ITEM hip_srcs "svd_op.cu")
list(REMOVE_ITEM hip_srcs "eigh_op.cu")
list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
Expand Down
9 changes: 8 additions & 1 deletion cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,10 @@ include(external/threadpool)# download threadpool
include(external/dlpack) # download dlpack
include(external/xxhash) # download, build, install xxhash
include(external/warpctc) # download, build, install warpctc
include(external/lapack) # download, build, install lapack

list(APPEND third_party_deps extern_eigen3 extern_gflags extern_glog extern_boost extern_xxhash)
list(APPEND third_party_deps extern_zlib extern_dlpack extern_warpctc extern_threadpool)
list(APPEND third_party_deps extern_zlib extern_dlpack extern_warpctc extern_threadpool extern_lapack)

include(cblas) # find first, then download, build, install openblas

Expand Down Expand Up @@ -361,4 +362,10 @@ if (WITH_CRYPTO)
add_definitions(-DPADDLE_WITH_CRYPTO)
endif (WITH_CRYPTO)

if (WITH_POCKETFFT)
include(external/pocketfft)
list(APPEND third_party_deps extern_pocketfft)
add_definitions(-DPADDLE_WITH_POCKETFFT)
endif (WITH_POCKETFFT)

add_custom_target(third_party ALL DEPENDS ${third_party_deps})
Binary file added log
Binary file not shown.
2 changes: 0 additions & 2 deletions paddle/fluid/extension/src/ext_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
PADDLE_THROW(platform::errors::Unavailable(
"Only GPU related Copy can reach this func."));
}
cudaStreamSynchronize(dev_ctx->stream());
#elif defined(PADDLE_WITH_HIP)
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
int device_num = paddle::platform::GetCurrentDeviceId();
Expand All @@ -110,7 +109,6 @@ void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
PADDLE_THROW(platform::errors::Unavailable(
"Only GPU related Copy can reach this func."));
}
hipStreamSynchronize(dev_ctx->stream());
#else
PADDLE_THROW(platform::errors::Unavailable(
"This function can only be used if compiled with"
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_subdirectory(io)
add_subdirectory(new_executor)
#ddim lib
proto_library(framework_proto SRCS framework.proto)
proto_library(pass_desc_proto SRCS pass_desc.proto DEPS framework_proto)

proto_library(op_def_proto SRCS op_def.proto DEPS framework_proto)
cc_library(op_def_api SRCS op_def_api.cc DEPS op_def_proto boost)
Expand Down Expand Up @@ -223,6 +224,7 @@ if(WITH_PYTHON)
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
py_proto_compile(trainer_py_proto SRCS trainer_desc.proto data_feed.proto)
py_proto_compile(distributed_strategy_py_proto SRCS distributed_strategy.proto)
py_proto_compile(pass_desc_py_proto SRCS pass_desc.proto)
#Generate an empty \
#__init__.py to make framework_py_proto as a valid python module.
add_custom_target(fleet_proto_init ALL
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/framework/data_layout_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,

if ((in_format != out_format) || always_copy) {
void* in_data = GetDataFromTensor(in, in_type);
std::string key =
platform::CreateKey(*dev_ctx, in_tz, in_format, out_format, in_type);

platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx,
cpu_engine, key);
platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type,
cpu_engine);

auto reorder_src_memory_p = handler.AcquireSrcMemory(in_format, in_data);
auto reorder_dst_memory_p =
Expand Down
18 changes: 17 additions & 1 deletion paddle/fluid/framework/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <iostream>
#include <string>
#include <typeindex>

Expand Down Expand Up @@ -170,11 +171,26 @@ extern inline proto::VarType::Type ToComplexType(proto::VarType::Type t) {
return proto::VarType::COMPLEX128;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unknown complex value data type (%s), now only support float32 and "
"Unknown real value data type (%s), now only support float32 and "
"float64.",
DataTypeToString(t)));
}
}

extern inline proto::VarType::Type ToRealType(proto::VarType::Type t) {
switch (t) {
case proto::VarType::COMPLEX64:
return proto::VarType::FP32;
case proto::VarType::COMPLEX128:
return proto::VarType::FP64;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unknown complex value data type (%s), now only support complex64 "
"and "
"complex128.",
DataTypeToString(t)));
}
}

} // namespace framework
} // namespace paddle
28 changes: 28 additions & 0 deletions paddle/fluid/framework/ddim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,34 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
return os;
}

DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims) {
PADDLE_ENFORCE_GE(src.size(), 3,
platform::errors::InvalidArgument(
"The rank of src dim should be at least 3 "
"in flatten_to_3d, but received %d.",
src.size()));
PADDLE_ENFORCE_EQ((num_row_dims >= 1 && num_row_dims < src.size()), true,
platform::errors::InvalidArgument(
"The num_row_dims should be inside [1, %d] "
"in flatten_to_3d, but received %d.",
src.size() - 1, num_row_dims));
PADDLE_ENFORCE_EQ((num_col_dims >= 2 && num_col_dims <= src.size()), true,
platform::errors::InvalidArgument(
"The num_col_dims should be inside [2, %d] "
"in flatten_to_3d, but received %d.",
src.size(), num_col_dims));
PADDLE_ENFORCE_GE(
num_col_dims, num_row_dims,
platform::errors::InvalidArgument(
"The num_row_dims should be less than num_col_dims in flatten_to_3d,"
"but received num_row_dims = %d, num_col_dims = %d.",
num_row_dims, num_col_dims));

return DDim({product(slice_ddim(src, 0, num_row_dims)),
product(slice_ddim(src, num_row_dims, num_col_dims)),
product(slice_ddim(src, num_col_dims, src.size()))});
}

DDim flatten_to_2d(const DDim& src, int num_col_dims) {
return DDim({product(slice_ddim(src, 0, num_col_dims)),
product(slice_ddim(src, num_col_dims, src.size()))});
Expand Down
Loading

0 comments on commit e1f0559

Please sign in to comment.