-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【BUAA】【Infer Symbolic Shape No.3,46,107】Add as_strided, fold and yolo_loss #67037
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
2fad015
a
Fripping 7f66ddc
update cinn api
Fripping 243febb
Delete paddle/phi/ops/yaml/.ipynb_checkpoints directory
Fripping 31c0cd3
Delete paddle/fluid/pir/dialect/operator/interface/infer_symbolic_sha…
Fripping 84c12b8
update cinn api
Fripping ca6bd5d
update -1 sign
Fripping 5a28eea
Delete paddle/fluid/pir/dialect/operator/interface/infer_symbolic_sha…
Fripping 1718327
Merge branch 'develop' into cin7
Fripping e3a61ec
Update unary_infer_sym.cc
Fripping File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -388,6 +388,30 @@ bool Assign_OpInferSymbolicShape( | |
return AssignOpInferSymbolicShape(op, infer_context); | ||
} | ||
|
||
bool AsStridedOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const std::vector<int> &shape = | ||
paddle::dialect::details::GetVectorAttr<int>(op, "dims"); | ||
|
||
int rank = shape.size(); | ||
std::vector<symbol::DimExpr> out_shape; | ||
for (int i = 0; i < rank; ++i) { | ||
symbol::DimExpr out_unknown = infer_context->GetNextSymName(); | ||
if (shape[i] == -1) { | ||
out_shape.push_back(out_unknown); | ||
} else { | ||
out_shape.push_back(symbol::DimExpr(shape[i])); | ||
} | ||
} | ||
Comment on lines
+398
to
+405
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. symbol::DimExpr out_unknown = infer_context->GetNextSymName();最好放在if分支里,虽然这些新符号都没用到但是每个循环里符号计数都会增加 |
||
|
||
infer_context->SetShapeOrDataForValue( | ||
op->result(0), | ||
symbol::ShapeOrDataDimExprs{ | ||
symbol::TensorShapeOrDataDimExprs(out_shape)}); | ||
|
||
return true; | ||
} | ||
|
||
bool BipartiteMatchOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const auto &dist_mat_shape_or_data = | ||
|
@@ -1060,6 +1084,83 @@ bool Flatten_OpInferSymbolicShape( | |
return FlattenOpInferSymbolicShape(op, infer_context); | ||
} | ||
|
||
bool FoldOpInferSymbolicShape(pir::Operation *op, | ||
pir::InferSymbolicShapeContext *infer_context) { | ||
const auto &x_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); | ||
|
||
std::vector<symbol::DimExpr> out_shape; | ||
out_shape.push_back(x_shape[0]); | ||
|
||
const std::vector<int> &output_sizes = | ||
paddle::dialect::details::GetVectorAttr<int>(op, "output_sizes"); | ||
PADDLE_ENFORCE_EQ( | ||
output_sizes.size(), | ||
2, | ||
common::errors::InvalidArgument( | ||
"It is expected output_size equals to 2, but got size %d", | ||
output_sizes.size())); | ||
infer_context->AddGreatThanOneCstr(output_sizes[0]); | ||
infer_context->AddGreatThanOneCstr(output_sizes[1]); | ||
|
||
const std::vector<int> &kernel_sizes = | ||
paddle::dialect::details::GetVectorAttr<int>(op, "kernel_sizes"); | ||
const std::vector<int> &dilations = | ||
paddle::dialect::details::GetVectorAttr<int>(op, "dilations"); | ||
const std::vector<int> &strides = | ||
paddle::dialect::details::GetVectorAttr<int>(op, "strides"); | ||
const std::vector<int> &paddings = | ||
paddle::dialect::details::GetVectorAttr<int>(op, "paddings"); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
kernel_sizes.size(), | ||
2, | ||
common::errors::InvalidArgument( | ||
"It is expected kernel_size equals to 2, but got size %d", | ||
kernel_sizes.size())); | ||
PADDLE_ENFORCE_EQ( | ||
strides.size(), | ||
2, | ||
common::errors::InvalidArgument( | ||
"It is expected strides_size equals to 2, but got size %d", | ||
strides.size())); | ||
PADDLE_ENFORCE_EQ( | ||
paddings.size(), | ||
4, | ||
common::errors::InvalidArgument( | ||
"It is expected paddings_size equals to 4, but got size %d", | ||
paddings.size())); | ||
PADDLE_ENFORCE_EQ( | ||
dilations.size(), | ||
2, | ||
common::errors::InvalidArgument( | ||
"It is expected dilations_size equals to 2, but got size %d", | ||
dilations.size())); | ||
|
||
int blocks_height = (output_sizes[0] + 2 * paddings[0] - | ||
(dilations[0] * (kernel_sizes[0] - 1) + 1)) / | ||
strides[0] + | ||
1; | ||
int blocks_width = (output_sizes[1] + 2 * paddings[1] - | ||
(dilations[1] * (kernel_sizes[1] - 1) + 1)) / | ||
strides[1] + | ||
1; | ||
|
||
infer_context->AddEqualCstr((blocks_height * blocks_width), x_shape[2]); | ||
|
||
out_shape.push_back(x_shape[1] / (kernel_sizes[0] * kernel_sizes[1])); | ||
|
||
out_shape.push_back(symbol::DimExpr(output_sizes[0])); | ||
out_shape.push_back(symbol::DimExpr(output_sizes[1])); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 少了DimExpr的约束,参考infermeta再修改一下吧 |
||
infer_context->SetShapeOrDataForValue( | ||
op->result(0), | ||
symbol::ShapeOrDataDimExprs{ | ||
symbol::TensorShapeOrDataDimExprs(out_shape)}); | ||
|
||
return true; | ||
} | ||
|
||
bool IdentityLossOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const auto &input_shape = | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改一下AsStride的InferMeta接口