Skip to content

Commit

Permalink
[Infer Symbolic Shape No.225][BUAA] shuffle_channel (#67463)
Browse files Browse the repository at this point in the history
* Finished shuffle_channel op

* Fixed some errors

* Fixed shuffle_channel op

* Fixed wrong file error
  • Loading branch information
MufanColin authored Aug 19, 2024
1 parent af0f692 commit 7718432
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2193,12 +2193,23 @@ bool ShapeSrOpInferSymbolicShape(
// return true;
// }

// bool ShuffleChannelOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool ShuffleChannelOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &x_dims = x_shape_or_data.shape();

PADDLE_ENFORCE_EQ(
x_dims.size(),
4,
common::errors::InvalidArgument("The layout of input is NCHW."));

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs(symbol::TensorShapeOrDataDimExprs(x_dims)));

return true;
}

bool SliceOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Shape)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShardIndex)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShapeSr)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShardIndex)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShuffleChannel)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShuffleChannel)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Slice)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Split)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SplitWithNum)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4271,7 +4271,7 @@
kernel :
func : shuffle_channel
backward : shuffle_channel_grad
# interfaces : paddle::dialect::InferSymbolicShapeInterface
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : sigmoid
args : (Tensor x)
Expand Down

0 comments on commit 7718432

Please sign in to comment.