Skip to content

Commit

Permalink
Merge branch 'PaddlePaddle:develop' into onecyclelr
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll authored May 13, 2022
2 parents ebb04e2 + 1280f29 commit a58cab2
Show file tree
Hide file tree
Showing 292 changed files with 8,935 additions and 2,155 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ option(WITH_CUSTOM_DEVICE "Compile with custom device support" OFF)
option(WITH_ARM_BRPC "Supprot Brpc in Arm" OFF)

if(WITH_RECORD_BUILDTIME)
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_CURRENT_SOURCE_DIR}/tools/get_build_time.sh")
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "${CMAKE_CURRENT_SOURCE_DIR}/tools/get_build_time.sh")
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_CURRENT_SOURCE_DIR}/tools/get_build_time.sh ${CMAKE_CURRENT_BINARY_DIR}")
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "${CMAKE_CURRENT_SOURCE_DIR}/tools/get_build_time.sh ${CMAKE_CURRENT_BINARY_DIR}")
else()
include(ccache) # set ccache for compilation ; if WITH_RECORD_BUILDTIME=ON can't use ccache
endif()
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ PaddlePaddle is originated from industrial practices with dedication and commitm

## Installation

### Latest PaddlePaddle Release: [v2.2](https://github.com/PaddlePaddle/Paddle/tree/release/2.2)
### Latest PaddlePaddle Release: [v2.3](https://github.com/PaddlePaddle/Paddle/tree/release/2.3)

Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest features of PaddlePaddle.
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/mkldnn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ SET(MKLDNN_PREFIX_DIR ${THIRD_PARTY_PATH}/mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
SET(MKLDNN_REPOSITORY ${GIT_URL}/oneapi-src/oneDNN.git)
SET(MKLDNN_TAG 9a35435c18722ff17a48fb60bceac42bfdf78754)
SET(MKLDNN_TAG 9b186765dded79066e0cd9c17eb70b680b76fb8e)


# Introduce variables:
Expand Down
4 changes: 2 additions & 2 deletions cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ SET(XPU_RT_LIB_NAME "libxpurt.so")

if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220425")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220510")
else()
SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif()

# ubuntu and centos: use output by XDNN API team
if(NOT DEFINED XPU_XDNN_BASE_URL)
SET(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220425")
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220510")
else()
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif()
Expand Down
59 changes: 55 additions & 4 deletions paddle/fluid/distributed/collective/reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,9 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,

dev_ctx->Wait();

Tensor src_value_tensor(std::make_shared<phi::DenseTensor>(src->value()));
std::vector<int64_t> dst_shape = src_value_tensor.shape();

if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + size_,
[&](int64_t row) { return row == cpu_rows_num_ptr[0]; })) {
// During sparse communication, the number of each card is same.
Expand Down Expand Up @@ -940,8 +943,6 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
&dst_rows_vector);
dev_ctx->Wait();

Tensor src_value_tensor(std::make_shared<phi::DenseTensor>(src->value()));
std::vector<int64_t> dst_shape = src_value_tensor.shape();
dst_shape[dst_shape.size() - 2] = rows_num;
auto dst_dense_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
paddle::experimental::full(IntArray(dst_shape), 0,
Expand Down Expand Up @@ -971,8 +972,58 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
*(src->mutable_value()) =
*(std::dynamic_pointer_cast<phi::DenseTensor>(dst_value_tensor.impl()));
} else {
PADDLE_THROW(
platform::errors::Unimplemented("This case is not supported."));
std::vector<Tensor> rows_tensors;
std::vector<Tensor> values_tensors;

for (int i = 0; i < size_; ++i) {
std::vector<int64_t> value_tensor_shape = {
cpu_rows_num_ptr[i], dst_shape[dst_shape.size() - 1]};
Tensor rows_tensor = paddle::experimental::full(
IntArray({static_cast<int64_t>(cpu_rows_num_ptr[i])}), 0,
DataType::INT64, inner_place_);
Tensor values_tensor = paddle::experimental::full(
IntArray(value_tensor_shape), 0, src->value().dtype(), inner_place_);
std::vector<phi::DenseTensor> rows_dense_vector;
std::vector<phi::DenseTensor> values_dense_vector;

if (i == rank_) {
auto *rows_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(rows_tensor.impl())
.get();
framework::TensorFromVector<int64_t>(src_rows, *dev_ctx,
rows_dense_tensor);
values_tensor.set_impl(
std::make_shared<phi::DenseTensor>(src->value()));
}
rows_dense_vector.push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(rows_tensor.impl()));
values_dense_vector.push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(values_tensor.impl()));

