Skip to content

Commit

Permalink
add slice
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbopd committed Jan 5, 2024
1 parent 75fe3ae commit 0252315
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 156 deletions.
3 changes: 3 additions & 0 deletions paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/dialect/shape/ir/shape_attribute.h"

namespace cinn {
namespace hlir {
Expand Down Expand Up @@ -146,6 +147,8 @@ utils::Attribute CompatibleInfo::ConvertAttribute(
} else if (src_attr.isa<paddle::dialect::DataTypeAttribute>()) {
auto dtype = src_attr.dyn_cast<paddle::dialect::DataTypeAttribute>().data();
dst_attr = phi::DataTypeToString(dtype);
} else if (src_attr.isa<::pir::shape::SymbolAttribute>()) {
auto dst_attr = src_attr.dyn_cast<::pir::shape::SymbolAttribute>().data();
} else if (src_attr.isa<::pir::ArrayAttribute>()) {
auto attr_vec = src_attr.dyn_cast<::pir::ArrayAttribute>().AsVector();
if (attr_vec.size() > 0) {
Expand Down
253 changes: 133 additions & 120 deletions paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/dialect/shape/ir/shape_attribute.h"
Expand All @@ -25,113 +26,113 @@ bool InferSymbolicShapeInterface::InferSymbolicShape(
}
} // namespace paddle::dialect

namespace paddle::dialect {

namespace {

bool InferSymbolicShapeAllEqualUnary(
bool SameOperandsAndResultShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
std::string operand_source_id = pir::GetValueId(&operand_source);
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);
shape_analysis->value_id_to_shapeordata_[res_id] =
shape_analysis->value_id_to_shapeordata_[operand_source_id];
return true;
}

bool InferSymbolicShapeAllEqualBinary(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
std::string operand_source_id = pir::GetValueId(&operand_source);
symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source];

op->set_attribute("symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
operand_shape_or_data));
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);
shape_analysis->value_id_to_shapeordata_[res_id] =
shape_analysis->value_id_to_shapeordata_[operand_source_id];
shape_analysis->value_to_shape_or_data_[res] = operand_shape_or_data;
return true;
}

} // namespace

namespace paddle::dialect {
bool AbsOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
return SameOperandsAndResultShape(op, shape_analysis);
}

bool Abs_OpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
return SameOperandsAndResultShape(op, shape_analysis);
}

bool DataOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
symbol::ShapeOrDataDimExprs sss;
auto attributes = op->attributes();
pir::Attribute attr = attributes["shape"];
std::vector<int64_t> dims =
attr.dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData();

std::vector<symbol::DimExpr> sym_dims;
for (auto dim : dims) {
symbol::DimExpr dim_expr;
if (dim == -1) {
symbol::DimExpr symbolic_dim_expr(shape_analysis->GetNextSymName());
dim_expr = symbolic_dim_expr;
} else {
symbol::DimExpr numeric_dim_expr(dim);
dim_expr = numeric_dim_expr;
}
sym_dims.push_back(dim_expr);
}

symbol::ShapeOrDataDimExprs shape_data{sym_dims};
op->set_attribute(
"sym_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), sss));
"symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));

pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;

// auto attributes = op->attributes();
// pir::Attribute attr = attributes["shape"];
// const auto &vec = attr.dyn_cast<pir::ArrayAttribute>().AsVector();
return true;
}

bool CastOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
return SameOperandsAndResultShape(op, shape_analysis);
}

bool Cast_OpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
return SameOperandsAndResultShape(op, shape_analysis);
}

bool ExpOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
return SameOperandsAndResultShape(op, shape_analysis);
}

bool Exp_OpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
return SameOperandsAndResultShape(op, shape_analysis);
}

bool SubtractOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualBinary(op, shape_analysis);
return SameOperandsAndResultShape(op, shape_analysis);
}

bool Subtract_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualBinary(op, shape_analysis);
return SameOperandsAndResultShape(op, shape_analysis);
}

bool ShapeOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
std::string operand_source_id = pir::GetValueId(&operand_source);
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);

std::vector<int64_t> dims =
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());
symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source];

std::vector<symbol::DimExpr> shapes;
for (int64_t dim : dims) {
symbol::DimExpr dim_expr;
if (dim == -1) {
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName());
dim_expr = res_dim_expr;
} else {
symbol::DimExpr res_dim_expr(dim);
dim_expr = res_dim_expr;
}
shapes.push_back(dim_expr);
}
symbol::ShapeOrDataDimExprs extend_shape_or_data =
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(
operand_shape_or_data);

symbol::ShapeOrDataDimExprs shape_data{shapes};
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
shape_analysis->value_to_shape_or_data_[res] = extend_shape_or_data;
op->set_attribute("symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
extend_shape_or_data));
return true;
}

Expand All @@ -143,27 +144,53 @@ bool ShapeSrOpInferSymbolicShape(
bool StackOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
std::string operand_source_id = pir::GetValueId(&operand_source);
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);
symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source];

symbol::ShapeOrDataDimExprs shape_data;
shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_id];
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
std::vector<symbol::DimExpr> out_dims;
if (operand_shape_or_data.data().has_value()) {
out_dims = operand_shape_or_data.data().value();
}
// else : pir::VectorType x =
// operand_source.type().dyn_cast<pir::VectorType>();
// TODO(zhangbopd): else branch is not implemented yet.

