Skip to content

Commit

Permalink
【Infer Symbolic Shape No.170】【BUAA】 Add read_file op (#67889)
Browse files Browse the repository at this point in the history
* Finished read_file op

* Resolved some suggested changes

* Resolved some suggested changes

* Removed unnecessary comment
  • Loading branch information
MufanColin authored Sep 20, 2024
1 parent 319500d commit 768e234
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,24 @@ bool RandintOpInferSymbolicShape(
}
}

// bool ReadFileOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool ReadFileOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
symbol::DimExpr unique_dim_sym = infer_context->GetNextSymName();

const std::vector<symbol::DimExpr> &out_shape = [&] {
std::vector<symbol::DimExpr> shape;
shape.emplace_back(symbol::DimExpr(1));
shape.emplace_back(unique_dim_sym);
return shape;
}();

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});

return true;
}

// bool RecvV2OpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullIntArray)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gaussian)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Randint)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Randperm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReadFile)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Seed)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReadFile)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(RecvV2)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TrilIndices)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TriuIndices)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3839,6 +3839,7 @@
param : [filename]
data_type : dtype
backend : place
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : real
args : (Tensor x)
Expand Down

0 comments on commit 768e234

Please sign in to comment.