auto b_opts = BroadcastOptions();
b_opts.source_rank = i;
process_group_->Broadcast(rows_dense_vector, rows_dense_vector, b_opts);
process_group_
->Broadcast(values_dense_vector, values_dense_vector, b_opts)
->Wait();
rows_tensors.push_back(rows_tensor);
values_tensors.push_back(values_tensor);
}

Tensor dst_rows_tensor =
paddle::experimental::concat(rows_tensors, phi::Scalar(0));
framework::Vector<int64_t> dst_rows_vector(rows_num, 0);
auto *dst_rows_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(dst_rows_tensor.impl())
.get();
framework::TensorToVector<int64_t>(*dst_rows_dense_tensor, *dev_ctx,
&dst_rows_vector);
src->set_rows(dst_rows_vector);

Tensor dst_values_tensor =
paddle::experimental::concat(values_tensors, phi::Scalar(0));
*(src->mutable_value()) = *(
std::dynamic_pointer_cast<phi::DenseTensor>(dst_values_tensor.impl()));
}
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/ps/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class PSClient {
size_t table_id) = 0; // 保留

// firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold
// this is necessary because dense weight initialized in trainer on cold
// start
virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num,
Expand Down
17 changes: 2 additions & 15 deletions paddle/fluid/eager/accumulation/accumulation_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,9 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
*tensor = t;
} else {
// Accumulation
PADDLE_ENFORCE_EQ(t.initialized(), true,
paddle::platform::errors::Fatal(
"We can only accumulate initialized tensor, but we "
"got tensor: %s is empty please check you network "
"and make sure it creates grads.",
t.name()));
PADDLE_ENFORCE_NOT_NULL(
tensor, paddle::platform::errors::Fatal(
"We can only accumulate initialized tensor to non-nullptr "
"tensor but we got nullptr please check you network "
"and make sure it creates grads."));

if (t.is_dense_tensor()) {
if (tensor->is_dense_tensor()) {
if (LIKELY(t.is_dense_tensor())) {
if (LIKELY(tensor->is_dense_tensor())) {
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(t, tensor);

} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/eager/api/utils/global_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class Controller {
op_meta_info_map_.insert(map.begin(), map.end());
}

std::unordered_map<std::string, std::vector<std::unordered_map<int, int>>>&
std::unordered_map<std::string,
std::vector<std::vector<std::unordered_map<int, int>>>>&
GetCustomEdgesSlotMap() {
return custom_edges_slot_map_;
}
Expand All @@ -89,8 +90,10 @@ class Controller {
new paddle::imperative::Tracer()};
std::unordered_map<std::string, std::vector<paddle::OpMetaInfo>>
op_meta_info_map_;
/* op_type : {{grad_outputs}, {grad_inputs}, {input}, {output}, {attrs}}*/
std::unordered_map<std::string, std::vector<std::unordered_map<int, int>>>
/* op_type : {{{grad_outputs}, {grad_inputs}, {input}, {output}, {attrs}},
* {{grad_outputs}, {grad_inputs}, {input}, {output}, {attrs}}}*/
std::unordered_map<std::string,
std::vector<std::vector<std::unordered_map<int, int>>>>
custom_edges_slot_map_;
DISABLE_COPY_AND_ASSIGN(Controller);
};
Expand Down
31 changes: 10 additions & 21 deletions paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1156,28 +1156,20 @@ static std::string GenerateGradNodeCreationContent(
for (const auto& iter : op_base_infos) {
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map =
iter.GetGradInsFwdSlotnameMap();
const std::unordered_set<std::string>& no_need_buffer_ins =
iter.GetNoNeedBufferInputs();
for (auto& kv : grad_ins_fwd_slotname_map) {
const std::string& tensor_wrapper_name = kv.second;
std::string full_reserved = "false";
if (fwd_outputs_name_pos_map.find(tensor_wrapper_name) ==
fwd_outputs_name_pos_map.end() &&
!no_need_buffer_ins.count(tensor_wrapper_name)) {
full_reserved = "true";
}
const char* SET_TENSOR_WRAPPER_TEMPLATE =
" grad_node->SetTensorWrapper%s(%s, %s);\n";
" grad_node->SetTensorWrapper%s(%s);\n";
// Replace output directly with input in inplace op.
if (!inplace_map.empty() && inplace_map.count(tensor_wrapper_name)) {
auto inplace_input_name = inplace_map[tensor_wrapper_name];
grad_node_creation_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name),
LegalizeVarName(inplace_input_name), full_reserved);
LegalizeVarName(inplace_input_name));
} else {
grad_node_creation_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name),
LegalizeVarName(tensor_wrapper_name), full_reserved);
LegalizeVarName(tensor_wrapper_name));
}
}
}
Expand Down Expand Up @@ -2592,7 +2584,6 @@ static std::string GenerateGradNodeHeaderContents(

std::string tensor_wrapper_arg_str;
std::string tensor_wrapper_body_str;
std::string full_reserved_str = "full_reserved";
std::string no_need_buffer_str = "false";
if (no_need_buffer_ins.count(tensor_wrapper_name)) {
no_need_buffer_str = "true";
Expand All @@ -2610,12 +2601,12 @@ static std::string GenerateGradNodeHeaderContents(

const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"for(const auto& eager_tensor : %s) {\n"
" %s.emplace_back( egr::TensorWrapper(eager_tensor, %s "
"/*full_reserved*/, %s) );\n"
" %s.emplace_back( egr::TensorWrapper(eager_tensor "
", %s) );\n"
" }\n";
tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name,
struct_tensor_wrapper_name, full_reserved_str, no_need_buffer_str);
struct_tensor_wrapper_name, no_need_buffer_str);

const char* CLEAR_TENSOR_WRAPPER_TEMPLATE =
"for (auto tw: %s) {\n"
Expand All @@ -2636,22 +2627,20 @@ static std::string GenerateGradNodeHeaderContents(
TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name);

const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"%s = egr::TensorWrapper(%s, %s /*full_reserved*/, %s);\n";
"%s = egr::TensorWrapper(%s, %s);\n";
tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name,
tensor_wrapper_name, full_reserved_str, no_need_buffer_str);
tensor_wrapper_name, no_need_buffer_str);