symbol::ShapeOrDataDimExprs shape_data{out_dims};
if (operand_shape_or_data.data().has_value()) {
shape_data =
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(shape_data);
}

op->set_attribute(
"symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;
return true;
}

bool ReshapeOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source_1 = op->operand_source(1);
std::string operand_source_1_id = pir::GetValueId(&operand_source_1);
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);
pir::Value operand_source_shape = op->operand_source(1);

symbol::ShapeOrDataDimExprs shape_data;
symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source_shape];

std::vector<symbol::DimExpr> out_dims;
if (operand_shape_or_data.data().has_value()) {
out_dims = operand_shape_or_data.data().value();
}

shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_1_id];
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
symbol::ShapeOrDataDimExprs shape_data{out_dims};
op->set_attribute(
"symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));

pir::OpResult res0 = op->result(0);
pir::OpResult res1 = op->result(1);
shape_analysis->value_to_shape_or_data_[res0] = shape_data;
shape_analysis->value_to_shape_or_data_[res1] =
shape_analysis->value_to_shape_or_data_[operand_source_shape];
return true;
}

Expand All @@ -174,81 +201,67 @@ bool Reshape_OpInferSymbolicShape(

bool FullIntArrayOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
for (auto &res : op->results()) {
std::string value_id = pir::GetValueId(&res);
std::vector<int64_t> dims =
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());

std::vector<symbol::DimExpr> shapes;
for (int64_t dim : dims) {
symbol::DimExpr dim_expr;
if (dim == -1) {
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName());
dim_expr = res_dim_expr;
} else {
symbol::DimExpr res_dim_expr(dim);
dim_expr = res_dim_expr;
}
shapes.push_back(dim_expr);
}
auto attributes = op->attributes();
pir::Attribute attr = attributes["value"];
const auto &vec = attr.dyn_cast<pir::ArrayAttribute>().AsVector();

std::vector<symbol::DimExpr> data;
for (auto item : vec) {
int64_t i = item.dyn_cast<pir::Int64Attribute>().data();
data.push_back(symbol::DimExpr(i));
}

auto attributes = op->attributes();
pir::Attribute attr = attributes["value"];
const auto &vec = attr.dyn_cast<pir::ArrayAttribute>().AsVector();
symbol::ShapeOrDataDimExprs shape_data =
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(data);

for (auto item : vec) {
int64_t i = item.dyn_cast<pir::Int64Attribute>().data();
shapes.push_back(symbol::DimExpr(i));
}
op->set_attribute(
"symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));

// for (auto &item : shapes) {
// VLOG(0) << symbol::ToString(item);
// }
pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;
return true;
}

symbol::ShapeOrDataDimExprs shape_data{shapes};
shape_analysis->value_id_to_shapeordata_[value_id] = shape_data;
return true;
}
bool SliceOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
// TODO(zhangbopd): Not implemented yet.
return true;
}

} // namespace paddle::dialect
namespace cinn::dialect {

bool SliceOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
// TODO(zhangbopd): Not implemented yet, different from the one in paddle
// dialect.
pir::Value operand_source = op->operand_source(0);
std::string operand_source_id = pir::GetValueId(&operand_source);
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);
symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source];
pir::AttributeMap attributes = op->attributes();

std::vector<int64_t> dims =
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());

std::vector<symbol::DimExpr> shapes;
for (int64_t dim : dims) {
symbol::DimExpr dim_expr;
if (dim == -1) {
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName());
dim_expr = res_dim_expr;
} else {
symbol::DimExpr res_dim_expr(dim);
dim_expr = res_dim_expr;
}
shapes.push_back(dim_expr);
}
std::vector<pir::Attribute> attr_starts =
attributes["starts"].dyn_cast<pir::ArrayAttribute>().AsVector();

// pir::AttributeMap attributes = op->attributes();
int64_t start = attr_starts[0].dyn_cast<pir::Int64Attribute>().data();

// auto attr_starts =
// attributes["starts"].dyn_cast<pir::ArrayAttribute>().AsVector();
// auto start = attr_starts[0].dyn_cast<pir::Int64Attribute>().data();
std::vector<symbol::DimExpr> out_dims;
if (operand_shape_or_data.data().has_value()) {
out_dims.push_back(operand_shape_or_data.data().value()[start]);
}

// auto attr_ends =
// attributes["ends"].dyn_cast<pir::ArrayAttribute>().AsVector();
// auto end = attr_ends[0].dyn_cast<pir::Int64Attribute>().data();
symbol::ShapeOrDataDimExprs shape_data{out_dims};
if (operand_shape_or_data.data().has_value()) {
shape_data =
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(shape_data);
}
op->set_attribute(
"symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));

symbol::ShapeOrDataDimExprs shape_data{shapes};
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ bool Reshape_OpInferSymbolicShape(
bool FullIntArrayOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool SliceOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

} // namespace paddle::dialect

namespace cinn::dialect {
Expand Down
Loading

0 comments on commit 0252315

Please sign in to comment.