@@ -233,29 +233,38 @@ void CastFP64toFP32Kernel(const Context& dev_ctx,
233233}
234234
235235template <typename T, typename Context>
236- void AdamKernel (const Context& dev_ctx,
237- const phi::DenseTensor& param,
238- const phi::DenseTensor& grad,
239- const phi::DenseTensor& learning_rate,
240- const phi::DenseTensor& moment1,
241- const phi::DenseTensor& moment2,
242- const phi::DenseTensor& beta1_pow_in,
243- const phi::DenseTensor& beta2_pow_in,
244- const paddle::optional<phi::DenseTensor>& master_param,
245- const paddle::optional<phi::DenseTensor>& skip_update,
246- const phi::Scalar& beta1_in,
247- const phi::Scalar& beta2_in,
248- const phi::Scalar& epsilon_in,
249- bool lazy_mode,
250- int64_t min_row_size_to_use_multithread,
251- bool multi_precision,
252- bool use_global_beta_pow,
253- phi::DenseTensor* param_out,
254- phi::DenseTensor* moment1_out,
255- phi::DenseTensor* moment2_out,
256- phi::DenseTensor* beta1_pow_out,
257- phi::DenseTensor* beta2_pow_out,
258- phi::DenseTensor* master_param_out) {
236+ void AdamKernel (
237+ const Context& dev_ctx,
238+ const phi::DenseTensor& param,
239+ const phi::DenseTensor& grad,
240+ const phi::DenseTensor& learning_rate,
241+ const phi::DenseTensor& moment1,
242+ const phi::DenseTensor& moment2,
243+ const paddle::optional<phi::DenseTensor>& moment2_max, // UNUSED
244+ const phi::DenseTensor& beta1_pow_in,
245+ const phi::DenseTensor& beta2_pow_in,
246+ const paddle::optional<phi::DenseTensor>& master_param,
247+ const paddle::optional<phi::DenseTensor>& skip_update,
248+ const phi::Scalar& beta1_in,
249+ const phi::Scalar& beta2_in,
250+ const phi::Scalar& epsilon_in,
251+ bool lazy_mode,
252+ int64_t min_row_size_to_use_multithread,
253+ bool multi_precision,
254+ bool use_global_beta_pow,
255+ bool amsgrad, // UNUSED
256+ phi::DenseTensor* param_out,
257+ phi::DenseTensor* moment1_out,
258+ phi::DenseTensor* moment2_out,
259+ phi::DenseTensor* moment2_max_out, // UNUSED
260+ phi::DenseTensor* beta1_pow_out,
261+ phi::DenseTensor* beta2_pow_out,
262+ phi::DenseTensor* master_param_out) {
263+ PADDLE_ENFORCE_NE (
264+ amsgrad,
265+ true ,
266+ phi::errors::Unimplemented (" Operation amsgrad is not supported yet." ));
267+
259268 bool skip_update_ = false ;
260269 if (skip_update.is_initialized ()) {
261270 PADDLE_ENFORCE_EQ (skip_update->numel (),
@@ -358,32 +367,41 @@ void AdamKernel(const Context& dev_ctx,
358367}
359368
360369template <typename T, typename Context>
361- void AdamwKernel (const Context& dev_ctx,
362- const phi::DenseTensor& param,
363- const phi::DenseTensor& grad,
364- const phi::DenseTensor& learning_rate,
365- const phi::DenseTensor& moment1,
366- const phi::DenseTensor& moment2,
367- const phi::DenseTensor& beta1_pow,
368- const phi::DenseTensor& beta2_pow,
369- const paddle::optional<phi::DenseTensor>& master_param,
370- const paddle::optional<phi::DenseTensor>& skip_update,
371- const phi::Scalar& beta1,
372- const phi::Scalar& beta2,
373- const phi::Scalar& epsilon,
374- float lr_ratio,
375- float coeff,
376- bool with_decay,
377- bool lazy_mode,
378- int64_t min_row_size_to_use_multithread,
379- bool multi_precision,
380- bool use_global_beta_pow,
381- phi::DenseTensor* param_out,
382- phi::DenseTensor* moment1_out,
383- phi::DenseTensor* moment2_out,
384- phi::DenseTensor* beta1_pow_out,
385- phi::DenseTensor* beta2_pow_out,
386- phi::DenseTensor* master_param_outs) {
370+ void AdamwKernel (
371+ const Context& dev_ctx,
372+ const phi::DenseTensor& param,
373+ const phi::DenseTensor& grad,
374+ const phi::DenseTensor& learning_rate,
375+ const phi::DenseTensor& moment1,
376+ const phi::DenseTensor& moment2,
377+ const paddle::optional<phi::DenseTensor>& moment2_max, // UNUSED
378+ const phi::DenseTensor& beta1_pow,
379+ const phi::DenseTensor& beta2_pow,
380+ const paddle::optional<phi::DenseTensor>& master_param,
381+ const paddle::optional<phi::DenseTensor>& skip_update,
382+ const phi::Scalar& beta1,
383+ const phi::Scalar& beta2,
384+ const phi::Scalar& epsilon,
385+ float lr_ratio,
386+ float coeff,
387+ bool with_decay,
388+ bool lazy_mode,
389+ int64_t min_row_size_to_use_multithread,
390+ bool multi_precision,
391+ bool use_global_beta_pow,
392+ bool amsgrad, // UNUSED
393+ phi::DenseTensor* param_out,
394+ phi::DenseTensor* moment1_out,
395+ phi::DenseTensor* moment2_out,
396+ phi::DenseTensor* moment2_max_out, // UNUSED
397+ phi::DenseTensor* beta1_pow_out,
398+ phi::DenseTensor* beta2_pow_out,
399+ phi::DenseTensor* master_param_outs) {
400+ PADDLE_ENFORCE_NE (
401+ amsgrad,
402+ true ,
403+ phi::errors::Unimplemented (" Operation amsgrad is not supported yet." ));
404+
387405 using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
388406
389407 bool skip_update_ = false ;
@@ -514,18 +532,19 @@ PD_REGISTER_PLUGIN_KERNEL(adam,
514532 float ,
515533 double ) {
516534 // Skip beta1_pow, beta2_pow, skip_update data transform
517- kernel->InputAt (5 ).SetBackend (phi::Backend::ALL_BACKEND);
518535 kernel->InputAt (6 ).SetBackend (phi::Backend::ALL_BACKEND);
519- kernel->InputAt (8 ).SetBackend (phi::Backend::ALL_BACKEND);
536+ kernel->InputAt (7 ).SetBackend (phi::Backend::ALL_BACKEND);
537+ kernel->InputAt (9 ).SetBackend (phi::Backend::ALL_BACKEND);
520538 if (kernel_key.dtype () == phi::DataType::FLOAT16) {
521539 kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32);
522540 kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32);
523541 kernel->OutputAt (3 ).SetDataType (phi::DataType::FLOAT32);
524542 kernel->OutputAt (4 ).SetDataType (phi::DataType::FLOAT32);
525543 kernel->OutputAt (5 ).SetDataType (phi::DataType::FLOAT32);
544+ kernel->OutputAt (6 ).SetDataType (phi::DataType::FLOAT32);
526545 }
527- kernel->OutputAt (3 ).SetBackend (phi::Backend::UNDEFINED);
528546 kernel->OutputAt (4 ).SetBackend (phi::Backend::UNDEFINED);
547+ kernel->OutputAt (5 ).SetBackend (phi::Backend::UNDEFINED);
529548}
530549
531550PD_REGISTER_PLUGIN_KERNEL (adamw,
@@ -537,16 +556,17 @@ PD_REGISTER_PLUGIN_KERNEL(adamw,
537556 float ,
538557 double ) {
539558 // Skip beta1_pow, beta2_pow, skip_update data transform
540- kernel->InputAt (5 ).SetBackend (phi::Backend::ALL_BACKEND);
541559 kernel->InputAt (6 ).SetBackend (phi::Backend::ALL_BACKEND);
542- kernel->InputAt (8 ).SetBackend (phi::Backend::ALL_BACKEND);
560+ kernel->InputAt (7 ).SetBackend (phi::Backend::ALL_BACKEND);
561+ kernel->InputAt (9 ).SetBackend (phi::Backend::ALL_BACKEND);
543562 if (kernel_key.dtype () == phi::DataType::FLOAT16) {
544563 kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32);
545564 kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32);
546565 kernel->OutputAt (3 ).SetDataType (phi::DataType::FLOAT32);
547566 kernel->OutputAt (4 ).SetDataType (phi::DataType::FLOAT32);
548567 kernel->OutputAt (5 ).SetDataType (phi::DataType::FLOAT32);
568+ kernel->OutputAt (6 ).SetDataType (phi::DataType::FLOAT32);
549569 }
550- kernel->OutputAt (3 ).SetBackend (phi::Backend::UNDEFINED);
551570 kernel->OutputAt (4 ).SetBackend (phi::Backend::UNDEFINED);
571+ kernel->OutputAt (5 ).SetBackend (phi::Backend::UNDEFINED);
552572}
0 commit comments