Skip to content

Commit

Permalink
[Fluid] move lars_momentum_op InferShape to phi (#56749)
Browse files Browse the repository at this point in the history
* move to phi

* fix

* fix type
  • Loading branch information
gouzil authored Sep 5, 2023
1 parent 99ae88f commit 52a0a67
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 112 deletions.
120 changes: 8 additions & 112 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/multiary.h"

namespace paddle {
namespace operators {
Expand All @@ -22,117 +24,6 @@ class LarsMomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("Param"), "Input", "Param", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("Grad"), "Input", "Grad", "LarsMomentum");
OP_INOUT_CHECK(
ctx->HasInputs("Velocity"), "Input", "Velocity", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("LearningRate"),
"Input",
"LearningRate",
"LarsMomentum");
OP_INOUT_CHECK(
ctx->HasOutputs("ParamOut"), "Output", "ParamOut", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasOutputs("VelocityOut"),
"Output",
"VelocityOut",
"LarsMomentum");
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be phi::DenseTensor, "
"but the received is %s",
ctx->GetInputsVarType("Param").front()));

auto lr_dims = ctx->GetInputsDim("LearningRate");
auto grad_dim = ctx->GetInputsDim("Grad");
auto param_dim = ctx->GetInputsDim("Param");
auto velocity_dim = ctx->GetInputsDim("Velocity");
auto lars_weight_decays =
ctx->Attrs().Get<std::vector<float>>("lars_weight_decay");
auto multi_precision = ctx->Attrs().Get<bool>("multi_precision");

PADDLE_ENFORCE_EQ(
param_dim.size(),
grad_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) of LarsMomentumOp should have "
"same quantity. But number of Param is [%d] and Grad is [%d].",
param_dim.size(),
grad_dim.size()));
PADDLE_ENFORCE_EQ(
param_dim.size(),
velocity_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp should "
"have same quantity. But number of Param is [%d] and Velocity "
"is [%d].",
param_dim.size(),
velocity_dim.size()));
PADDLE_ENFORCE_EQ(
lars_weight_decays.size(),
grad_dim.size(),
platform::errors::InvalidArgument(
"Attr(Lars_weight_decay) and "
"Input(Grad) of LarsMomentumOp should have same quantity. "
"But number of Lars_weight_decay is [%d] and Grad is [%d].",
lars_weight_decays.size(),
grad_dim.size()));

if (multi_precision) {
OP_INOUT_CHECK(ctx->HasInputs("MasterParam"),
"Input",
"MasterParam",
"LarsMomentumMultiPrecision");
OP_INOUT_CHECK(ctx->HasOutputs("MasterParamOut"),
"Output",
"MasterParamOut",
"LarsMomentumMultiPrecision");
}
for (auto& lr_dim : lr_dims) {
PADDLE_ENFORCE_EQ(phi::product(lr_dim),
1,
platform::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dim)));
}

for (size_t i = 0; i < param_dim.size(); ++i) {
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Grad")[i],
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be phi::DenseTensor, "
"but the received is %s",
ctx->Inputs("Grad")[i].front(),
ctx->GetInputsVarType("Grad")[i]));
PADDLE_ENFORCE_EQ(
param_dim[i],
grad_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) input of LarsMomentumOp shall "
"have same dimension. But Param`s dim is [%s] and Grad's dim "
"is [%s].",
param_dim[i],
grad_dim[i]));
PADDLE_ENFORCE_EQ(
param_dim[i],
velocity_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp shall have "
"same dimension. But Param dim [%s] differs with Velocity dim "
"[%s].",
param_dim[i],
velocity_dim[i]));
}
ctx->SetOutputsDim("ParamOut", param_dim);
ctx->SetOutputsDim("VelocityOut", param_dim);
if (ctx->HasOutputs("MasterParamOut")) {
ctx->SetOutputsDim("MasterParamOut", param_dim);
}
}

protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand Down Expand Up @@ -226,11 +117,16 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(lars_momentum,
LarsMomentumInferShapeFunctor,
PD_INFER_META(phi::LarsMomentumInferMeta));

