Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Infer Symbolic Shape No.161,162][BUAA] norm, p_norm #67136

Merged
merged 8 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1158,12 +1158,32 @@ bool MeanAllOpInferSymbolicShape(
// return true;
// }

// bool NormOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context)
// {
// // pass
// return true;
// }
bool NormOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
auto x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &x_shape = x_shape_or_data.shape();

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)});

int axis = op->attribute<pir::Int32Attribute>("axis").data();
bool is_test = op->attribute<pir::BoolAttribute>("is_test").data();

if (!is_test) {
Copy link
Contributor

@gongshaotian gongshaotian Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉恭喜遇到一个不明参数,这个参数在api文档里没有,但yaml文件里声明了,一般是用来区分训练和推理的,is_test开着代表是跑推理,麻烦下周会跟同学们分享下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到

if (axis < 0) axis += x_shape.size();

auto norm_shape = x_shape;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要copy一份

Copy link
Contributor Author

@MufanColin MufanColin Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我后来看了一下好像是因为 x_shape 被加了 const 标记,所以需要 copy 一份才能进行相关修改,泓清老师说这里保留原样不用修改了。

norm_shape[axis] = symbol::DimExpr(1);
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(norm_shape)});
}

return true;
}

bool NonzeroOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
Expand Down Expand Up @@ -1198,12 +1218,70 @@ bool NumelOpInferSymbolicShape(pir::Operation *op,

return true;
}
// bool P_NormOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

bool PNormOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &x_shape = x_shape_or_data.shape();
int x_rank = x_shape.size();

int axis = op->attribute<pir::Int32Attribute>("axis").data();
bool keepdim = op->attribute<pir::BoolAttribute>("keepdim").data();
bool asvector = op->attribute<pir::BoolAttribute>("asvector").data();

if (axis < 0) {
axis += x_rank;
}

bool axis_valid = (axis >= 0) && (axis < x_rank);

PADDLE_ENFORCE_EQ(
axis_valid,
true,
common::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], R is the rank of "
"Input(X). "
"But received axis: %d, R: %d. Current Input(X)'s shape is=[%s].",
axis,
x_rank,
x_shape));

std::vector<symbol::DimExpr> out_shape;

if (asvector) {
if (keepdim) {
for (int i = 0; i < x_rank; ++i) {
out_shape.emplace_back(symbol::DimExpr(1));
}
} else {
out_shape = {};
}
} else {
if (keepdim) {
for (int i = 0; i < x_rank; ++i) {
if (i == axis) {
out_shape.emplace_back(symbol::DimExpr(1));
} else {
out_shape.emplace_back(x_shape[i]);
}
}
} else {
for (int i = 0; i < x_rank; ++i) {
if (i != axis) {
out_shape.emplace_back(x_shape[i]);
}
}
}
}

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});

return true;
}

// bool PartialSumOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(MeanAll)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool3DWithIndex)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Multinomial)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nanmedian)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Norm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Norm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Numel)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(P_Norm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(PNorm)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PartialSum)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3367,6 +3367,7 @@
kernel :
func : norm
backward : norm_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : npu_identity
args : (Tensor x, int format = -1)
Expand Down Expand Up @@ -3428,6 +3429,7 @@
kernel :
func : p_norm
backward : p_norm_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : pad
args : (Tensor x, int[] paddings, Scalar pad_value)
Expand Down