diff --git a/backends/npu/kernels/adam_kernel.cc b/backends/npu/kernels/adam_kernel.cc index 400a6b2ad7d..72289153d7d 100644 --- a/backends/npu/kernels/adam_kernel.cc +++ b/backends/npu/kernels/adam_kernel.cc @@ -233,29 +233,38 @@ void CastFP64toFP32Kernel(const Context& dev_ctx, } template -void AdamKernel(const Context& dev_ctx, - const phi::DenseTensor& param, - const phi::DenseTensor& grad, - const phi::DenseTensor& learning_rate, - const phi::DenseTensor& moment1, - const phi::DenseTensor& moment2, - const phi::DenseTensor& beta1_pow_in, - const phi::DenseTensor& beta2_pow_in, - const paddle::optional& master_param, - const paddle::optional& skip_update, - const phi::Scalar& beta1_in, - const phi::Scalar& beta2_in, - const phi::Scalar& epsilon_in, - bool lazy_mode, - int64_t min_row_size_to_use_multithread, - bool multi_precision, - bool use_global_beta_pow, - phi::DenseTensor* param_out, - phi::DenseTensor* moment1_out, - phi::DenseTensor* moment2_out, - phi::DenseTensor* beta1_pow_out, - phi::DenseTensor* beta2_pow_out, - phi::DenseTensor* master_param_out) { +void AdamKernel( + const Context& dev_ctx, + const phi::DenseTensor& param, + const phi::DenseTensor& grad, + const phi::DenseTensor& learning_rate, + const phi::DenseTensor& moment1, + const phi::DenseTensor& moment2, + const paddle::optional& moment2_max, // UNUSED + const phi::DenseTensor& beta1_pow_in, + const phi::DenseTensor& beta2_pow_in, + const paddle::optional& master_param, + const paddle::optional& skip_update, + const phi::Scalar& beta1_in, + const phi::Scalar& beta2_in, + const phi::Scalar& epsilon_in, + bool lazy_mode, + int64_t min_row_size_to_use_multithread, + bool multi_precision, + bool use_global_beta_pow, + bool amsgrad, // UNUSED + phi::DenseTensor* param_out, + phi::DenseTensor* moment1_out, + phi::DenseTensor* moment2_out, + phi::DenseTensor* moment2_max_out, // UNUSED + phi::DenseTensor* beta1_pow_out, + phi::DenseTensor* beta2_pow_out, + phi::DenseTensor* master_param_out) { + PADDLE_ENFORCE_NE( + amsgrad, + true, + phi::errors::Unimplemented("Operation amsgrad is not supported yet.")); + bool skip_update_ = false; if (skip_update.is_initialized()) { PADDLE_ENFORCE_EQ(skip_update->numel(), @@ -358,32 +367,41 @@ void AdamKernel(const Context& dev_ctx, } template -void AdamwKernel(const Context& dev_ctx, - const phi::DenseTensor& param, - const phi::DenseTensor& grad, - const phi::DenseTensor& learning_rate, - const phi::DenseTensor& moment1, - const phi::DenseTensor& moment2, - const phi::DenseTensor& beta1_pow, - const phi::DenseTensor& beta2_pow, - const paddle::optional& master_param, - const paddle::optional& skip_update, - const phi::Scalar& beta1, - const phi::Scalar& beta2, - const phi::Scalar& epsilon, - float lr_ratio, - float coeff, - bool with_decay, - bool lazy_mode, - int64_t min_row_size_to_use_multithread, - bool multi_precision, - bool use_global_beta_pow, - phi::DenseTensor* param_out, - phi::DenseTensor* moment1_out, - phi::DenseTensor* moment2_out, - phi::DenseTensor* beta1_pow_out, - phi::DenseTensor* beta2_pow_out, - phi::DenseTensor* master_param_outs) { +void AdamwKernel( + const Context& dev_ctx, + const phi::DenseTensor& param, + const phi::DenseTensor& grad, + const phi::DenseTensor& learning_rate, + const phi::DenseTensor& moment1, + const phi::DenseTensor& moment2, + const paddle::optional& moment2_max, // UNUSED + const phi::DenseTensor& beta1_pow, + const phi::DenseTensor& beta2_pow, + const paddle::optional& master_param, + const paddle::optional& skip_update, + const phi::Scalar& beta1, + const phi::Scalar& beta2, + const phi::Scalar& epsilon, + float lr_ratio, + float coeff, + bool with_decay, + bool lazy_mode, + int64_t min_row_size_to_use_multithread, + bool multi_precision, + bool use_global_beta_pow, + bool amsgrad, // UNUSED + phi::DenseTensor* param_out, + phi::DenseTensor* moment1_out, + phi::DenseTensor* moment2_out, + phi::DenseTensor* moment2_max_out, // UNUSED + phi::DenseTensor* beta1_pow_out, + phi::DenseTensor* beta2_pow_out, + phi::DenseTensor* master_param_outs) { + PADDLE_ENFORCE_NE( + amsgrad, + true, + phi::errors::Unimplemented("Operation amsgrad is not supported yet.")); + using MPDType = typename phi::dtype::MPTypeTrait::Type; bool skip_update_ = false; @@ -514,18 +532,19 @@ PD_REGISTER_PLUGIN_KERNEL(adam, float, double) { // Skip beta1_pow, beta2_pow, skip_update data transform - kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND); if (kernel_key.dtype() == phi::DataType::FLOAT16) { kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); } - kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED); } PD_REGISTER_PLUGIN_KERNEL(adamw, @@ -537,16 +556,17 @@ PD_REGISTER_PLUGIN_KERNEL(adamw, float, double) { // Skip beta1_pow, beta2_pow, skip_update data transform - kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND); if (kernel_key.dtype() == phi::DataType::FLOAT16) { kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); } - kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED); } diff --git a/backends/npu/tests/unittests/test_adam_op_npu.py b/backends/npu/tests/unittests/test_adam_op_npu.py index 5037e27f007..d919a23662d 100644 --- a/backends/npu/tests/unittests/test_adam_op_npu.py +++ b/backends/npu/tests/unittests/test_adam_op_npu.py @@ -54,7 +54,9 @@ def adam_step(inputs, attributes): moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon)) - return param_out, moment1_out, moment2_out + + moment2_max_out = np.empty_like(moment2_out) + return param_out, moment1_out, moment2_out, moment2_max_out class TestAdam(OpTest): @@ -67,6 +69,8 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype) # The second moment is positive moment2 = np.random.random((102, 105)).astype(self.dtype) + # npu not support amsgrad, `moment2_max` is useless + moment2_max = np.zeros((102, 105)).astype(self.dtype) learning_rate = 0.004 beta1 = 0.78 @@ -80,18 +84,27 @@ def setUp(self): "Grad": grad, "Moment1": moment1, "Moment2": moment2, + "Moment2Max": moment2_max, "LearningRate": np.array([learning_rate]).astype(self.dtype), "Beta1Pow": np.array([beta1_pow]).astype(self.dtype), "Beta2Pow": np.array([beta2_pow]).astype(self.dtype), } - self.attrs = {"epsilon": epsilon, "beta1": beta1, "beta2": beta2} + self.attrs = { + "epsilon": epsilon, + "beta1": beta1, + "beta2": beta2, + "amsgrad": False, + } - param_out, moment1_out, moment2_out = adam_step(self.inputs, self.attrs) + param_out, moment1_out, moment2_out, moment2_max_out = adam_step( + self.inputs, self.attrs + ) self.outputs = { "Moment1Out": moment1_out, "Moment2Out": moment2_out, + "Moment2MaxOut": moment2_max_out, "ParamOut": param_out, "Beta1PowOut": np.array([beta1_pow]).astype("float32") * beta1, "Beta2PowOut": np.array([beta2_pow]).astype("float32") * beta2, @@ -104,7 +117,9 @@ def init_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, atol=1e-5) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=1e-5 + ) @check_run_big_shape_test() @@ -119,6 +134,8 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, self.shape).astype(self.dtype) # The second moment is positive moment2 = np.random.random(self.shape).astype(self.dtype) + # npu not support amsgrad, `moment2_max` is useless + moment2_max = np.zeros(self.shape).astype(self.dtype) learning_rate = 0.004 beta1 = 0.78 @@ -132,18 +149,27 @@ def setUp(self): "Grad": grad, "Moment1": moment1, "Moment2": moment2, + "Moment2Max": moment2_max, "LearningRate": np.array([learning_rate]).astype(self.dtype), "Beta1Pow": np.array([beta1_pow]).astype(self.dtype), "Beta2Pow": np.array([beta2_pow]).astype(self.dtype), } - self.attrs = {"epsilon": epsilon, "beta1": beta1, "beta2": beta2} + self.attrs = { + "epsilon": epsilon, + "beta1": beta1, + "beta2": beta2, + "amsgrad": False, + } - param_out, moment1_out, moment2_out = adam_step(self.inputs, self.attrs) + param_out, moment1_out, moment2_out, moment2_max_out = adam_step( + self.inputs, self.attrs + ) self.outputs = { "Moment1Out": moment1_out, "Moment2Out": moment2_out, + "Moment2MaxOut": moment2_max_out, "ParamOut": param_out, "Beta1PowOut": np.array([beta1_pow]).astype("float32") * beta1, "Beta2PowOut": np.array([beta2_pow]).astype("float32") * beta2, @@ -159,7 +185,9 @@ def init_shape(self): self.shape = (4000, 8192) def test_check_output(self): - self.check_output_with_place(self.place, atol=1e-5) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=1e-5 + ) @check_run_big_shape_test() @@ -202,6 +230,9 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype) # The second moment is positive moment2 = np.random.random((102, 105)).astype(self.dtype) + # npu not support amsgrad, `moment2_max` is useless + moment2_max = np.zeros((102, 105)).astype(self.dtype) + learning_rate = 0.004 beta1 = 0.78 beta2 = 0.836 @@ -214,6 +245,7 @@ def setUp(self): "Grad": grad, "Moment1": moment1, "Moment2": moment2, + "Moment2Max": moment2_max, "LearningRate": np.array([learning_rate]).astype(self.dtype), "Beta1Pow": np.array([beta1_pow]).astype(self.dtype), "Beta2Pow": np.array([beta2_pow]).astype(self.dtype), @@ -224,11 +256,14 @@ def setUp(self): self.attrs = {"epsilon": epsilon} - param_out, moment1_out, moment2_out = adam_step(self.inputs, self.attrs) + param_out, moment1_out, moment2_out, moment2_max_out = adam_step( + self.inputs, self.attrs + ) self.outputs = { "Moment1Out": moment1_out, "Moment2Out": moment2_out, + "Moment2MaxOut": moment2_max_out, "ParamOut": param_out, "Beta1PowOut": np.array([beta1_pow]).astype("float32") * beta1, "Beta2PowOut": np.array([beta2_pow]).astype("float32") * beta2, @@ -241,7 +276,9 @@ def init_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, atol=1e-5) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=1e-5 + ) @unittest.skip(reason="disable_ut in Paddle CI") @@ -255,6 +292,8 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype) # The second moment is positive moment2 = np.random.random((102, 105)).astype(self.dtype) + # npu not support amsgrad, `moment2_max` is useless + moment2_max = np.zeros((102, 105)).astype(self.dtype) learning_rate = 0.004 beta1 = 0.78 @@ -268,6 +307,7 @@ def setUp(self): "Grad": grad, "Moment1": moment1, "Moment2": moment2, + "Moment2Max": moment2_max, "LearningRate": np.array([learning_rate]).astype(self.dtype), "Beta1Pow": np.array([beta1_pow]).astype(self.dtype), "Beta2Pow": np.array([beta2_pow]).astype(self.dtype), @@ -277,11 +317,12 @@ def setUp(self): "SkipUpdate": np.array([True]).astype("bool"), } - self.attrs = {"epsilon": epsilon} + self.attrs = {"epsilon": epsilon, "amsgrad": False} self.outputs = { "Moment1Out": moment1, "Moment2Out": moment2, + "Moment2MaxOut": moment2_max, "ParamOut": param, "Beta1PowOut": self.inputs["Beta1Pow"], "Beta2PowOut": self.inputs["Beta2Pow"], @@ -294,7 +335,9 @@ def init_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, atol=1e-5) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=1e-5 + ) class TestAdamOpWithGlobalBetaPow(OpTest): @@ -307,6 +350,8 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype) # The second moment is positive moment2 = np.random.random((102, 105)).astype(self.dtype) + # npu not support amsgrad, `moment2_max` is useless + moment2_max = np.zeros((102, 105)).astype(self.dtype) learning_rate = 0.004 beta1 = 0.78 @@ -320,6 +365,7 @@ def setUp(self): "Grad": grad, "Moment1": moment1, "Moment2": moment2, + "Moment2Max": moment2_max, "LearningRate": np.array([learning_rate]).astype(self.dtype), "Beta1Pow": np.array([beta1_pow]).astype(self.dtype), "Beta2Pow": np.array([beta2_pow]).astype(self.dtype), @@ -328,16 +374,19 @@ def setUp(self): "EpsilonTensor": np.array([epsilon]).astype(self.dtype), } - attributes = {"epsilon": epsilon} + attributes = {"epsilon": epsilon, "amsgrad": False} - param_out, moment1_out, moment2_out = adam_step(self.inputs, attributes) + param_out, moment1_out, moment2_out, moment2_max_out = adam_step( + self.inputs, attributes + ) - self.attrs = {"use_global_beta_pow": True} + self.attrs = {"use_global_beta_pow": True, "amsgrad": False} # use_global_beta_pow=True, Beta1PowOut and Beta2PowOut are empty. self.outputs = { "Moment1Out": moment1_out, "Moment2Out": moment2_out, + "Moment2MaxOut": moment2_max_out, "ParamOut": param_out, "Beta1PowOut": np.array([]), "Beta2PowOut": np.array([]), @@ -350,7 +399,9 @@ def init_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, atol=1e-5) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=1e-5 + ) def create_test_fp64_class(parent): @@ -359,7 +410,9 @@ def init_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output_with_place(self.place, atol=1e-5, rtol=1e-4) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=1e-5, rtol=1e-4 + ) cls_name = "{0}_{1}".format(parent.__name__, "Fp64") TestAdamOpFp64Case.__name__ = cls_name @@ -378,7 +431,9 @@ def init_dtype(self): self.dtype = np.float16 def test_check_output(self): - self.check_output_with_place(self.place, atol=1e-5, rtol=1e-4) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=1e-5, rtol=1e-4 + ) cls_name = "{0}_{1}".format(parent.__name__, "Fp16") TestAdamOpFp16Case.__name__ = cls_name diff --git a/backends/npu/tests/unittests/test_adam_op_npu_eager.py b/backends/npu/tests/unittests/test_adam_op_npu_eager.py index 6af5c256ed7..7970f85d58c 100644 --- a/backends/npu/tests/unittests/test_adam_op_npu_eager.py +++ b/backends/npu/tests/unittests/test_adam_op_npu_eager.py @@ -53,7 +53,9 @@ def adam_step(inputs, attributes): moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon)) - return param_out, moment1_out, moment2_out + + moment2_max_out = np.empty_like(moment2_out) + return param_out, moment1_out, moment2_out, moment2_max_out class TestAdam(OpTest): @@ -74,6 +76,8 @@ def setUp(self): moment2 = convert_float_to_uint16( np.random.random((102, 105)).astype(self.dtype) ) + # npu not support amsgrad, `moment2_max` is useless + moment2_max = convert_float_to_uint16(np.zeros((102, 105)).astype(self.dtype)) learning_rate = 0.004 beta1 = 0.78 @@ -87,6 +91,7 @@ def setUp(self): "Grad": grad, "Moment1": moment1, "Moment2": moment2, + "Moment2Max": moment2_max, "LearningRate": convert_float_to_uint16( np.array([learning_rate]).astype(self.dtype) ), @@ -98,13 +103,21 @@ def setUp(self): ), } - self.attrs = {"epsilon": epsilon, "beta1": beta1, "beta2": beta2} + self.attrs = { + "epsilon": epsilon, + "beta1": beta1, + "beta2": beta2, + "amsgrad": False, + } - param_out, moment1_out, moment2_out = adam_step(self.inputs, self.attrs) + param_out, moment1_out, moment2_out, moment2_max_out = adam_step( + self.inputs, self.attrs + ) self.outputs = { "Moment1Out": moment1_out, "Moment2Out": moment2_out, + "Moment2MaxOut": moment2_max_out, "ParamOut": param_out, "Beta1PowOut": np.array([beta1_pow]).astype("float32") * beta1, "Beta2PowOut": np.array([beta2_pow]).astype("float32") * beta2, @@ -118,7 +131,9 @@ def init_dtype(self): @check_soc_version def test_check_output(self): - self.check_output_with_place(self.place, atol=4e-3) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=4e-3 + ) class TestAdamWithEpsilonTensor(OpTest): @@ -139,6 +154,8 @@ def setUp(self): moment2 = convert_float_to_uint16( np.random.random((102, 105)).astype(self.dtype) ) + # npu not support amsgrad, `moment2_max` is useless + moment2_max = convert_float_to_uint16(np.zeros((102, 105)).astype(self.dtype)) learning_rate = 0.004 beta1 = 0.78 @@ -152,6 +169,7 @@ def setUp(self): "Grad": grad, "Moment1": moment1, "Moment2": moment2, + "Moment2Max": moment2_max, "LearningRate": convert_float_to_uint16( np.array([learning_rate]).astype(self.dtype) ), @@ -172,13 +190,16 @@ def setUp(self): ), } - self.attrs = {"epsilon": epsilon} + self.attrs = {"epsilon": epsilon, "amsgrad": False} - param_out, moment1_out, moment2_out = adam_step(self.inputs, self.attrs) + param_out, moment1_out, moment2_out, moment2_max_out = adam_step( + self.inputs, self.attrs + ) self.outputs = { "Moment1Out": moment1_out, "Moment2Out": moment2_out, + "Moment2MaxOut": moment2_max_out, "ParamOut": param_out, "Beta1PowOut": np.array([beta1_pow]).astype("float32") * beta1, "Beta2PowOut": np.array([beta2_pow]).astype("float32") * beta2, @@ -192,7 +213,9 @@ def init_dtype(self): @check_soc_version def test_check_output(self): - self.check_output_with_place(self.place, atol=4e-3) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=4e-3 + ) class TestAdamOpWithSkipUpdate(OpTest): @@ -213,6 +236,8 @@ def setUp(self): moment2 = convert_float_to_uint16( np.random.random((102, 105)).astype(self.dtype) ) + # npu not support amsgrad, `moment2_max` is useless + moment2_max = convert_float_to_uint16(np.zeros((102, 105)).astype(self.dtype)) learning_rate = 0.004 beta1 = 0.78 @@ -226,6 +251,7 @@ def setUp(self): "Grad": grad, "Moment1": moment1, "Moment2": moment2, + "Moment2Max": moment2_max, "LearningRate": convert_float_to_uint16( np.array([learning_rate]).astype(self.dtype) ), @@ -247,11 +273,12 @@ def setUp(self): "SkipUpdate": np.array([True]).astype("bool"), } - self.attrs = {"epsilon": epsilon} + self.attrs = {"epsilon": epsilon, "amsgrad": False} self.outputs = { "Moment1Out": moment1, "Moment2Out": moment2, + "Moment2MaxOut": moment2_max, "ParamOut": param, "Beta1PowOut": self.inputs["Beta1Pow"], "Beta2PowOut": self.inputs["Beta2Pow"], @@ -265,7 +292,9 @@ def init_dtype(self): @check_soc_version def test_check_output(self): - self.check_output_with_place(self.place, atol=4e-3) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=4e-3 + ) class TestAdamOpWithGlobalBetaPow(OpTest): @@ -286,6 +315,8 @@ def setUp(self): moment2 = convert_float_to_uint16( np.random.random((102, 105)).astype(self.dtype) ) + # npu not support amsgrad, `moment2_max` is useless + moment2_max = convert_float_to_uint16(np.zeros((102, 105)).astype(self.dtype)) learning_rate = 0.004 beta1 = 0.78 @@ -299,6 +330,7 @@ def setUp(self): "Grad": grad, "Moment1": moment1, "Moment2": moment2, + "Moment2Max": moment2_max, "LearningRate": convert_float_to_uint16( np.array([learning_rate]).astype(self.dtype) ), @@ -319,9 +351,11 @@ def setUp(self): ), } - attributes = {"epsilon": epsilon} + attributes = {"epsilon": epsilon, "amsgrad": False} - param_out, moment1_out, moment2_out = adam_step(self.inputs, attributes) + param_out, moment1_out, moment2_out, moment2_max_out = adam_step( + self.inputs, attributes + ) self.attrs = {"use_global_beta_pow": True} @@ -329,6 +363,7 @@ def setUp(self): self.outputs = { "Moment1Out": moment1_out, "Moment2Out": moment2_out, + "Moment2MaxOut": moment2_max_out, "ParamOut": param_out, "Beta1PowOut": np.array([]), "Beta2PowOut": np.array([]), @@ -342,7 +377,9 @@ def init_dtype(self): @check_soc_version def test_check_output(self): - self.check_output_with_place(self.place, atol=4e-3) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=4e-3 + ) if __name__ == "__main__": diff --git a/backends/npu/tests/unittests/test_adamw_op_npu.py b/backends/npu/tests/unittests/test_adamw_op_npu.py index da78258b2a3..a340b354e32 100644 --- a/backends/npu/tests/unittests/test_adamw_op_npu.py +++ b/backends/npu/tests/unittests/test_adamw_op_npu.py @@ -58,7 +58,8 @@ def adamw_step(inputs, attributes): lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon)) - return param_out, moment1_out, moment2_out + moment2_max_out = np.empty_like(moment2_out) + return param_out, moment1_out, moment2_out, moment2_max_out class TestAdamW(OpTest): @@ -71,6 +72,8 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (105, 102)).astype("float32") # The second moment is positive moment2 = np.random.random((105, 102)).astype("float32") + # npu not support amsgrad, `moment2_max` is useless + moment2_max = np.zeros((105, 102)).astype("float32") learning_rate = 0.5 beta1 = 0.78 @@ -84,6 +87,7 @@ def setUp(self): "Grad": grad, "Moment1": moment1, "Moment2": moment2, + "Moment2Max": moment2_max, "LearningRate": np.array([learning_rate]).astype("float32"), "Beta1Pow": np.array([beta1_pow]).astype("float32"), "Beta2Pow": np.array([beta2_pow]).astype("float32"), @@ -95,13 +99,17 @@ def setUp(self): "beta2": beta2, "coeff": 0.9, "with_decay": True, + "amsgrad": False, } - param_out, moment1_out, moment2_out = adamw_step(self.inputs, self.attrs) + param_out, moment1_out, moment2_out, moment2_max_out = adamw_step( + self.inputs, self.attrs + ) self.outputs = { "Moment1Out": moment1_out, "Moment2Out": moment2_out, + "Moment2MaxOut": moment2_max_out, "ParamOut": param_out, "Beta1PowOut": np.array([beta1_pow]).astype("float32") * beta1, "Beta2PowOut": np.array([beta2_pow]).astype("float32") * beta2, @@ -114,7 +122,9 @@ def init_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, atol=1e-5) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=1e-5 + ) class TestAdamOpWithSkipUpdate(OpTest): @@ -127,6 +137,8 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") # The second moment is positive moment2 = np.random.random((102, 105)).astype("float32") + # npu not support amsgrad, `moment2_max` is useless + moment2_max = np.zeros((102, 105)).astype("float32") learning_rate = 0.004 beta1 = 0.78 @@ -140,6 +152,7 @@ def setUp(self): "Grad": grad, "Moment1": moment1, "Moment2": moment2, + "Moment2Max": moment2_max, "LearningRate": np.array([learning_rate]).astype("float32"), "Beta1Pow": np.array([beta1_pow]).astype("float32"), "Beta2Pow": np.array([beta2_pow]).astype("float32"), @@ -149,11 +162,17 @@ def setUp(self): "SkipUpdate": np.array([True]).astype("bool"), } - self.attrs = {"epsilon": epsilon, "coeff": 0.02, "with_decay": True} + self.attrs = { + "epsilon": epsilon, + "coeff": 0.02, + "with_decay": True, + "amsgrad": False, + } self.outputs = { "Moment1Out": moment1, "Moment2Out": moment2, + "Moment2MaxOut": moment2_max, "ParamOut": param, "Beta1PowOut": self.inputs["Beta1Pow"], "Beta2PowOut": self.inputs["Beta2Pow"], @@ -166,7 +185,9 @@ def init_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, atol=1e-5) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=1e-5 + ) class TestAdamOpWithoutDecay(OpTest): @@ -179,6 +200,8 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") # The second moment is positive moment2 = np.random.random((102, 105)).astype("float32") + # npu not support amsgrad, `moment2_max` is useless + moment2_max = np.zeros((102, 105)).astype("float32") learning_rate = 0.004 beta1 = 0.78 @@ -192,6 +215,7 @@ def setUp(self): "Grad": grad, "Moment1": moment1, "Moment2": moment2, + "Moment2Max": moment2_max, "LearningRate": np.array([learning_rate]).astype("float32"), "Beta1Pow": np.array([beta1_pow]).astype("float32"), "Beta2Pow": np.array([beta2_pow]).astype("float32"), @@ -201,11 +225,17 @@ def setUp(self): "SkipUpdate": np.array([True]).astype("bool"), } - self.attrs = {"epsilon": epsilon, "coeff": 0.02, "with_decay": False} + self.attrs = { + "epsilon": epsilon, + "coeff": 0.02, + "with_decay": False, + "amsgrad": False, + } self.outputs = { "Moment1Out": moment1, "Moment2Out": moment2, + "Moment2MaxOut": moment2_max, "ParamOut": param, "Beta1PowOut": self.inputs["Beta1Pow"], "Beta2PowOut": self.inputs["Beta2Pow"], @@ -218,7 +248,9 @@ def init_dtype(self): self.dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, atol=1e-5) + self.check_output_with_place( + no_check_set=["Moment2MaxOut"], place=self.place, atol=1e-5 + ) class TestNet(unittest.TestCase): diff --git a/python/tests/white_list/no_check_set_white_list.py b/python/tests/white_list/no_check_set_white_list.py index a7db6333131..fe08d8ef240 100644 --- a/python/tests/white_list/no_check_set_white_list.py +++ b/python/tests/white_list/no_check_set_white_list.py @@ -37,4 +37,6 @@ "class_center_sample", "einsum", "rmsprop", + "adam", # AMSGrad variant no check moment2 max output + "adamw", # AMSGrad variant no check moment2 max output ]