From d4c4336c8f1f6e304cc1c866317d6ecd8207c453 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 29 Aug 2023 13:21:23 +0800 Subject: [PATCH 1/2] [Fluid] move lars_momentum_xpu to phi --- .../optimizers/lars_momentum_op_xpu.cc | 125 ------------------ .../phi/kernels/xpu/lars_momentum_kernel.cc | 113 ++++++++++++++++ 2 files changed, 113 insertions(+), 125 deletions(-) delete mode 100644 paddle/fluid/operators/optimizers/lars_momentum_op_xpu.cc create mode 100644 paddle/phi/kernels/xpu/lars_momentum_kernel.cc diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op_xpu.cc b/paddle/fluid/operators/optimizers/lars_momentum_op_xpu.cc deleted file mode 100644 index 266ce2e57ca798..00000000000000 --- a/paddle/fluid/operators/optimizers/lars_momentum_op_xpu.cc +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/backends/xpu/enforce_xpu.h" - -namespace paddle { -namespace operators { - -template -class LarsMomentumOpXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - bool multi_precision = ctx.Attr("multi_precision"); - auto param_out = ctx.MultiOutput("ParamOut"); - auto velocity_out = ctx.MultiOutput("VelocityOut"); - auto param = ctx.MultiInput("Param"); - auto velocity = ctx.MultiInput("Velocity"); - auto learning_rate = ctx.MultiInput("LearningRate"); - auto grad = ctx.MultiInput("Grad"); - auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); - auto master_param = ctx.MultiInput("MasterParam"); - auto master_param_out = ctx.MultiOutput("MasterParamOut"); - float mu = static_cast(ctx.Attr("mu")); - float lars_coeff = ctx.Attr("lars_coeff"); - float epsilon = ctx.Attr("epsilon"); - float rescale_grad = ctx.Attr("rescale_grad"); - - std::vector param_list; - std::vector grad_list; - std::vector param_out_list; - std::vector velocity_list; - std::vector velocity_out_list; - std::vector lrs; - std::vector param_sizes; - - std::vector master_param_list; - std::vector master_param_out_list; - int op_num = param.size(); - for (int i = 0; i < op_num; ++i) { - param_list.push_back( - reinterpret_cast(const_cast((param[i]->data())))); - grad_list.push_back( - reinterpret_cast(const_cast(grad[i]->data()))); - param_out_list.push_back(reinterpret_cast( - param_out[i]->mutable_data(ctx.GetPlace()))); - velocity_list.push_back(const_cast(velocity[i]->data())); - velocity_out_list.push_back( - velocity_out[i]->mutable_data(ctx.GetPlace())); - lrs.push_back(const_cast(learning_rate[i]->data())); - param_sizes.push_back(param[i]->numel()); - - PADDLE_ENFORCE_EQ( - param_list[i], - param_out_list[i], - platform::errors::InvalidArgument( - "Input(Param) and Output(ParamOut) must be the same Tensors.")); - PADDLE_ENFORCE_EQ(velocity_list[i], - velocity_out_list[i], - platform::errors::InvalidArgument( - "Input(Velocity) and Output(VelocityOut) must be " - "the same Tensors.")); - if (multi_precision) { - master_param_list.push_back( - const_cast(master_param[i]->data())); - master_param_out_list.push_back( - master_param_out[i]->mutable_data(ctx.GetPlace())); - PADDLE_ENFORCE_EQ(master_param_list[i], - master_param_out_list[i], - platform::errors::InvalidArgument( - "Input(MasterParam) and Output(MasterParamOut) " - "must be the same Tensors.")); - } else { - master_param_list.push_back(nullptr); - master_param_out_list.push_back(nullptr); - } - } - - auto& dev_ctx = ctx.template device_context(); - int r = lars_momentum(dev_ctx.x_context(), - param_list, - grad_list, - velocity_list, - lrs, - master_param_list, - param_out_list, - velocity_out_list, - master_param_out_list, - weight_decay_arr, - param_sizes, - mu, - lars_coeff, - epsilon, - rescale_grad); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "lars_momentum"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -PD_REGISTER_STRUCT_KERNEL(lars_momentum, - XPU, - ALL_LAYOUT, - ops::LarsMomentumOpXPUKernel, - float, - plat::float16) {} -#endif diff --git a/paddle/phi/kernels/xpu/lars_momentum_kernel.cc b/paddle/phi/kernels/xpu/lars_momentum_kernel.cc new file mode 100644 index 00000000000000..8821a27f6137f4 --- /dev/null +++ b/paddle/phi/kernels/xpu/lars_momentum_kernel.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/lars_momentum_kernel.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void LarsMomentumKernel( + const Context& dev_ctx, + const std::vector& param, + const std::vector& velocity, + const std::vector& learning_rate, + const std::vector& grad, + const paddle::optional>& master_param, + const std::vector& weight_decay_arr, + float mu, + float lars_coeff, + float epsilon, + bool multi_precision, + float rescale_grad, + std::vector param_out, + std::vector velocity_out, + std::vector master_param_out) { + using XPUType = typename XPUTypeTrait::Type; + std::vector param_list; + std::vector grad_list; + std::vector param_out_list; + std::vector velocity_list; + std::vector velocity_out_list; + std::vector lrs; + std::vector param_sizes; + + std::vector master_param_list; + std::vector master_param_out_list; + int op_num = param.size(); + for (int i = 0; i < op_num; ++i) { + param_list.push_back( + reinterpret_cast(const_cast((param[i]->data())))); + grad_list.push_back( + reinterpret_cast(const_cast(grad[i]->data()))); + param_out_list.push_back( + reinterpret_cast(dev_ctx.template Alloc(param_out[i]))); + velocity_list.push_back(const_cast(velocity[i]->data())); + velocity_out_list.push_back(dev_ctx.template Alloc(velocity_out[i])); + lrs.push_back(const_cast(learning_rate[i]->data())); + param_sizes.push_back(param[i]->numel()); + + PADDLE_ENFORCE_EQ( + param_list[i], + param_out_list[i], + phi::errors::InvalidArgument( + "Input(Param) and Output(ParamOut) must be the same Tensors.")); + PADDLE_ENFORCE_EQ(velocity_list[i], + velocity_out_list[i], + phi::errors::InvalidArgument( + "Input(Velocity) and Output(VelocityOut) must be " + "the same Tensors.")); + if (multi_precision) { + master_param_list.push_back( + const_cast(master_param.get()[i]->data())); + master_param_out_list.push_back( + dev_ctx.template Alloc(master_param_out[i])); + PADDLE_ENFORCE_EQ(master_param_list[i], + master_param_out_list[i], + phi::errors::InvalidArgument( + "Input(MasterParam) and Output(MasterParamOut) " + "must be the same Tensors.")); + } else { + master_param_list.push_back(nullptr); + master_param_out_list.push_back(nullptr); + } + } + + int r = lars_momentum(dev_ctx.x_context(), + param_list, + grad_list, + velocity_list, + lrs, + master_param_list, + param_out_list, + velocity_out_list, + master_param_out_list, + weight_decay_arr, + param_sizes, + mu, + lars_coeff, + epsilon, + rescale_grad); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "lars_momentum"); +} +} // namespace phi + +PD_REGISTER_KERNEL(lars_momentum, + XPU, + ALL_LAYOUT, + phi::LarsMomentumKernel, + float, + phi::dtype::float16) {} From 2f8034d1a1a2e11cef174d02e158ef365771c633 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 29 Aug 2023 18:34:30 +0800 Subject: [PATCH 2/2] Empty-Commit;test=kunlun;