From eb0293ef9ad61552032fdc65964684d05af8f79a Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Fri, 7 May 2021 07:53:05 +0000 Subject: [PATCH 01/10] add use_global_beta_pow --- paddle/fluid/operators/optimizers/adam_op.cc | 21 +++++++- paddle/fluid/operators/optimizers/adam_op.cu | 54 +++++++++++--------- paddle/fluid/operators/optimizers/adam_op.h | 22 ++++---- 3 files changed, 64 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index a7886cdd670d4..973da19cf7fd6 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -198,6 +198,13 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) " "Whether to use multi-precision during weight updating.") .SetDefault(false); + // TODO(zhiqiu): We could set Beta1Pow, Beta2Pow, Beta1PowOut, Beta2PowOut + // as dispensable since they are not used when use_global_beta_pow is true. + AddAttr("use_global_beta_pow", + "(bool, default false) " + "Whether to use global beta_pow for whole model instead of " + "creating beta_pow for each parameter.") + .SetDefault(false); AddComment(R"DOC( Adam Optimizer. @@ -246,4 +253,16 @@ REGISTER_OP_VERSION(adam) "EpsilonTensor", "If provided, Adam will use this as epsilon, " "this has a higher priority than attr(epsilon). " - "For better performance in npu kernel. ")); + "For better performance in npu kernel. ")) + .AddCheckpoint( + R"ROC( + Upgrade adam, add 1 attribute [use_global_beta_pow]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_global_beta_pow", + "If true, Adam will use global beta_pow for whole model " + "instead of creating beta_pow for each parameter." + "In that case, the inputs(Beta1Pow, Beta12Pow) and the " + "outputs(Beta1PowOut, Beta2PowOut) will not be used in adam op, " + "and beta_pow will be updated after all adam op in the model.", + false)); diff --git a/paddle/fluid/operators/optimizers/adam_op.cu b/paddle/fluid/operators/optimizers/adam_op.cu index 3d6f0f99a52df..e3f5523694073 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cu +++ b/paddle/fluid/operators/optimizers/adam_op.cu @@ -154,6 +154,7 @@ class AdamOpCUDAKernel : public framework::OpKernel { int64_t min_row_size_to_use_multithread = ctx.Attr("min_row_size_to_use_multithread"); bool lazy_mode = ctx.Attr("lazy_mode"); + bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); auto* param = ctx.Input("Param"); auto* grad_var = ctx.InputVar("Grad"); @@ -254,11 +255,13 @@ class AdamOpCUDAKernel : public framework::OpKernel { lr->data(), grad->data(), param->data(), param_out->mutable_data(ctx.GetPlace()), master_in_data, master_out_data, param->numel()); - // Cpu update - beta1_pow_out->mutable_data(platform::CPUPlace())[0] = - beta1 * beta1_pow->data()[0]; - beta2_pow_out->mutable_data(platform::CPUPlace())[0] = - beta2 * beta2_pow->data()[0]; + if (!use_global_beta_pow) { + // Cpu update + beta1_pow_out->mutable_data(platform::CPUPlace())[0] = + beta1 * beta1_pow->data()[0]; + beta2_pow_out->mutable_data(platform::CPUPlace())[0] = + beta2 * beta2_pow->data()[0]; + } } else { AdamKernelMEM<<>>( beta1, beta2, epsilon, beta1_pow->data(), @@ -269,14 +272,15 @@ class AdamOpCUDAKernel : public framework::OpKernel { lr->data(), grad->data(), param->data(), param_out->mutable_data(ctx.GetPlace()), master_in_data, master_out_data, param->numel()); - // Update with gpu - UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( - beta1, beta2, beta1_pow->data(), - beta2_pow->data(), - beta1_pow_out->mutable_data(ctx.GetPlace()), - beta2_pow_out->mutable_data(ctx.GetPlace())); + if (!use_global_beta_pow) { + // Update with gpu + UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( + beta1, beta2, beta1_pow->data(), + beta2_pow->data(), + beta1_pow_out->mutable_data(ctx.GetPlace()), + beta2_pow_out->mutable_data(ctx.GetPlace())); + } } - } else if (grad_var->IsType()) { auto* grad = ctx.Input("Grad"); if (grad->rows().size() == 0) { @@ -328,11 +332,13 @@ class AdamOpCUDAKernel : public framework::OpKernel { param_out->mutable_data(ctx.GetPlace()), master_in_data, master_out_data, rows, row_numel, grad_merge.rows().size(), lazy_mode, ndim); - // Update with cpu - beta1_pow_out->mutable_data(platform::CPUPlace())[0] = - beta1 * beta1_pow->data()[0]; - beta2_pow_out->mutable_data(platform::CPUPlace())[0] = - beta2 * beta2_pow->data()[0]; + if (!use_global_beta_pow) { + // Update with cpu + beta1_pow_out->mutable_data(platform::CPUPlace())[0] = + beta1 * beta1_pow->data()[0]; + beta2_pow_out->mutable_data(platform::CPUPlace())[0] = + beta2 * beta2_pow->data()[0]; + } } else { SparseAdamFunctor functor( beta1, beta2, epsilon, beta1_pow->data(), @@ -351,12 +357,14 @@ class AdamOpCUDAKernel : public framework::OpKernel { ctx.device_context()), param->numel()); for_range(functor); - // update beta1 and beta2 - UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( - beta1, beta2, beta1_pow->data(), - beta2_pow->data(), - beta1_pow_out->mutable_data(ctx.GetPlace()), - beta2_pow_out->mutable_data(ctx.GetPlace())); + if (!use_global_beta_pow) { + // update beta1 and beta2 + UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( + beta1, beta2, beta1_pow->data(), + beta2_pow->data(), + beta1_pow_out->mutable_data(ctx.GetPlace()), + beta2_pow_out->mutable_data(ctx.GetPlace())); + } } } else { PADDLE_THROW(platform::errors::InvalidArgument( diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 9667db8055b90..3daae58572815 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -406,6 +406,7 @@ class AdamOpKernel : public framework::OpKernel { int64_t min_row_size_to_use_multithread = ctx.Attr("min_row_size_to_use_multithread"); bool lazy_mode = ctx.Attr("lazy_mode"); + bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); auto* param = ctx.Input("Param"); auto* grad_var = ctx.InputVar("Grad"); @@ -475,11 +476,12 @@ class AdamOpKernel : public framework::OpKernel { lr->data(), grad->data(), param->data(), param_out->mutable_data(ctx.GetPlace())); functor(param->numel()); - beta1_pow_out->mutable_data(ctx.GetPlace())[0] = - beta1 * beta1_pow->data()[0]; - beta2_pow_out->mutable_data(ctx.GetPlace())[0] = - beta2 * beta2_pow->data()[0]; - + if (!use_global_beta_pow) { + beta1_pow_out->mutable_data(ctx.GetPlace())[0] = + beta1 * beta1_pow->data()[0]; + beta2_pow_out->mutable_data(ctx.GetPlace())[0] = + beta2 * beta2_pow->data()[0]; + } } else if (grad_var->IsType()) { auto* grad = ctx.Input("Grad"); if (grad->rows().size() == 0) { @@ -523,10 +525,12 @@ class AdamOpKernel : public framework::OpKernel { param_out->mutable_data(ctx.GetPlace()), rows, row_numel, grad_merge.rows().size(), lazy_mode); // update beta1 and beta2 - beta1_pow_out->mutable_data(ctx.GetPlace())[0] = - beta1 * beta1_pow->data()[0]; - beta2_pow_out->mutable_data(ctx.GetPlace())[0] = - beta2 * beta2_pow->data()[0]; + if (!use_global_beta_pow) { + beta1_pow_out->mutable_data(ctx.GetPlace())[0] = + beta1 * beta1_pow->data()[0]; + beta2_pow_out->mutable_data(ctx.GetPlace())[0] = + beta2 * beta2_pow->data()[0]; + } if (lazy_mode) { VLOG(3) << "run cpu lazy mode"; size_t row_count = grad_merge.rows().size(); From 113898488814206108467564498a80f929daa27d Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Fri, 7 May 2021 07:57:02 +0000 Subject: [PATCH 02/10] add use_global_beta_pow --- paddle/fluid/operators/optimizers/adam_op.cu | 2 +- paddle/fluid/operators/optimizers/adam_op.h | 2 +- .../fluid/operators/optimizers/adam_op_xpu.cc | 81 ++++++++++--------- 3 files changed, 45 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.cu b/paddle/fluid/operators/optimizers/adam_op.cu index e3f5523694073..acaa28bde2020 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cu +++ b/paddle/fluid/operators/optimizers/adam_op.cu @@ -154,7 +154,7 @@ class AdamOpCUDAKernel : public framework::OpKernel { int64_t min_row_size_to_use_multithread = ctx.Attr("min_row_size_to_use_multithread"); bool lazy_mode = ctx.Attr("lazy_mode"); - bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); + bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); auto* param = ctx.Input("Param"); auto* grad_var = ctx.InputVar("Grad"); diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 3daae58572815..d4859895f1610 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -406,7 +406,7 @@ class AdamOpKernel : public framework::OpKernel { int64_t min_row_size_to_use_multithread = ctx.Attr("min_row_size_to_use_multithread"); bool lazy_mode = ctx.Attr("lazy_mode"); - bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); + bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); auto* param = ctx.Input("Param"); auto* grad_var = ctx.InputVar("Grad"); diff --git a/paddle/fluid/operators/optimizers/adam_op_xpu.cc b/paddle/fluid/operators/optimizers/adam_op_xpu.cc index 09f117374499b..f229cdcd093f8 100644 --- a/paddle/fluid/operators/optimizers/adam_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_xpu.cc @@ -73,6 +73,8 @@ class AdamOpXPUKernel : public framework::OpKernel { "value is:%d.", beta2_pow_out->numel())); + bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); + T beta1 = static_cast(ctx.Attr("beta1")); if (ctx.HasInput("Beta1Tensor")) { auto* beta1_tensor = ctx.Input("Beta1Tensor"); @@ -111,45 +113,48 @@ class AdamOpXPUKernel : public framework::OpKernel { mom1_out.template mutable_data(ctx.GetPlace()), mom2_out.template mutable_data(ctx.GetPlace()), param_out.template mutable_data(ctx.GetPlace()), param.numel()); - - // update in cpu and then copy to xpu - if (beta1_pow.place() == platform::CPUPlace() && - beta2_pow.place() == platform::CPUPlace()) { - const T* beta1_pow_p = beta1_pow.template data(); - beta1_pow_out->mutable_data(platform::CPUPlace())[0] = - beta1 * beta1_pow_p[0]; - const T* beta2_pow_p = beta2_pow.template data(); - beta2_pow_out->mutable_data(platform::CPUPlace())[0] = - beta2 * beta2_pow_p[0]; - } else { - T cpu_beta1_pow_out_data; - T cpu_beta2_pow_out_data; - memory::Copy(platform::CPUPlace(), &cpu_beta1_pow_out_data, - BOOST_GET_CONST(platform::XPUPlace, beta1_pow.place()), - beta1_pow_ptr, sizeof(T)); - - cpu_beta1_pow_out_data = cpu_beta1_pow_out_data * beta1; - memory::Copy(platform::CPUPlace(), &cpu_beta2_pow_out_data, - BOOST_GET_CONST(platform::XPUPlace, beta2_pow.place()), - beta2_pow_ptr, sizeof(T)); - - cpu_beta2_pow_out_data = cpu_beta2_pow_out_data * beta2; - - T* beta1_pow_out_p = beta1_pow_out->mutable_data(ctx.GetPlace()); - T* beta2_pow_out_p = beta2_pow_out->mutable_data(ctx.GetPlace()); - memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), - beta1_pow_out_p, platform::CPUPlace(), - &cpu_beta1_pow_out_data, sizeof(T)); - memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), - beta2_pow_out_p, platform::CPUPlace(), - &cpu_beta2_pow_out_data, sizeof(T)); + if (!use_global_beta_pow) { + // update in cpu and then copy to xpu + if (beta1_pow.place() == platform::CPUPlace() && + beta2_pow.place() == platform::CPUPlace()) { + const T* beta1_pow_p = beta1_pow.template data(); + beta1_pow_out->mutable_data(platform::CPUPlace())[0] = + beta1 * beta1_pow_p[0]; + const T* beta2_pow_p = beta2_pow.template data(); + beta2_pow_out->mutable_data(platform::CPUPlace())[0] = + beta2 * beta2_pow_p[0]; + + } else { + T cpu_beta1_pow_out_data; + T cpu_beta2_pow_out_data; + + memory::Copy(platform::CPUPlace(), &cpu_beta1_pow_out_data, + BOOST_GET_CONST(platform::XPUPlace, beta1_pow.place()), + beta1_pow_ptr, sizeof(T)); + + cpu_beta1_pow_out_data = cpu_beta1_pow_out_data * beta1; + memory::Copy(platform::CPUPlace(), &cpu_beta2_pow_out_data, + BOOST_GET_CONST(platform::XPUPlace, beta2_pow.place()), + beta2_pow_ptr, sizeof(T)); + + cpu_beta2_pow_out_data = cpu_beta2_pow_out_data * beta2; + + T* beta1_pow_out_p = beta1_pow_out->mutable_data(ctx.GetPlace()); + T* beta2_pow_out_p = beta2_pow_out->mutable_data(ctx.GetPlace()); + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), + beta1_pow_out_p, platform::CPUPlace(), + &cpu_beta1_pow_out_data, sizeof(T)); + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), + beta2_pow_out_p, platform::CPUPlace(), + &cpu_beta2_pow_out_data, sizeof(T)); + } + + PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, + platform::errors::External( + "XPU API return wrong value[%d], please check " + "where Baidu Kunlun Card is properly installed.", + r)); } - - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External( - "XPU API return wrong value[%d], please check " - "where Baidu Kunlun Card is properly installed.", - r)); } else { PADDLE_ENFORCE_EQ(1, 2, platform::errors::InvalidArgument( "Variable type not supported by adam_op")); From b587a4de7bdcc6d2fba7a73e347887b3fa4c04b5 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Fri, 7 May 2021 08:03:03 +0000 Subject: [PATCH 03/10] update npu kernel --- paddle/fluid/operators/optimizers/adam_op.cc | 6 +++--- paddle/fluid/operators/optimizers/adam_op_npu.cc | 16 ++++++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index 973da19cf7fd6..7536654c5f5cc 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -198,7 +198,7 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) " "Whether to use multi-precision during weight updating.") .SetDefault(false); - // TODO(zhiqiu): We could set Beta1Pow, Beta2Pow, Beta1PowOut, Beta2PowOut + // TODO(zhiqiu): We could set Beta1PowOut and Beta2PowOut // as dispensable since they are not used when use_global_beta_pow is true. AddAttr("use_global_beta_pow", "(bool, default false) " @@ -262,7 +262,7 @@ REGISTER_OP_VERSION(adam) "use_global_beta_pow", "If true, Adam will use global beta_pow for whole model " "instead of creating beta_pow for each parameter." - "In that case, the inputs(Beta1Pow, Beta12Pow) and the " - "outputs(Beta1PowOut, Beta2PowOut) will not be used in adam op, " + "In that case, the outputs(Beta1PowOut, Beta2PowOut) will not be " + "used in adam op, " "and beta_pow will be updated after all adam op in the model.", false)); diff --git a/paddle/fluid/operators/optimizers/adam_op_npu.cc b/paddle/fluid/operators/optimizers/adam_op_npu.cc index 343a670438862..29e1e7872b697 100644 --- a/paddle/fluid/operators/optimizers/adam_op_npu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_npu.cc @@ -59,6 +59,8 @@ class AdamNPUKernel : public framework::OpKernel { auto* beta1_pow_out = ctx.Output("Beta1PowOut"); auto* beta2_pow_out = ctx.Output("Beta2PowOut"); + bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); + param_out->mutable_data(ctx.GetPlace()); mom1_out->mutable_data(ctx.GetPlace()); mom2_out->mutable_data(ctx.GetPlace()); @@ -174,12 +176,14 @@ class AdamNPUKernel : public framework::OpKernel { *mom2, ctx.GetPlace(), ctx.template device_context(), mom2_out); } - auto runner_m1 = - NpuOpRunner("Mul", {*beta1_pow, *beta1_tensor}, {*beta1_pow_out}, {}); - runner_m1.Run(stream); - auto runner_m2 = - NpuOpRunner("Mul", {*beta2_pow, *beta2_tensor}, {*beta2_pow_out}, {}); - runner_m2.Run(stream); + if (!use_global_beta_pow) { + auto runner_m1 = + NpuOpRunner("Mul", {*beta1_pow, *beta1_tensor}, {*beta1_pow_out}, {}); + runner_m1.Run(stream); + auto runner_m2 = + NpuOpRunner("Mul", {*beta2_pow, *beta2_tensor}, {*beta2_pow_out}, {}); + runner_m2.Run(stream); + } } }; From acc1b5540acbd846e37f6b44cd14b49496a913e3 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Fri, 7 May 2021 10:35:19 +0000 Subject: [PATCH 04/10] update python api --- python/paddle/fluid/optimizer.py | 221 ++++++++++++++++++++++++------- 1 file changed, 175 insertions(+), 46 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 41b2843ea33e9..bd348176e3bf7 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -125,6 +125,8 @@ def __init__(self, # to train. These variables are called accumulators. # {accum_name : { paramter_name : accumulator_for_parameter, ...}, ...} self._accumulators = defaultdict(lambda: dict()) + # global_accumulator dict, {accum_name : acc_variable, ...} + self._global_accumulators = {} self.helper = None self._opti_name_list = [] self._accumulators_holder = {} @@ -157,6 +159,8 @@ def state_dict(self): for k, v in self._accumulators.items(): for para_name, var_tmp in v.items(): state_dict[var_tmp.name] = var_tmp + for k, v in self._global_accumulators.items(): + state_dict[v.name] = v # global step if use lr decay if isinstance(self._learning_rate, LRScheduler): state_dict["LR_Scheduler"] = self._learning_rate.state_dict() @@ -236,36 +240,42 @@ def set_state_dict(self, state_dict): "Type not supprt, value in state dict must be [VarBase, Variable, numpy], the type is ", type(global_step)) + def _load_state_para(state_dict, param): + var = param.value() + tensor = var.get_tensor() + model_np = np.array(tensor) + load_para = state_dict[param.name] + if isinstance(load_para, Variable): + load_para_np = load_para.numpy() + elif isinstance(load_para, core.VarBase): + load_para_np = load_para.numpy() + elif isinstance(load_para, np.ndarray): + load_para_np = load_para + else: + raise RuntimeError("State dict type {} not supprt".format( + str(type(load_para)))) + + assert model_np.shape == load_para_np.shape, \ + "Parameter shape not match, Dygraph Parameter [ {} ] need tensor with shape {} but load tensor with shape {}".format( + item.name, model_np.shape, load_para_np.shape) + + assert model_np.dtype == load_para_np.dtype, \ + "Parameter dtype not match, Dygraph Parameter [ {} ] need tensor with dtype {} but load tensor with dtype {}".format( + item.name, model_np.dtype, load_para_np.dtype) + + tensor.set(load_para_np, framework._current_expected_place()) + self._accumulators_holder = state_dict for k, v in self._accumulators.items(): for para_name, var_tmp in v.items(): assert var_tmp.name in state_dict, \ "optimizer variable {} not found".format( var_tmp.name ) - var = var_tmp.value() - tensor = var.get_tensor() - model_np = np.array(tensor) - - load_para = state_dict[var_tmp.name] - - if isinstance(load_para, Variable): - load_para_np = load_para.numpy() - elif isinstance(load_para, core.VarBase): - load_para_np = load_para.numpy() - elif isinstance(load_para, np.ndarray): - load_para_np = load_para - else: - raise RuntimeError("State dict type {} not supprt".format( - str(type(load_para)))) - - assert model_np.shape == load_para_np.shape, \ - "Parameter shape not match, Dygraph Parameter [ {} ] need tensor with shape {} but load tensor with shape {}".format( - item.name, model_np.shape, load_para_np.shape) - - assert model_np.dtype == load_para_np.dtype, \ - "Parameter dtype not match, Dygraph Parameter [ {} ] need tensor with dtype {} but load tensor with dtype {}".format( - item.name, model_np.dtype, load_para_np.dtype) + _load_state_para(state_dict, var_tmp) - tensor.set(load_para_np, framework._current_expected_place()) + for k, v in self._global_accumulators.items(): + assert v.name in state_dict, \ + "optimizer variable {} not found".format( v.name ) + _load_state_para(state_dict, v) # [aliases] Compatible with old method names set_dict = set_state_dict @@ -589,6 +599,60 @@ def _add_accumulator(self, self._accumulators[name][param.name] = var return var + def _add_global_accumulator(self, + name, + dtype=None, + fill_value=0.0, + shape=None, + type=None, + device=None): + """Utility function to add a global accumulator for all parameters in the model + + Args: + block: the block in which the loss variable is present + name: name of the accumulator + dtype: data type of the accumulator variable + fill_value: value to initialize the accumulator variable + shape: the shape of the accumulator + type: the variable type of the accumulator + device: the target place of the accumulator + """ + if self._name is not None: + name = self._name + "_" + name + if (name in self._global_accumulators): + if framework.in_dygraph_mode(): + return self._global_accumulators[name] + raise Exception("Global accumulator {} already exists".format(name)) + if shape == None: + shape = [1] # most case, global accumulator is of shape [1] + assert isinstance(self.helper, LayerHelper) + + var_name = name + var_name = unique_name.generate(var_name) + self._opti_name_list.append(var_name) + + var = self.helper.create_global_variable( + name=var_name, + persistable=True, + dtype=dtype if dtype else self._dtype, + type=type, + shape=shape, + belong_to_optimizer=True) + if device is None: + device = 'cpu' + with device_guard(device): + self.helper.set_variable_initializer( + var, initializer=Constant(value=float(fill_value))) + + if framework.in_dygraph_mode(): + if len(self._accumulators_holder) > 0: + assert var_name in self._accumulators_holder, \ + "Optimizer set error, {} should in state dict".format( var_name ) + var.set_value(self._accumulators_holder[var_name]) + + self._global_accumulators[var_name] = var + return var + def _get_accumulator(self, name, param): """Utility function to fetch an accumulator for a parameter @@ -597,7 +661,7 @@ def _get_accumulator(self, name, param): param: parameter variable for which accumulator is to be fetched Returns: - accumulator variable for the parameter + accumulator variable """ if self._name is not None: name = self._name + "_" + name @@ -607,6 +671,21 @@ def _get_accumulator(self, name, param): format(name, param.name)) return self._accumulators[name][param.name] + def _get_global_accumulator(self, name): + """Utility function to fetch a global accumulator + + Args: + name: name of the accumulator + + Returns: + accumulator variable + """ + if self._name is not None: + name = self._name + "_" + name + if (name not in self._global_accumulators): + raise Exception("Global accumulator {} does not exist".format(name)) + return self._global_accumulators[name] + def _update_param_device_map(self, parameters_and_grads, target_block): for param_and_grad in parameters_and_grads: if param_and_grad[0].trainable is True: @@ -1915,6 +1994,8 @@ class AdamOptimizer(Optimizer): gradient in current mini-batch, so it will be much more faster. But this mode has different semantics with the original Adam algorithm and may lead to different result. The default value is False. + use_global_beta_pow (bool, optional): Whether to use global beta_pow. If true, Adam will use global beta_pow + for whole model instead of creating beta_pow for each parameter. Default is false. Examples: .. code-block:: python @@ -2024,7 +2105,8 @@ def __init__(self, regularization=None, grad_clip=None, name=None, - lazy_mode=False): + lazy_mode=False, + use_global_beta_pow=False): assert learning_rate is not None assert beta1 is not None assert beta2 is not None @@ -2040,6 +2122,7 @@ def __init__(self, self._beta2 = beta2 self._epsilon = epsilon self._lazy_mode = lazy_mode + self._use_global_beta_pow = use_global_beta_pow def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) @@ -2048,20 +2131,21 @@ def _create_accumulators(self, block, parameters): for p in parameters: self._add_accumulator(self._moment1_acc_str, p) self._add_accumulator(self._moment2_acc_str, p) - self._add_accumulator( - name=self._beta1_pow_acc_str, - param=p, - fill_value=0.9 if isinstance(self._beta1, Variable) \ - else self._beta1, - shape=[1], - type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') - self._add_accumulator( - name=self._beta2_pow_acc_str, - param=p, - fill_value=0.999 if isinstance(self._beta2, Variable) \ - else self._beta2, - shape=[1], - type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') + if not self._use_global_beta_pow: + self._add_accumulator( + name=self._beta1_pow_acc_str, + param=p, + fill_value=0.9 if isinstance(self._beta1, Variable) \ + else self._beta1, + shape=[1], + type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') + self._add_accumulator( + name=self._beta2_pow_acc_str, + param=p, + fill_value=0.999 if isinstance(self._beta2, Variable) \ + else self._beta2, + shape=[1], + type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) @@ -2070,10 +2154,16 @@ def _append_optimize_op(self, block, param_and_grad): param_and_grad[0]) moment2 = self._get_accumulator(self._moment2_acc_str, param_and_grad[0]) - beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, - param_and_grad[0]) - beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, - param_and_grad[0]) + if self._use_global_beta_pow: + beta1_pow_acc = self._get_global_accumulator( + self._beta1_pow_acc_str) + beta2_pow_acc = self._get_global_accumulator( + self._beta2_pow_acc_str) + else: + beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, + param_and_grad[0]) + beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, + param_and_grad[0]) lr = self._create_param_lr(param_and_grad) # create the adam optimize op @@ -2087,7 +2177,8 @@ def _append_optimize_op(self, block, param_and_grad): beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1, moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon, 'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread', - 1000, 'beta1', _beta1, 'beta2', _beta2) + 1000, 'beta1', _beta1, 'beta2', _beta2, 'use_global_beta_pow', + self._use_global_beta_pow) return None @@ -2109,7 +2200,8 @@ def _append_optimize_op(self, block, param_and_grad): } attrs = { "lazy_mode": self._lazy_mode, - "min_row_size_to_use_multithread": 1000 + "min_row_size_to_use_multithread": 1000, + 'use_global_beta_pow': self._use_global_beta_pow } if isinstance(self._beta1, Variable): @@ -2134,6 +2226,43 @@ def _append_optimize_op(self, block, param_and_grad): return adam_op + def _finish_update(self, block, parameters_and_grads): + r"""Update beta1_pow and beta2_pow accumulator + """ + assert isinstance(block, framework.Block) + if self._use_global_beta_pow: + beta1_pow_acc = self._get_global_accumulator( + self._beta1_pow_acc_str) + beta2_pow_acc = self._get_global_accumulator( + self._beta2_pow_acc_str) + + with block.program._optimized_guard([]): + inputs = {"X": beta1_pow_acc} + attrs = {} + if isinstance(self._beta1, Variable): + inputs['ScaleTensor'] = self._beta1 + else: + attrs['scale'] = self._beta1 + block.append_op( + type="scale", + inputs=inputs, + outputs={"Out": beta1_pow_acc}, + attrs=attrs, + stop_gradient=True) + + inputs = {"X": beta2_pow_acc} + attrs = {} + if isinstance(self._beta2, Variable): + inputs['ScaleTensor'] = self._beta1 + else: + attrs['scale'] = self._beta2 + block.append_op( + type="scale", + inputs=inputs, + outputs={"Out": beta2_pow_acc}, + attrs=attrs, + stop_gradient=True) + class AdamaxOptimizer(Optimizer): r""" From f09b6cf56f5a623fa49f4d5a5da4736151959c05 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Sat, 8 May 2021 07:35:44 +0000 Subject: [PATCH 05/10] refine code --- python/paddle/fluid/optimizer.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index bd348176e3bf7..362b5d92281ed 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -650,7 +650,7 @@ def _add_global_accumulator(self, "Optimizer set error, {} should in state dict".format( var_name ) var.set_value(self._accumulators_holder[var_name]) - self._global_accumulators[var_name] = var + self._global_accumulators[name] = var return var def _get_accumulator(self, name, param): @@ -2146,6 +2146,19 @@ def _create_accumulators(self, block, parameters): else self._beta2, shape=[1], type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') + if self._use_global_beta_pow: + self._add_global_accumulator( + name=self._beta1_pow_acc_str, + fill_value=0.9 if isinstance(self._beta1, Variable) \ + else self._beta1, + shape=[1], + type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') + self._add_global_accumulator( + name=self._beta2_pow_acc_str, + fill_value=0.999 if isinstance(self._beta2, Variable) \ + else self._beta2, + shape=[1], + type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) @@ -2253,7 +2266,7 @@ def _finish_update(self, block, parameters_and_grads): inputs = {"X": beta2_pow_acc} attrs = {} if isinstance(self._beta2, Variable): - inputs['ScaleTensor'] = self._beta1 + inputs['ScaleTensor'] = self._beta2 else: attrs['scale'] = self._beta2 block.append_op( From d9430dd90183f9f853bfa0c7ea89758da51432e1 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Mon, 10 May 2021 07:28:05 +0000 Subject: [PATCH 06/10] add ut for use_global_beta_pow --- paddle/fluid/operators/optimizers/adam_op.cu | 1 + paddle/fluid/operators/optimizers/adam_op.h | 1 + .../fluid/operators/optimizers/adam_op_npu.cc | 38 ++++++------ .../fluid/operators/optimizers/adam_op_xpu.cc | 1 + .../tests/unittests/npu/test_adam_op_npu.py | 59 +++++++++++++++++++ .../paddle/fluid/tests/unittests/op_test.py | 6 ++ .../fluid/tests/unittests/test_adam_op.py | 59 ++++++++++++++++++- 7 files changed, 146 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.cu b/paddle/fluid/operators/optimizers/adam_op.cu index acaa28bde2020..2ee2a08bf3bc6 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cu +++ b/paddle/fluid/operators/optimizers/adam_op.cu @@ -155,6 +155,7 @@ class AdamOpCUDAKernel : public framework::OpKernel { ctx.Attr("min_row_size_to_use_multithread"); bool lazy_mode = ctx.Attr("lazy_mode"); bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; auto* param = ctx.Input("Param"); auto* grad_var = ctx.InputVar("Grad"); diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index d4859895f1610..bbd4179d84d89 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -407,6 +407,7 @@ class AdamOpKernel : public framework::OpKernel { ctx.Attr("min_row_size_to_use_multithread"); bool lazy_mode = ctx.Attr("lazy_mode"); bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; auto* param = ctx.Input("Param"); auto* grad_var = ctx.InputVar("Grad"); diff --git a/paddle/fluid/operators/optimizers/adam_op_npu.cc b/paddle/fluid/operators/optimizers/adam_op_npu.cc index 29e1e7872b697..5391324ddd954 100644 --- a/paddle/fluid/operators/optimizers/adam_op_npu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_npu.cc @@ -36,7 +36,6 @@ class AdamNPUKernel : public framework::OpKernel { "but the received is %s", ctx.InputNames("Param").front(), framework::ToTypeName(param_var->Type()))); - T epsilon = static_cast(ctx.Attr("epsilon")); auto* param = ctx.Input("Param"); auto* grad_var = ctx.InputVar("Grad"); PADDLE_ENFORCE_EQ(grad_var->IsType(), true, @@ -60,28 +59,12 @@ class AdamNPUKernel : public framework::OpKernel { auto* beta2_pow_out = ctx.Output("Beta2PowOut"); bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; param_out->mutable_data(ctx.GetPlace()); mom1_out->mutable_data(ctx.GetPlace()); mom2_out->mutable_data(ctx.GetPlace()); - // NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform place. - if (beta1_pow->place() == platform::CPUPlace()) { - T beta1 = *beta1_pow->data(); - // `mutable_data` operation needs to be done after getting data - beta1_pow_out->mutable_data(ctx.GetPlace()); - FillNpuTensorWithConstant(beta1_pow_out, beta1); - } else { - beta1_pow_out->mutable_data(ctx.GetPlace()); - } - if (beta2_pow->place() == platform::CPUPlace()) { - T beta2 = *beta2_pow->data(); - beta2_pow_out->mutable_data(ctx.GetPlace()); - FillNpuTensorWithConstant(beta2_pow_out, beta2); - } else { - beta2_pow_out->mutable_data(ctx.GetPlace()); - } - const Tensor* beta1_tensor = nullptr; const Tensor* beta2_tensor = nullptr; const Tensor* epsilon_tensor = nullptr; @@ -177,6 +160,25 @@ class AdamNPUKernel : public framework::OpKernel { ctx.template device_context(), mom2_out); } if (!use_global_beta_pow) { + // NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform + // place. + Tensor beta1_pow_tmp; + Tensor beta2_pow_tmp; + if (beta1_pow->place() == platform::CPUPlace()) { + T beta1 = *beta1_pow->data(); + beta1_pow_tmp.mutable_data(ctx.GetPlace()); + FillNpuTensorWithConstant(&beta1_pow_tmp, beta1); + beta1_pow = &beta1_pow_tmp; + } + if (beta2_pow->place() == platform::CPUPlace()) { + T beta2 = *beta2_pow->data(); + beta2_pow_tmp.mutable_data(ctx.GetPlace()); + FillNpuTensorWithConstant(&beta2_pow_tmp, beta2); + beta2_pow = &beta2_pow_tmp; + } + + beta1_pow_out->mutable_data(ctx.GetPlace()); + beta2_pow_out->mutable_data(ctx.GetPlace()); auto runner_m1 = NpuOpRunner("Mul", {*beta1_pow, *beta1_tensor}, {*beta1_pow_out}, {}); runner_m1.Run(stream); diff --git a/paddle/fluid/operators/optimizers/adam_op_xpu.cc b/paddle/fluid/operators/optimizers/adam_op_xpu.cc index f229cdcd093f8..0f5706e428e15 100644 --- a/paddle/fluid/operators/optimizers/adam_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_xpu.cc @@ -74,6 +74,7 @@ class AdamOpXPUKernel : public framework::OpKernel { beta2_pow_out->numel())); bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; T beta1 = static_cast(ctx.Attr("beta1")); if (ctx.HasInput("Beta1Tensor")) { diff --git a/python/paddle/fluid/tests/unittests/npu/test_adam_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_adam_op_npu.py index ec616070b63ab..a3b4242f39d36 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_adam_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_adam_op_npu.py @@ -134,6 +134,65 @@ def test_check_output(self): self.check_output_with_place(self.place, atol=1e-5, check_dygraph=False) +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestAdamOpWithGlobalBetaPow(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "adam" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.004 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32"), + 'Beta1Tensor': np.array([beta1]).astype("float32"), + 'Beta2Tensor': np.array([beta2]).astype("float32"), + 'EpsilonTensor': np.array([epsilon]).astype("float32"), + } + + attributes = {'epsilon': epsilon} + + param_out, moment1_out, \ + moment2_out = adam_step(self.inputs, attributes) + + self.attrs = {'use_global_beta_pow': True} + + # use_global_beta_pow=True, Beta1PowOut and Beta2PowOut are empty. + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'ParamOut': param_out, + 'Beta1PowOut': np.array([]), + 'Beta2PowOut': np.array([]) + } + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5, check_dygraph=False) + + @unittest.skipIf(not paddle.is_compiled_with_npu(), "core is not compiled with NPU") class TestNet(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 25717b7967712..a2e467ad74769 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1087,6 +1087,7 @@ def check_output_with_place(self, dygraph_outs = self._calc_dygraph_output( place, no_check_set=no_check_set) outs, fetch_list = self._calc_output(place, no_check_set=no_check_set) + for out_name, out_dup in Operator.get_op_outputs(self.op_type): if out_name not in self.outputs: continue @@ -1177,6 +1178,11 @@ def find_actual(target_name, fetch_list): actual_t = convert_uint16_to_float(actual_t) atol = 0.03 + # NOTE(zhiqiu): np.allclose([], [1.]) returns True + # see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng + if expect_t.size == 0: + self.assertTrue(actual_t.size == 0) + self.assertTrue( np.allclose( actual_t, expect_t, atol=atol, equal_nan=equal_nan), diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index cb646ef0b9321..279c888184c1a 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -404,7 +404,7 @@ def test_check_output(self): class TestAdamOpBetaEpsilonVariable(OpTest): def setUp(self): - '''Test Adam Op with beta as Variable + '''Test Adam Op with beta/epsilon as Variable ''' self.op_type = "adam" param = np.random.uniform(-1, 1, (102, 105)).astype("float32") @@ -450,6 +450,57 @@ def test_check_output(self): self.check_output() +class TestAdamOpWithGlobalBetaPow(OpTest): + def setUp(self): + '''Test Adam Op with global_beta_pow + ''' + self.op_type = "adam" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + beta1 = 0.85 + beta2 = 0.95 + + learning_rate = 0.001 + epsilon = 1e-8 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32"), + "Beta1Tensor": np.array([beta1]).astype("float32"), + "Beta2Tensor": np.array([beta2]).astype("float32"), + "EpsilonTensor": np.array([epsilon]).astype("float32"), + } + + attributes = {'epsilon': epsilon} + + param_out, moment1_out, \ + moment2_out = adam_step(self.inputs, attributes) + + self.attrs = {'use_global_beta_pow': True} + + # use_global_beta_pow=True, Beta1PowOut and Beta2PowOut are empty. + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'ParamOut': param_out, + 'Beta1PowOut': np.array([]), + 'Beta2PowOut': np.array([]) + } + + def test_check_output(self): + self.check_output() + + class TestAdamOpV2(unittest.TestCase): def test_adam_op(self): place = fluid.CPUPlace() @@ -493,6 +544,7 @@ def test_adam_op_dygraph(self): out.backward() adam.step() adam.clear_gradients() + paddle.enable_static() def test_adam_op_with_state_dict(self): @@ -523,6 +575,7 @@ def test_adam_op_with_state_dict(self): params = adam.get_opti_var_name_list() assert (params is not None) + paddle.enable_static() def test_adam_with_grad_clip(self): paddle.disable_static() @@ -536,6 +589,7 @@ def test_adam_with_grad_clip(self): out.backward() adam.step() adam.clear_gradients() + paddle.enable_static() def test_adam_op_with_set_lr(self): paddle.disable_static() @@ -550,6 +604,7 @@ def test_adam_op_with_set_lr(self): lr_var = paddle.fluid.layers.create_global_var( shape=[1], value=lr, dtype='float32') adam.set_lr(lr_var) + paddle.enable_static() def test_adam_op_invalid_input(self): paddle.disable_static() @@ -563,6 +618,7 @@ def test_adam_op_invalid_input(self): with self.assertRaises(ValueError): adam = paddle.optimizer.Adam( 0.1, epsilon=-1, parameters=linear.parameters()) + paddle.enable_static() def test_adam_op_with_sparse_input_and_weight_decay(self): @@ -577,6 +633,7 @@ def test_adam_op_with_sparse_input_and_weight_decay(self): out = emb(x) out.backward() adam.step() + paddle.enable_static() class TestNetWithEpsilonTensor(unittest.TestCase): From df81ac634d0f59e8bf8dfcef2080f781b4f7ee18 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Tue, 11 May 2021 11:13:01 +0000 Subject: [PATCH 07/10] fix npu kernel --- .../fluid/operators/optimizers/adam_op_npu.cc | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op_npu.cc b/paddle/fluid/operators/optimizers/adam_op_npu.cc index 5391324ddd954..e5fe7f20a42e0 100644 --- a/paddle/fluid/operators/optimizers/adam_op_npu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_npu.cc @@ -49,8 +49,8 @@ class AdamNPUKernel : public framework::OpKernel { auto* mom2 = ctx.Input("Moment2"); auto* lr = ctx.Input("LearningRate"); - auto* beta1_pow = ctx.Input("Beta1Pow"); - auto* beta2_pow = ctx.Input("Beta2Pow"); + auto* beta1_pow = ctx.Input("Beta1Pow"); + auto* beta2_pow = ctx.Input("Beta2Pow"); auto* param_out = ctx.Output("ParamOut"); auto* mom1_out = ctx.Output("Moment1Out"); @@ -65,6 +65,23 @@ class AdamNPUKernel : public framework::OpKernel { mom1_out->mutable_data(ctx.GetPlace()); mom2_out->mutable_data(ctx.GetPlace()); + // NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform + // place. + LoDTensor beta1_pow_tmp; + LoDTensor beta2_pow_tmp; + if (beta1_pow->place() == platform::CPUPlace()) { + T beta1 = *beta1_pow->data(); + beta1_pow_tmp.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&beta1_pow_tmp, beta1); + beta1_pow = &beta1_pow_tmp; + } + if (beta2_pow->place() == platform::CPUPlace()) { + T beta2 = *beta2_pow->data(); + beta2_pow_tmp.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&beta2_pow_tmp, beta2); + beta2_pow = &beta2_pow_tmp; + } + const Tensor* beta1_tensor = nullptr; const Tensor* beta2_tensor = nullptr; const Tensor* epsilon_tensor = nullptr; @@ -160,23 +177,6 @@ class AdamNPUKernel : public framework::OpKernel { ctx.template device_context(), mom2_out); } if (!use_global_beta_pow) { - // NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform - // place. - Tensor beta1_pow_tmp; - Tensor beta2_pow_tmp; - if (beta1_pow->place() == platform::CPUPlace()) { - T beta1 = *beta1_pow->data(); - beta1_pow_tmp.mutable_data(ctx.GetPlace()); - FillNpuTensorWithConstant(&beta1_pow_tmp, beta1); - beta1_pow = &beta1_pow_tmp; - } - if (beta2_pow->place() == platform::CPUPlace()) { - T beta2 = *beta2_pow->data(); - beta2_pow_tmp.mutable_data(ctx.GetPlace()); - FillNpuTensorWithConstant(&beta2_pow_tmp, beta2); - beta2_pow = &beta2_pow_tmp; - } - beta1_pow_out->mutable_data(ctx.GetPlace()); beta2_pow_out->mutable_data(ctx.GetPlace()); auto runner_m1 = From 84273c204b78e3440327f0de1c52bd24d6ea1559 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Tue, 11 May 2021 11:39:43 +0000 Subject: [PATCH 08/10] add ut for api --- .../fluid/tests/unittests/test_adam_op.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index 279c888184c1a..4ed1643ac2c69 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -637,7 +637,11 @@ def test_adam_op_with_sparse_input_and_weight_decay(self): class TestNetWithEpsilonTensor(unittest.TestCase): - def _test(self, place, use_tensor=True, use_fluid_api=True): + def _test(self, + place, + use_tensor=True, + use_fluid_api=True, + use_global_beta_pow=False): paddle.enable_static() main_prog = paddle.static.Program() startup_prog = paddle.static.Program() @@ -690,7 +694,8 @@ def _test(self, place, use_tensor=True, use_fluid_api=True): learning_rate=0.01, beta1=beta1, beta2=beta2, - epsilon=epsilon) + epsilon=epsilon, + use_global_beta_pow=use_global_beta_pow) else: adam = paddle.optimizer.Adam( learning_rate=0.01, @@ -703,7 +708,8 @@ def _test(self, place, use_tensor=True, use_fluid_api=True): learning_rate=0.01, beta1=beta1_init, beta2=beta2_init, - epsilon=epsilon_init) + epsilon=epsilon_init, + use_global_beta_pow=use_global_beta_pow) else: adam = fluid.optimizer.Adam( learning_rate=0.01, @@ -737,9 +743,11 @@ def _test_with_place(self, place): for use_tensor in [True, False]: for use_fluid_api in [True, False]: - pred, loss = self._test(place, use_tensor, use_fluid_api) - preds.append(pred) - losses.append(loss) + for use_global_beta_pow in [True, False]: + pred, loss = self._test(place, use_tensor, use_fluid_api, + use_global_beta_pow) + preds.append(pred) + losses.append(loss) for pred in preds: self.assertTrue(np.allclose(pred, preds[0])) for loss in losses: From b6d5a698ee747403976eebda9558b78c65f61010 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Wed, 12 May 2021 05:57:10 +0000 Subject: [PATCH 09/10] add ut for exception --- .../fluid/tests/unittests/test_adam_op.py | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index 4ed1643ac2c69..961593c04ea2a 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -709,7 +709,8 @@ def _test(self, beta1=beta1_init, beta2=beta2_init, epsilon=epsilon_init, - use_global_beta_pow=use_global_beta_pow) + use_global_beta_pow=use_global_beta_pow, + name='a') else: adam = fluid.optimizer.Adam( learning_rate=0.01, @@ -759,6 +760,33 @@ def test_adam_api(self): if core.is_compiled_with_cuda(): self._test_with_place(paddle.CUDAPlace(0)) + def test_adam_exception(self): + paddle.enable_static() + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data(name="label", shape=[32, 1], dtype='int64') + + sum = paddle.add(a, b) + z = paddle.pow(sum, 2.0) + + fc_1 = fluid.layers.fc(input=z, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + adam = fluid.optimizer.Adam(use_global_beta_pow=True) + adam.minimize(loss) + self.assertRaises(Exception, adam._get_global_accumulator, 'tmp') + adam._add_global_accumulator( + 'tmp', type=core.VarDesc.VarType.LOD_TENSOR) + adam._get_global_accumulator('tmp') + self.assertRaises( + Exception, + adam._add_global_accumulator, + adam._beta1_pow_acc_str, + type=core.VarDesc.VarType.LOD_TENSOR) + paddle.disable_static() + if __name__ == "__main__": unittest.main() From 8b12bdca8d0ecad02d0a2fa934bcbe018f1900b7 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Wed, 12 May 2021 08:22:29 +0000 Subject: [PATCH 10/10] add ut for save/load --- .../fluid/tests/unittests/test_adam_op.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index 961593c04ea2a..1e316c3383ea7 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -787,6 +787,28 @@ def test_adam_exception(self): type=core.VarDesc.VarType.LOD_TENSOR) paddle.disable_static() + def test_adam_save_load(self): + paddle.disable_static() + a = paddle.rand([4, 10]) + linear = paddle.nn.Linear(10, 10) + b = linear(a) + state_dict = linear.state_dict() + fluid.save_dygraph(state_dict, "paddle_dy") + + scheduler = paddle.optimizer.lr.NoamDecay( + d_model=0.01, warmup_steps=100, verbose=True) + adam = paddle.fluid.optimizer.Adam( + learning_rate=scheduler, + parameter_list=linear.parameters(), + use_global_beta_pow=True) + adam.minimize(b) + state_dict = adam.state_dict() + fluid.save_dygraph(state_dict, "paddle_dy") + para_state_dict, opti_state_dict = fluid.load_dygraph("paddle_dy") + adam.set_state_dict(opti_state_dict) + + paddle.enable_static() + if __name__ == "__main__": unittest.main()