Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fluid] move lars_momentum_op InferShape to phi #56749

Merged
merged 4 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 8 additions & 112 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/multiary.h"

namespace paddle {
namespace operators {
Expand All @@ -22,117 +24,6 @@ class LarsMomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("Param"), "Input", "Param", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("Grad"), "Input", "Grad", "LarsMomentum");
OP_INOUT_CHECK(
ctx->HasInputs("Velocity"), "Input", "Velocity", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("LearningRate"),
"Input",
"LearningRate",
"LarsMomentum");
OP_INOUT_CHECK(
ctx->HasOutputs("ParamOut"), "Output", "ParamOut", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasOutputs("VelocityOut"),
"Output",
"VelocityOut",
"LarsMomentum");
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be phi::DenseTensor, "
"but the received is %s",
ctx->GetInputsVarType("Param").front()));

auto lr_dims = ctx->GetInputsDim("LearningRate");
auto grad_dim = ctx->GetInputsDim("Grad");
auto param_dim = ctx->GetInputsDim("Param");
auto velocity_dim = ctx->GetInputsDim("Velocity");
auto lars_weight_decays =
ctx->Attrs().Get<std::vector<float>>("lars_weight_decay");
auto multi_precision = ctx->Attrs().Get<bool>("multi_precision");

PADDLE_ENFORCE_EQ(
param_dim.size(),
grad_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) of LarsMomentumOp should have "
"same quantity. But number of Param is [%d] and Grad is [%d].",
param_dim.size(),
grad_dim.size()));
PADDLE_ENFORCE_EQ(
param_dim.size(),
velocity_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp should "
"have same quantity. But number of Param is [%d] and Velocity "
"is [%d].",
param_dim.size(),
velocity_dim.size()));
PADDLE_ENFORCE_EQ(
lars_weight_decays.size(),
grad_dim.size(),
platform::errors::InvalidArgument(
"Attr(Lars_weight_decay) and "
"Input(Grad) of LarsMomentumOp should have same quantity. "
"But number of Lars_weight_decay is [%d] and Grad is [%d].",
lars_weight_decays.size(),
grad_dim.size()));

if (multi_precision) {
OP_INOUT_CHECK(ctx->HasInputs("MasterParam"),
"Input",
"MasterParam",
"LarsMomentumMultiPrecision");
OP_INOUT_CHECK(ctx->HasOutputs("MasterParamOut"),
"Output",
"MasterParamOut",
"LarsMomentumMultiPrecision");
}
for (auto& lr_dim : lr_dims) {
PADDLE_ENFORCE_EQ(phi::product(lr_dim),
1,
platform::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dim)));
}

for (size_t i = 0; i < param_dim.size(); ++i) {
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Grad")[i],
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be phi::DenseTensor, "
"but the received is %s",
ctx->Inputs("Grad")[i].front(),
ctx->GetInputsVarType("Grad")[i]));
PADDLE_ENFORCE_EQ(
param_dim[i],
grad_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) input of LarsMomentumOp shall "
"have same dimension. But Param`s dim is [%s] and Grad's dim "
"is [%s].",
param_dim[i],
grad_dim[i]));
PADDLE_ENFORCE_EQ(
param_dim[i],
velocity_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp shall have "
"same dimension. But Param dim [%s] differs with Velocity dim "
"[%s].",
param_dim[i],
velocity_dim[i]));
}
ctx->SetOutputsDim("ParamOut", param_dim);
ctx->SetOutputsDim("VelocityOut", param_dim);
if (ctx->HasOutputs("MasterParamOut")) {
ctx->SetOutputsDim("MasterParamOut", param_dim);
}
}

protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand Down Expand Up @@ -226,11 +117,16 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(lars_momentum,
LarsMomentumInferShapeFunctor,
PD_INFER_META(phi::LarsMomentumInferMeta));

namespace ops = paddle::operators;
REGISTER_OPERATOR(
lars_momentum,
ops::LarsMomentumOp,
ops::LarsMomentumOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::LarsMomentumOpVarTypeInference);
ops::LarsMomentumOpVarTypeInference,
LarsMomentumInferShapeFunctor);
86 changes: 86 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2566,6 +2566,92 @@ void LambInferMeta(const MetaTensor& param,
}
}

void LarsMomentumInferMeta(
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& velocity,
const std::vector<const MetaTensor*>& learning_rate,
const std::vector<const MetaTensor*>& grad,
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
const std::vector<float>& lars_weight_decay,
float mu,
float lars_coeff,
float epsilon,
bool multi_precision,
float rescale_grad,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> velocity_out,
std::vector<MetaTensor*> master_param_out) {
std::vector<DDim> lr_dims = GetMetaTensorsDim(learning_rate);
std::vector<DDim> grad_dim = GetMetaTensorsDim(grad);
std::vector<DDim> param_dim = GetMetaTensorsDim(param);
std::vector<DDim> velocity_dim = GetMetaTensorsDim(velocity);

PADDLE_ENFORCE_EQ(
param_dim.size(),
grad_dim.size(),
phi::errors::InvalidArgument(
"Input(Param) and Input(Grad) of LarsMomentumOp should have "
"same quantity. But number of Param is [%d] and Grad is [%d].",
param_dim.size(),
grad_dim.size()));
PADDLE_ENFORCE_EQ(
param_dim.size(),
velocity_dim.size(),
phi::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp should "
"have same quantity. But number of Param is [%d] and Velocity "
"is [%d].",
param_dim.size(),
velocity_dim.size()));
PADDLE_ENFORCE_EQ(
lars_weight_decay.size(),
grad_dim.size(),
phi::errors::InvalidArgument(
"Attr(Lars_weight_decay) and "
"Input(Grad) of LarsMomentumOp should have same quantity. "
"But number of Lars_weight_decay is [%d] and Grad is [%d].",
lars_weight_decay.size(),
grad_dim.size()));

for (auto& lr_dim : lr_dims) {
PADDLE_ENFORCE_EQ(phi::product(lr_dim),
1,
phi::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dim)));
}

for (size_t i = 0; i < param_dim.size(); ++i) {
PADDLE_ENFORCE_EQ(
param_dim[i],
grad_dim[i],
phi::errors::InvalidArgument(
"Input(Param) and Input(Grad) input of LarsMomentumOp shall "
"have same dimension. But Param`s dim is [%s] and Grad's dim "
"is [%s].",
param_dim[i],
grad_dim[i]));
PADDLE_ENFORCE_EQ(
param_dim[i],
velocity_dim[i],
phi::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp shall have "
"same dimension. But Param dim [%s] differs with Velocity dim "
"[%s].",
param_dim[i],
velocity_dim[i]));
}

for (size_t i = 0; i < param_out.size(); i++) {
param_out[i]->set_dims(param_dim[i]);
velocity_out[i]->set_dims(param_dim[i]);
if (master_param != nullptr) {
master_param_out[i]->set_dims(param_dim[i]);
}
}
}

void LLMInt8LinearInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
Expand Down
16 changes: 16 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,22 @@ void LambInferMeta(const MetaTensor& param,
MetaTensor* beta2_pow_out,
MetaTensor* master_param_outs);

void LarsMomentumInferMeta(
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& velocity,
const std::vector<const MetaTensor*>& learning_rate,
const std::vector<const MetaTensor*>& grad,
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
const std::vector<float>& lars_weight_decay,
float mu,
float lars_coeff,
float epsilon,
bool multi_precision,
float rescale_grad,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> velocity_out,
std::vector<MetaTensor*> master_param_out);

void LLMInt8LinearInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
Expand Down