Skip to content

Commit

Permalink
support sparse of adam, *test=kunlun (PaddlePaddle#38483)
Browse files Browse the repository at this point in the history
* support sparse of adam, *test=kunlun

* add pre-commit-config.yaml

* support sparse of adam in KL2,*test=kunlun

* support sparse of adam in KL2, *test=kunlun

* modify xpu.cmake, *test=kunlun

* support sparse of adam, rm some wait, *test=kunlun

* support sparse of adam, rm some wait, *test=kunlun

* support sparse of adam, *test=kunlun

* support sparse of adam, *test=kunlun

* support sparse of adam, *test=kunlun

* support sparse of adam, *test=kunlun

* support sparse of adam, *test=kunlun
  • Loading branch information
helen88 authored Jan 24, 2022
1 parent c379606 commit e106901
Show file tree
Hide file tree
Showing 4 changed files with 411 additions and 9 deletions.
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ ENDIF()

if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220104")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220116")
else()
SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif()
Expand Down
153 changes: 153 additions & 0 deletions paddle/fluid/operators/math/selected_rows_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,155 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
}
};

#ifdef PADDLE_WITH_XPU
template <typename T>
struct MergeAdd<platform::XPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::XPUDeviceContext& context,
const framework::SelectedRows& input,
const bool sorted_result = false) {
framework::SelectedRows out;
(*this)(context, input, &out, sorted_result);
return out;
}

void operator()(const platform::XPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output,
const bool sorted_result = false) {
framework::Vector<int64_t> input_rows(input.rows());
if (input_rows.size() == 0) {
return;
}

framework::SelectedRows& out = *output;
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto input_width = input.value().dims()[1];

out.set_rows(merge_rows);
out.set_height(input.height());
out.mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
int r =
xpu::constant<T>(context.x_context(), out.mutable_value()->data<T>(),
merge_rows.size() * input_width, static_cast<T>(0.f));
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU constant op return"
" wrong value[%d %s].",
r, XPUAPIErrorMsg[r]));

std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i;
}

auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>();
int n = input_width;
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]];
auto r = xpu::add(context.x_context(), &input_data[i * input_width],
&out_data[out_i * input_width],
&out_data[out_i * input_width], n);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API return wrong value[%d %s], ", r,
XPUAPIErrorMsg[r]));
}
}

void operator()(const platform::XPUDeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output,
const bool sorted_result = false) {
if (inputs.size() == 0) {
VLOG(3) << "no input! return";
return;
}
const framework::SelectedRows* has_value_input = nullptr;
for (auto* in : inputs) {
if (in->rows().size() > 0) {
has_value_input = in;
break;
}
}
if (has_value_input == nullptr) {
VLOG(3) << "no input has value! just return" << std::endl;
return;
}
auto input_width = has_value_input->value().dims()[1];
auto input_height = has_value_input->height();
framework::SelectedRows& out = *output;
std::set<int64_t> merged_row_set;
size_t row_num = 0;
for (auto* input : inputs) {
if (input->rows().size() == 0) {
continue;
}
PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
platform::errors::InvalidArgument(
"All inputs should have same "
"dimension except for the first one."));
PADDLE_ENFORCE_EQ(input_height, input->height(),
platform::errors::InvalidArgument(
"All inputs should have same height."));
row_num += input->rows().size();
merged_row_set.insert(input->rows().begin(), input->rows().end());
}

std::vector<int64_t> merge_rows(merged_row_set.begin(),
merged_row_set.end());

if (sorted_result) {
std::sort(merge_rows.begin(), merge_rows.end());
}

out.set_rows(merge_rows);
out.set_height(input_height);
out.mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merged_row_set.size()), input_width}),
context.GetPlace());

int r =
xpu::constant<T>(context.x_context(), out.mutable_value()->data<T>(),
merge_rows.size() * input_width, static_cast<T>(0.f));
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU constant op return"
" wrong value[%d %s].",
r, XPUAPIErrorMsg[r]));

