-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【BUAA】【Infer Symbolic Shape No.3,46,107】Add as_strided, fold and yolo_loss #67037
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
ab86a17
to
d515eeb
Compare
@@ -1594,6 +1594,83 @@ bool Where_OpInferSymbolicShape(pir::Operation *op, | |||
return WhereOpInferSymbolicShape(op, infer_context); | |||
} | |||
|
|||
bool YoloLossOpInferSymbolicShape( | |||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | |||
const auto &dim_x = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名:x_shape
infer_context->GetShapeOrDataForValue(op->operand_source(2)).shape(); | ||
std::vector<int> anchors_mask = | ||
paddle::dialect::details::GetVectorAttr<int>(op, "anchor_mask"); | ||
int mask_num = static_cast<int>(anchors_mask.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
非必要cast
infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape(); | ||
const auto &dim_gtlabel = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(2)).shape(); | ||
std::vector<int> anchors_mask = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::vector &
phi::errors::InvalidArgument("Input(GTBox) should be a 3-D tensor, but " | ||
"received gtbox dimension size(%s)", | ||
dim_gtbox.size())); | ||
/*PADDLE_ENFORCE_EQ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删掉注释
@@ -31,6 +31,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Asinh) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Asinh_) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Assign) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Assign_) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AsStrided) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个op看接口描述应该不是same类型,我再确认一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已确认AsStrided不是same类型,麻烦参考kernel实现编写符号推导接口。顺便把infermeta接口也修改一下
@@ -994,6 +994,34 @@ bool Flatten_OpInferSymbolicShape( | |||
return FlattenOpInferSymbolicShape(op, infer_context); | |||
} | |||
|
|||
bool FoldOpInferSymbolicShape(pir::Operation *op, | |||
pir::InferSymbolicShapeContext *infer_context) { | |||
const auto &in_dims = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不规范,x_shape
std::vector<int> kernel_sizes = | ||
paddle::dialect::details::GetVectorAttr<int>(op, "kernel_sizes"); | ||
out_dims.push_back(in_dims[1] / | ||
(symbol::DimExpr(kernel_sizes[0] * kernel_sizes[1]))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不需要使用symbol::DimExpr() 构造 Dimexpr(),DimExpr是支持这些运算的
|
||
std::vector<int> output_sizes = | ||
paddle::dialect::details::GetVectorAttr<int>(op, "output_sizes"); | ||
infer_context->AddGreatThanOneCstr(symbol::DimExpr(output_sizes[0])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删掉 symbol::DimExpr()
std::vector<int> output_sizes = | ||
paddle::dialect::details::GetVectorAttr<int>(op, "output_sizes"); | ||
infer_context->AddGreatThanOneCstr(symbol::DimExpr(output_sizes[0])); | ||
infer_context->AddGreatThanOneCstr(symbol::DimExpr(output_sizes[1])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
|
||
out_dims.push_back(symbol::DimExpr(output_sizes[0])); | ||
out_dims.push_back(symbol::DimExpr(output_sizes[1])); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
少了DimExpr的约束,参考infermeta再修改一下吧
coverage exit-9: AsStridedOP: FoldOp: YoloLossOp: |
int rank = shape.size(); | ||
std::vector<symbol::DimExpr> out_shape; | ||
for (int i = 0; i < rank; ++i) { | ||
out_shape.push_back(symbol::DimExpr(shape[i])); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确认一下这个算子是否允许输入维度为-1吧,如果允许的话需要特判上新符号
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
…pe/.ipynb_checkpoints directory
for (int i = 0; i < rank; ++i) { | ||
symbol::DimExpr out_unknown = infer_context->GetNextSymName(); | ||
if (shape[i] == -1) { | ||
out_shape.push_back(out_unknown); | ||
} else { | ||
out_shape.push_back(symbol::DimExpr(shape[i])); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
symbol::DimExpr out_unknown = infer_context->GetNextSymName();最好放在if分支里,虽然这些新符号都没用到但是每个循环里符号计数都会增加
@@ -388,6 +388,30 @@ bool Assign_OpInferSymbolicShape( | |||
return AssignOpInferSymbolicShape(op, infer_context); | |||
} | |||
|
|||
bool AsStridedOpInferSymbolicShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改一下AsStride的InferMeta接口
在这个pr下贴上 上述两个遗留问题的PR |
上述两个遗留问题的解决PR:#67801 |
PR Category
CINN
PR Types
Others
Description
添加 as_strided, yolo_loss 和 fold 算子符号推导接口实现。
as_strided两个遗留问题的解决PR:#67801