Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 75 additions & 55 deletions backends/npu/kernels/adam_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,29 +233,38 @@ void CastFP64toFP32Kernel(const Context& dev_ctx,
}

template <typename T, typename Context>
void AdamKernel(const Context& dev_ctx,
const phi::DenseTensor& param,
const phi::DenseTensor& grad,
const phi::DenseTensor& learning_rate,
const phi::DenseTensor& moment1,
const phi::DenseTensor& moment2,
const phi::DenseTensor& beta1_pow_in,
const phi::DenseTensor& beta2_pow_in,
const paddle::optional<phi::DenseTensor>& master_param,
const paddle::optional<phi::DenseTensor>& skip_update,
const phi::Scalar& beta1_in,
const phi::Scalar& beta2_in,
const phi::Scalar& epsilon_in,
bool lazy_mode,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow,
phi::DenseTensor* param_out,
phi::DenseTensor* moment1_out,
phi::DenseTensor* moment2_out,
phi::DenseTensor* beta1_pow_out,
phi::DenseTensor* beta2_pow_out,
phi::DenseTensor* master_param_out) {
void AdamKernel(
const Context& dev_ctx,
const phi::DenseTensor& param,
const phi::DenseTensor& grad,
const phi::DenseTensor& learning_rate,
const phi::DenseTensor& moment1,
const phi::DenseTensor& moment2,
const paddle::optional<phi::DenseTensor>& moment2_max, // UNUSED
const phi::DenseTensor& beta1_pow_in,
const phi::DenseTensor& beta2_pow_in,
const paddle::optional<phi::DenseTensor>& master_param,
const paddle::optional<phi::DenseTensor>& skip_update,
const phi::Scalar& beta1_in,
const phi::Scalar& beta2_in,
const phi::Scalar& epsilon_in,
bool lazy_mode,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow,
bool amsgrad, // UNUSED
phi::DenseTensor* param_out,
phi::DenseTensor* moment1_out,
phi::DenseTensor* moment2_out,
phi::DenseTensor* moment2_max_out, // UNUSED
phi::DenseTensor* beta1_pow_out,
phi::DenseTensor* beta2_pow_out,
phi::DenseTensor* master_param_out) {
PADDLE_ENFORCE_NE(
amsgrad,
true,
phi::errors::Unimplemented("Operation amsgrad is not supported yet."));

bool skip_update_ = false;
if (skip_update.is_initialized()) {
PADDLE_ENFORCE_EQ(skip_update->numel(),
Expand Down Expand Up @@ -358,32 +367,41 @@ void AdamKernel(const Context& dev_ctx,
}

template <typename T, typename Context>
void AdamwKernel(const Context& dev_ctx,
const phi::DenseTensor& param,
const phi::DenseTensor& grad,
const phi::DenseTensor& learning_rate,
const phi::DenseTensor& moment1,
const phi::DenseTensor& moment2,
const phi::DenseTensor& beta1_pow,
const phi::DenseTensor& beta2_pow,
const paddle::optional<phi::DenseTensor>& master_param,
const paddle::optional<phi::DenseTensor>& skip_update,
const phi::Scalar& beta1,
const phi::Scalar& beta2,
const phi::Scalar& epsilon,
float lr_ratio,
float coeff,
bool with_decay,
bool lazy_mode,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow,
phi::DenseTensor* param_out,
phi::DenseTensor* moment1_out,
phi::DenseTensor* moment2_out,
phi::DenseTensor* beta1_pow_out,
phi::DenseTensor* beta2_pow_out,
phi::DenseTensor* master_param_outs) {
void AdamwKernel(
const Context& dev_ctx,
const phi::DenseTensor& param,
const phi::DenseTensor& grad,
const phi::DenseTensor& learning_rate,
const phi::DenseTensor& moment1,
const phi::DenseTensor& moment2,
const paddle::optional<phi::DenseTensor>& moment2_max, // UNUSED
const phi::DenseTensor& beta1_pow,
const phi::DenseTensor& beta2_pow,
const paddle::optional<phi::DenseTensor>& master_param,
const paddle::optional<phi::DenseTensor>& skip_update,
const phi::Scalar& beta1,
const phi::Scalar& beta2,
const phi::Scalar& epsilon,
float lr_ratio,
float coeff,
bool with_decay,
bool lazy_mode,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow,
bool amsgrad, // UNUSED
phi::DenseTensor* param_out,
phi::DenseTensor* moment1_out,
phi::DenseTensor* moment2_out,
phi::DenseTensor* moment2_max_out, // UNUSED
phi::DenseTensor* beta1_pow_out,
phi::DenseTensor* beta2_pow_out,
phi::DenseTensor* master_param_outs) {
PADDLE_ENFORCE_NE(
amsgrad,
true,
phi::errors::Unimplemented("Operation amsgrad is not supported yet."));

using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;

bool skip_update_ = false;
Expand Down Expand Up @@ -514,18 +532,19 @@ PD_REGISTER_PLUGIN_KERNEL(adam,
float,
double) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND);
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
}
kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED);
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED);
}

PD_REGISTER_PLUGIN_KERNEL(adamw,
Expand All @@ -537,16 +556,17 @@ PD_REGISTER_PLUGIN_KERNEL(adamw,
float,
double) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND);
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
}
kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED);
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED);
}
Loading