const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = " %s.clear();\n";
clear_tensor_wrappers_str += paddle::string::Sprintf(
CLEAR_TENSOR_WRAPPER_TEMPLATE, struct_tensor_wrapper_name);
}
std::string full_reserved_signature_str = "bool full_reserved";
const char* SET_TENSOR_WRAPPER_TEMPLATE =
" void SetTensorWrapper%s(%s, %s) {\n %s\n }\n";
" void SetTensorWrapper%s(%s) {\n %s\n }\n";
set_tensor_wrappers_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name,
tensor_wrapper_arg_str, full_reserved_signature_str,
tensor_wrapper_body_str);
tensor_wrapper_arg_str, tensor_wrapper_body_str);
}
}
VLOG(6) << "Generated TensorWrapper";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ add_custom_target(eager_final_state_codegen
COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py"
"--api_yaml_path=${api_yaml_path}"
"--backward_yaml_path=${backward_yaml_path}"
"--forwards_cc_path=${tmp_forwards_cc_path}"
"--forwards_cc_path=${tmp_forwards_cc_path}"
"--forwards_h_path=${tmp_forwards_h_path}"
"--nodes_cc_path=${tmp_nodes_cc_path}"
"--nodes_cc_path=${tmp_nodes_cc_path}"
"--nodes_h_path=${tmp_nodes_h_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_forwards_cc_path} ${forwards_cc_path}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_forwards_h_path} ${forwards_h_path}
Expand Down
Loading

1 comment on commit a58cab2

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.