float* out_data = reinterpret_cast<float*>(out.mutable_value()->data<T>());

std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i;
}

for (auto* input : inputs) {
if (input->rows().size() == 0) {
continue;
}
auto& input_rows = input->rows();

int n = input_width;
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]];
auto r = xpu::add(
context.x_context(), input->value().data<T>() + i * input_width,
&out_data[out_i * input_width], &out_data[out_i * input_width], n);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API return wrong value[%d %s], ", r,
XPUAPIErrorMsg[r]));
}
}
}
};

#endif
template <typename T>
struct MergeAverage<platform::CPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
Expand Down Expand Up @@ -589,6 +738,10 @@ template struct MergeAdd<platform::CPUDeviceContext,
template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::bfloat16>;

#ifdef PADDLE_WITH_XPU
template struct MergeAdd<platform::XPUDeviceContext, float>;
#endif

template struct MergeAverage<platform::CPUDeviceContext, int>;
template struct MergeAverage<platform::CPUDeviceContext, int64_t>;
template struct MergeAverage<platform::CPUDeviceContext, float>;
Expand Down
127 changes: 119 additions & 8 deletions paddle/fluid/operators/optimizers/adam_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/operators/optimizers/adam_op.h"
#include "gflags/gflags.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -155,6 +156,11 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
mom2_out.template mutable_data<float>(ctx.GetPlace()),
param_out.template mutable_data<float>(ctx.GetPlace()),
beta1, beta2, epsilon, param.numel());

xpu_wait(dev_ctx.x_context()->xpu_stream);
PADDLE_ENFORCE_EQ(
r == xpu::Error_t::SUCCESS, true,
platform::errors::External("XPU API return wrong value[%d],", r));
if (!use_global_beta_pow) {
// update in cpu and then copy to xpu
if (beta1_pow.place() == platform::CPUPlace() &&
Expand All @@ -165,7 +171,6 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
const float* beta2_pow_p = beta2_pow.template data<float>();
beta2_pow_out->mutable_data<float>(platform::CPUPlace())[0] =
beta2 * beta2_pow_p[0];
xpu_wait(dev_ctx.x_context()->xpu_stream);
} else {
float* beta1_pow_out_p =
beta1_pow_out->mutable_data<float>(ctx.GetPlace());
Expand All @@ -177,23 +182,129 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
r, xpu::SUCCESS,
platform::errors::External(
"XPU kernel scale occur error in adamw error code ", r,
"XPU kernel scale occur error in adam error code ", r,
XPUAPIErrorMsg[r]));
r = xpu::scale(dev_ctx.x_context(), beta2_pow_ptr, beta2_pow_out_p,
beta2_pow.numel(), false, beta2, 0.0f);
PADDLE_ENFORCE_EQ(
r, xpu::SUCCESS,
platform::errors::External(
"XPU kernel scale occur error in adamw error code ", r,
"XPU kernel scale occur error in adam error code ", r,
XPUAPIErrorMsg[r]));

xpu_wait(dev_ctx.x_context()->xpu_stream);
}
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
auto& dev_ctx = ctx.template device_context<DeviceContext>();

if (grad->rows().size() == 0) {
VLOG(3) << "grad row size is 0!!";
return;
}

std::vector<int64_t> cpu_rows(grad->rows().begin(), grad->rows().end());
bool is_strict_sorted = true;
for (size_t i = 1; i < cpu_rows.size(); ++i) {
if (cpu_rows[i - 1] >= cpu_rows[i]) {
is_strict_sorted = false;
break;
}
}

framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows* grad_merge_ptr;
if (is_strict_sorted) {
grad_merge_ptr = grad;
} else {
scatter::MergeAdd<platform::XPUDeviceContext, T> merge_func;
merge_func(ctx.template device_context<platform::XPUDeviceContext>(),
*grad, &tmp_grad_merge, true);

xpu_wait(dev_ctx.x_context()->xpu_stream);
grad_merge_ptr = &tmp_grad_merge;
}
const T* beta1_pow_ptr = beta1_pow.template data<T>();
const T* beta2_pow_ptr = beta2_pow.template data<T>();
Tensor xpu_beta1_pow;
Tensor xpu_beta2_pow;
if (beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
paddle::framework::TensorCopy(beta1_pow, ctx.GetPlace(), dev_ctx,
&xpu_beta1_pow);
paddle::framework::TensorCopy(beta2_pow, ctx.GetPlace(), dev_ctx,
&xpu_beta2_pow);
dev_ctx.Wait();
beta1_pow_ptr = xpu_beta1_pow.template data<T>();
beta2_pow_ptr = xpu_beta2_pow.template data<T>();
}
auto& grad_merge = *grad_merge_ptr;
auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>();
int row_count = grad_merge.rows().size();
std::vector<int> rows(row_count);
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* xpu_rows = RAII_GUARD.alloc_l3_or_gm<int>(row_count);
std::vector<int64_t> merge_rows(grad_merge.rows().begin(),
grad_merge.rows().end());
for (size_t i = 0; i < grad_merge.rows().size(); ++i) {
rows[i] = static_cast<int>(merge_rows[i]);
}
xpu_wait(dev_ctx.x_context()->xpu_stream);
memory::Copy(ctx.GetPlace(), xpu_rows, platform::CPUPlace(), rows.data(),
row_count * sizeof(int));
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
auto ori_rows = param.numel() / row_numel;

PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::External(
"XPU API return wrong value[%d], please check "
"where Baidu Kunlun Card is properly installed.",
r));
int lazy_mode = static_cast<int>(ctx.Attr<bool>("lazy_mode"));
int r = xpu::sparse_adam(
dev_ctx.x_context(), grad_data, mom1.template data<T>(),
mom2.template data<T>(), param.template data<T>(), beta1_pow_ptr,
beta2_pow_ptr, lr.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace()), beta1, beta2,
epsilon, ori_rows, xpu_rows, row_numel, grad_merge.rows().size(),
lazy_mode);

PADDLE_ENFORCE_EQ(
r == xpu::Error_t::SUCCESS, true,
platform::errors::External("XPU API return wrong value[%d],", r));

if (!use_global_beta_pow) {
// update in cpu and then copy to xpu
if (beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
const float* beta1_pow_p = beta1_pow.template data<float>();
beta1_pow_out->mutable_data<float>(platform::CPUPlace())[0] =
beta1 * beta1_pow_p[0];
const float* beta2_pow_p = beta2_pow.template data<float>();
beta2_pow_out->mutable_data<float>(platform::CPUPlace())[0] =
beta2 * beta2_pow_p[0];
} else {
float* beta1_pow_out_p =
beta1_pow_out->mutable_data<float>(ctx.GetPlace());
float* beta2_pow_out_p =
beta2_pow_out->mutable_data<float>(ctx.GetPlace());
int r =
xpu::scale(dev_ctx.x_context(), beta1_pow_ptr, beta1_pow_out_p,
beta1_pow.numel(), false, beta1, 0.0f);
PADDLE_ENFORCE_EQ(
r, xpu::SUCCESS,
platform::errors::External(
"XPU kernel scale occur error in adam error code ", r,
XPUAPIErrorMsg[r]));
r = xpu::scale(dev_ctx.x_context(), beta2_pow_ptr, beta2_pow_out_p,
beta2_pow.numel(), false, beta2, 0.0f);
PADDLE_ENFORCE_EQ(
r, xpu::SUCCESS,
platform::errors::External(
"XPU kernel scale occur error in adam error code ", r,
XPUAPIErrorMsg[r]));
}
}
xpu_wait(dev_ctx.x_context()->xpu_stream);
} else {
PADDLE_ENFORCE_EQ(1, 2, platform::errors::InvalidArgument(
"Variable type not supported by adam_op"));
Expand Down
Loading

0 comments on commit e106901

Please sign in to comment.