namespace ops = paddle::operators;
REGISTER_OPERATOR(
lars_momentum,
ops::LarsMomentumOp,
ops::LarsMomentumOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::LarsMomentumOpVarTypeInference);
ops::LarsMomentumOpVarTypeInference,
LarsMomentumInferShapeFunctor);
86 changes: 86 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2577,6 +2577,92 @@ void LambInferMeta(const MetaTensor& param,
}
}

void LarsMomentumInferMeta(
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& velocity,
const std::vector<const MetaTensor*>& learning_rate,
const std::vector<const MetaTensor*>& grad,
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
const std::vector<float>& lars_weight_decay,
float mu,
float lars_coeff,
float epsilon,
bool multi_precision,
float rescale_grad,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> velocity_out,
std::vector<MetaTensor*> master_param_out) {
std::vector<DDim> lr_dims = GetMetaTensorsDim(learning_rate);
std::vector<DDim> grad_dim = GetMetaTensorsDim(grad);
std::vector<DDim> param_dim = GetMetaTensorsDim(param);
std::vector<DDim> velocity_dim = GetMetaTensorsDim(velocity);

PADDLE_ENFORCE_EQ(
param_dim.size(),
grad_dim.size(),
phi::errors::InvalidArgument(
"Input(Param) and Input(Grad) of LarsMomentumOp should have "
"same quantity. But number of Param is [%d] and Grad is [%d].",
param_dim.size(),
grad_dim.size()));
PADDLE_ENFORCE_EQ(
param_dim.size(),
velocity_dim.size(),
phi::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp should "
"have same quantity. But number of Param is [%d] and Velocity "
"is [%d].",
param_dim.size(),
velocity_dim.size()));
PADDLE_ENFORCE_EQ(
lars_weight_decay.size(),
grad_dim.size(),
phi::errors::InvalidArgument(
"Attr(Lars_weight_decay) and "
"Input(Grad) of LarsMomentumOp should have same quantity. "
"But number of Lars_weight_decay is [%d] and Grad is [%d].",
lars_weight_decay.size(),
grad_dim.size()));

for (auto& lr_dim : lr_dims) {
PADDLE_ENFORCE_EQ(phi::product(lr_dim),
1,
phi::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dim)));
}

for (size_t i = 0; i < param_dim.size(); ++i) {
PADDLE_ENFORCE_EQ(
param_dim[i],
grad_dim[i],
phi::errors::InvalidArgument(
"Input(Param) and Input(Grad) input of LarsMomentumOp shall "
"have same dimension. But Param`s dim is [%s] and Grad's dim "
"is [%s].",
param_dim[i],
grad_dim[i]));
PADDLE_ENFORCE_EQ(
param_dim[i],
velocity_dim[i],
phi::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp shall have "
"same dimension. But Param dim [%s] differs with Velocity dim "
"[%s].",
param_dim[i],
velocity_dim[i]));
}

for (size_t i = 0; i < param_out.size(); i++) {
param_out[i]->set_dims(param_dim[i]);
velocity_out[i]->set_dims(param_dim[i]);
if (master_param != nullptr) {
master_param_out[i]->set_dims(param_dim[i]);
}
}
}

void LLMInt8LinearInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
Expand Down
16 changes: 16 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,22 @@ void LambInferMeta(const MetaTensor& param,
MetaTensor* beta2_pow_out,
MetaTensor* master_param_outs);

void LarsMomentumInferMeta(
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& velocity,
const std::vector<const MetaTensor*>& learning_rate,
const std::vector<const MetaTensor*>& grad,
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
const std::vector<float>& lars_weight_decay,
float mu,
float lars_coeff,
float epsilon,
bool multi_precision,
float rescale_grad,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> velocity_out,
std::vector<MetaTensor*> master_param_out);

void LLMInt8LinearInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
Expand Down

0 comments on commit 52a0a67

Please sign in to comment.