Skip to content

Commit

Permalink
Fix check_infer_symbolic_pass (#61791)
Browse files Browse the repository at this point in the history
* fix check_infer_symbolic_pass to support Tensorlist for value

* fix some problem of simplify_dim_expr_pass

* modify according to the comment
  • Loading branch information
JiaWenxuan authored Feb 20, 2024
1 parent 796a71f commit 1d82e2c
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,29 @@
#include "paddle/common/flags.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/dialect/shape/utils/dim_expr.h"

namespace cinn {
namespace dialect {
namespace ir {

namespace {

std::string SprintShape(const std::vector<std::int64_t>& shape) {
std::string str = "[";
for (std::int64_t value : shape) {
str += std::to_string(value);
if (value != shape.back()) {
std::string SprintShape(const std::vector<std::vector<std::int64_t>>& shapes) {
std::string str;
for (int i = 0; i < shapes.size(); i++) {
str += "[";
for (int j = 0; j < shapes[i].size(); j++) {
str += std::to_string(shapes[i][j]);
if (j != shapes[i].size() - 1) {
str += ", ";
}
}
str += "]";
if (i != shapes.size() - 1) {
str += ", ";
}
}
return str + "]";
return str;
}

void PrintProgram(pir::ModuleOp m, const std::string& mgs) {
Expand All @@ -51,23 +57,54 @@ void PrintProgram(pir::ModuleOp m, const std::string& mgs) {
<< print_stream.str();
}

std::vector<std::int64_t> GetStaticValueShape(pir::Value value) {
const auto& dim = value.type().dyn_cast<::pir::DenseTensorType>().dims();
return ::common::vectorize(dim);
std::vector<std::vector<std::int64_t>> GetStaticValueShape(pir::Value value) {
std::vector<std::vector<std::int64_t>> static_shape;
if (const pir::DenseTensorType& dense_tensor =
value.type().dyn_cast<::pir::DenseTensorType>()) {
static_shape.push_back(::common::vectorize(dense_tensor.dims()));
} else if (const pir::VectorType vector_tensor =
value.type().dyn_cast<::pir::VectorType>()) {
for (size_t i = 0; i < vector_tensor.size(); i++) {
if (vector_tensor[i].isa<pir::DenseTensorType>()) {
const pir::DenseTensorType& dense_tensor =
vector_tensor[i].dyn_cast<::pir::DenseTensorType>();
static_shape.push_back(::common::vectorize(dense_tensor.dims()));
}
}
} else {
IR_THROW("error:the value doesn't have DenseTensorType");
}
return static_shape;
}

std::optional<std::vector<std::int64_t>> GetDynamicValueShape(
std::vector<std::int64_t> GetShapeFromTensor(
const symbol::TensorShapeOrDataDimExprs& tensor_shape_or_data) {
std::vector<std::int64_t> dynamic_shape;
for (const auto& dim_expr_shape : tensor_shape_or_data.shape()) {
CHECK(dim_expr_shape.Has<std::int64_t>());
dynamic_shape.push_back(dim_expr_shape.Get<std::int64_t>());
}
return dynamic_shape;
}

std::vector<std::vector<std::int64_t>> GetDynamicValueShape(
pir::Value value, const pir::ShapeConstraintIRAnalysis& shape_analysis) {
std::vector<std::vector<std::int64_t>> dynamic_shapes;
if (!shape_analysis.HasShapeOrDataForValue(value)) {
return std::nullopt;
}
const auto& dim_expr_dynamic_shapes =
shape_analysis.GetShapeOrDataForValue(value).shape();
std::vector<std::int64_t> dynamic_shapes{};
for (const auto& dim_expr_shape : dim_expr_dynamic_shapes) {
CHECK(dim_expr_shape.Has<std::int64_t>());
dynamic_shapes.push_back(dim_expr_shape.Get<std::int64_t>());
return dynamic_shapes;
}
symbol::ShapeOrDataDimExprs shape_or_data =
shape_analysis.GetShapeOrDataForValue(value);
auto lambdas = symbol::Overloaded{
[&](const symbol::TensorShapeOrDataDimExprs& tensor_shape_or_data) {
dynamic_shapes.push_back(GetShapeFromTensor(tensor_shape_or_data));
},
[&](const symbol::TensorListShapeOrDataDimExprs& tensor_list) {
for (const auto& tensor_shape_or_data : tensor_list) {
dynamic_shapes.push_back(GetShapeFromTensor(tensor_shape_or_data));
}
}};
std::visit(lambdas, shape_or_data.variant());
return dynamic_shapes;
}

Expand All @@ -76,22 +113,16 @@ void CompareStaticAndDynamicValueShape(
const pir::ShapeConstraintIRAnalysis& shape_analysis,
int op_index,
pir::ModuleOp module_op) {
std::vector<std::int64_t> static_value_shape = GetStaticValueShape(value);
std::optional<std::vector<std::int64_t>> opt_dynamic_value_shape =
std::vector<std::vector<std::int64_t>> static_value_shape =
GetStaticValueShape(value);
std::vector<std::vector<std::int64_t>> dynamic_value_shape =
GetDynamicValueShape(value, shape_analysis);
if (opt_dynamic_value_shape.has_value()) {
if (static_value_shape != opt_dynamic_value_shape.value()) {
VLOG(4) << "CheckInferSymbolic failed, in the fellowing program, the "
<< op_index
<< "th op : the shape is not equal\nthe static shape is: "
<< SprintShape(static_value_shape)
<< ", and the dynamic shape is: "
<< SprintShape(opt_dynamic_value_shape.value());
PrintProgram(module_op, "CheckInferSymbolic");
}
} else {
if (static_value_shape != dynamic_value_shape) {
VLOG(4) << "CheckInferSymbolic failed, in the fellowing program, the "
<< op_index << "th op infer symbolic failed";
<< op_index
<< "th op : the shape is not equal\nthe static shape is: "
<< SprintShape(static_value_shape) << ", and the dynamic shape is: "
<< SprintShape(dynamic_value_shape);
PrintProgram(module_op, "CheckInferSymbolic");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,11 @@ symbol::TensorShapeOrDataDimExprs SimplifyTensorShapeOrData(
SimplifyDimExpr(shape_or_data.shape());
if (!shape_or_data.data().has_value()) {
return symbol::ShapeOrData<symbol::DimExpr>(simplified_shape);
} else {
std::vector<symbol::DimExpr> simplified_data =
SimplifyDimExpr(shape_or_data.data().value());
return symbol::ShapeOrData<symbol::DimExpr>(simplified_shape,
simplified_data);
}
std::vector<symbol::DimExpr> simplified_data =
SimplifyDimExpr(shape_or_data.data().value());
return symbol::ShapeOrData<symbol::DimExpr>(simplified_shape,
simplified_data);
}

symbol::ShapeOrDataDimExprs SimplifyShapeOrData(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace cinn {
namespace dialect {
namespace ir {

// This is a helper pass for simplify the DimExpr after ShapeOptimizationPass
// This is a helper pass for simplifying the DimExpr after ShapeOptimizationPass
std::unique_ptr<::pir::Pass> CreateSimplifyDimExprPass();
} // namespace ir
} // namespace dialect
Expand Down
7 changes: 4 additions & 3 deletions test/ir/pir/cinn/symbolic/test_check_infer_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ def apply_to_static(net, use_cinn):
)


def exp_sub(x):
def exp_sub_concat(x):
y = paddle.exp(x)
z = y - x
return z
out = paddle.concat([z, x], 0)
return out


class CheckInferSymbolicNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fn = exp_sub
self.fn = exp_sub_concat

def forward(self, x):
out = self.fn(x)
Expand Down

0 comments on commit 1d82e2c

Please sign in to comment.