Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… reduce_as
  • Loading branch information
zeroRains committed May 8, 2024
2 parents cffb57c + 8c438a3 commit 5d2998d
Show file tree
Hide file tree
Showing 303 changed files with 8,226 additions and 5,511 deletions.
5 changes: 1 addition & 4 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@ insert_final_newline = true
[*.{c,cc,cxx,cpp,cu,cuh,h,hpp,hxx,kps}]
indent_size = 2

[*.{py,java,r}]
[*.{py,pyi,java,r,toml}]
indent_size = 4

[Dockerfile.*]
indent_size = 4

[.flake8]
indent_size = 4

[*.go]
indent_style = tab
indent_size = 4
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ option(WITH_PIP_CUDA_LIBRARIES
"Paddle uses the CUDA library provided by NVIDIA" OFF)
option(WITH_NIGHTLY_BUILD
"Compile nightly paddle whl package of the develop branch" OFF)
option(WITH_CPP_TEST "Compile PaddlePaddle skip cpp test" ON)
find_package(Git REQUIRED)

# config GIT_URL with github mirrors to speed up dependent repos clone
Expand Down
24 changes: 12 additions & 12 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ void GroupOp::Print(pir::IrPrinter& printer) {
}

bool GroupOp::InferSymbolicShape(
::pir::ShapeConstraintIRAnalysis* shape_analysis) {
::pir::InferSymExprForBlock(*block(), shape_analysis);
::pir::InferSymbolicShapeContext* infer_context) {
::pir::InferSymExprForBlock(*block(), infer_context);

for (uint32_t rst_idx = 0; rst_idx < num_results(); rst_idx++) {
auto inner_yield_value = block()->back().operand_source(rst_idx);
const auto& shape =
shape_analysis->GetShapeOrDataForValue(inner_yield_value);
shape_analysis->SetShapeOrDataForValue(result(rst_idx), shape);
infer_context->GetShapeOrDataForValue(inner_yield_value);
infer_context->SetShapeOrDataForValue(result(rst_idx), shape);
}

if (VLOG_IS_ON(4)) {
Expand Down Expand Up @@ -204,16 +204,16 @@ void YieldStoreOp::Build(pir::Builder& builder,
void YieldStoreOp::VerifySig() {}

bool YieldStoreOp::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis* shape_analysis) {
shape_analysis->SetShapeOrDataForValue(
result(0), shape_analysis->GetShapeOrDataForValue(operand_source(0)));
pir::InferSymbolicShapeContext* infer_context) {
infer_context->SetShapeOrDataForValue(
result(0), infer_context->GetShapeOrDataForValue(operand_source(0)));
return true;
}

bool ConcatOp::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis* shape_analysis) {
pir::InferSymbolicShapeContext* infer_context) {
VLOG(4) << "Infer symbolic shape for cinn_op.concat";
return ConcatOpInferSymbolicShape(this->operation(), shape_analysis);
return ConcatOpInferSymbolicShape(this->operation(), infer_context);
}

void ConcatOp::Build(pir::Builder& builder, // NOLINT
Expand Down Expand Up @@ -476,7 +476,7 @@ GenerateShapeOp::ConvertAttributeToSymbolBindings(
}

bool GenerateShapeOp::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis* shape_analysis) {
pir::InferSymbolicShapeContext* infer_context) {
const auto attr_dim_exprs = [&] {
std::vector<symbol::DimExpr> dim_exprs{};
pir::Attribute dim_expr_attr = this->attributes().at("output_dim_exprs");
Expand Down Expand Up @@ -505,7 +505,7 @@ bool GenerateShapeOp::InferSymbolicShape(
}();
auto DimExprs4InputDim =
[&](int input_idx) -> const symbol::ShapeOrDataDimExprs& {
return shape_analysis->GetShapeOrDataForValue(
return infer_context->GetShapeOrDataForValue(
this->operand_source(input_idx));
};
auto DimExprs4SymbolName =
Expand All @@ -527,7 +527,7 @@ bool GenerateShapeOp::InferSymbolicShape(
symbol::ShapeOrDataDimExprs shape_or_data_dim_exprs{
symbol::TensorShapeOrDataDimExprs(shape, substituted_dim_exprs)};

shape_analysis->SetShapeOrDataForValue(this->out(), shape_or_data_dim_exprs);
infer_context->SetShapeOrDataForValue(this->out(), shape_or_data_dim_exprs);

return true;
}
Expand Down
8 changes: 4 additions & 4 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class IR_API GroupOp
pir::Block *block() const;
std::vector<pir::Operation *> GetOperators() const;

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);

void VerifySig();
void Print(pir::IrPrinter &printer); // NOLINT
Expand Down Expand Up @@ -102,7 +102,7 @@ class IR_API YieldStoreOp

void VerifySig();

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
};

class IR_API ConcatOp
Expand All @@ -123,7 +123,7 @@ class IR_API ConcatOp

void VerifySig() const {}

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
};

class IR_API SplitOp : public pir::Op<SplitOp> {
Expand Down Expand Up @@ -177,7 +177,7 @@ class IR_API GenerateShapeOp

pir::Value out() { return result(0); }

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);

static pir::Attribute ConvertSymbolBindingsToAttribute(
pir::Builder &builder, const SymbolBindings &symbol_bindings); // NOLINT
Expand Down
13 changes: 6 additions & 7 deletions paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
COMMON_DECLARE_bool(print_ir);
COMMON_DECLARE_bool(disable_dyshape_in_train);
COMMON_DECLARE_bool(enable_cinn_accuracy_check);
COMMON_DECLARE_bool(enable_fuse_parallel_matmul_pass);
PD_DECLARE_bool(group_schedule_tiling_first);

namespace cinn::dialect::ir {
Expand Down Expand Up @@ -84,7 +85,9 @@ void ApplyPdToCinnPass(
const std::function<std::shared_ptr<::pir::PassManager>()>&
CreatePassManager) {
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
pass_manager->AddPass(cinn::dialect::ir::CreateFuseParallelMatmulPass());
if (FLAGS_enable_fuse_parallel_matmul_pass) {
pass_manager->AddPass(cinn::dialect::ir::CreateFuseParallelMatmulPass());
}
pass_manager->AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass());
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());
pass_manager->Run(program);
Expand Down Expand Up @@ -220,14 +223,10 @@ void ApplyCinnPass(::pir::Program* program,
ApplyPdToCinnPass(program, CreatePassManager);
ApplyCinnPreprocessPass(program, CreatePassManager);
ApplyBuildGroupOpPass(program, CreatePassManager);
LOG(INFO) << "====[pir-to-py-code group-ops begin]===" << std::endl
<< PirToPyCodeConverter().Convert(*program);
LOG(INFO) << "====[pir-to-py-code group-ops end]===";
PirToPyCodeConverter().SaveIfFlagEnabled("group_op_programs", *program);
ApplyGroupOpPass(program, CreatePassManager);
ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager);
LOG(INFO) << "====[pir-to-py-code fusion-ops begin]===" << std::endl
<< PirToPyCodeConverter().Convert(*program);
LOG(INFO) << "====[pir-to-py-code fusion-ops end]===";
PirToPyCodeConverter().SaveIfFlagEnabled("fusion_op_programs", *program);
LOG(INFO) << "FusionOp count before lowering : *****[ "
<< GetOpCount<cinn::dialect::FusionOp>(program->module_op())
<< " ]*****";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,9 @@ class AddYieldStoreInFusionOpPattern
auto orignal_base = op->operand_source(i);
op->operand(i).set_source(store_op.result(0));

if (shape_analysis.HasShapeOrDataForValue(orignal_base)) {
shape_analysis.SetShapeOrDataForValue(
store_op.result(0),
shape_analysis.GetShapeOrDataForValue(orignal_base));
}
shape_analysis.SetShapeOrDataForValue(
store_op.result(0),
shape_analysis.GetShapeOrDataForValue(orignal_base));
}

return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ class BlockDimExprsAsserter {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
} else {
bool infer_result = interface.InferSymbolicShape(shape_analysis.get());
// TODO(Hongqing-work): delete this after the shape analysis reconstruct
// is done.
bool infer_result = interface.InferSymbolicShape(
shape_analysis->GetInferSymbolicShapeContext());
PADDLE_ENFORCE_EQ(infer_result,
true,
::common::errors::PreconditionNotMet(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,9 @@ ::pir::GroupOpsVec CloneOps(
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());

for (size_t i = 0; i < op->num_results(); ++i) {
if (shape_analysis.HasShapeOrDataForValue(op->result(i))) {
shape_analysis.SetShapeOrDataForValue(
new_op->result(i),
shape_analysis.GetShapeOrDataForValue(op->result(i)));
}
shape_analysis.SetShapeOrDataForValue(
new_op->result(i),
shape_analysis.GetShapeOrDataForValue(op->result(i)));
}

vec_new_op_list.push_back(new_op);
Expand Down Expand Up @@ -357,11 +355,9 @@ class CinnGroupClusterPattern
// update ir mapping
for (size_t i = 0; i < output_values.size(); ++i) {
ir_mapping.Add(output_values[i], new_group_op->result(i));
if (shape_analysis.HasShapeOrDataForValue(output_values[i])) {
shape_analysis.SetShapeOrDataForValue(
new_group_op->result(i),
shape_analysis.GetShapeOrDataForValue(output_values[i]));
}
shape_analysis.SetShapeOrDataForValue(
new_group_op->result(i),
shape_analysis.GetShapeOrDataForValue(output_values[i]));
}
for (size_t i = 0; i < output_values.size(); ++i) {
auto find_it = all_output_values.find(output_values[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,16 @@ bool ReplaceOpWithReshapeOp(pir::Operation* op,
std::vector<int> shape = phi::vectorize<int>(
output.type().dyn_cast<pir::DenseTensorType>().dims());

if (shape_analysis->HasShapeOrDataForValue(op->result(0))) {
const auto& shape_info =
shape_analysis->GetShapeOrDataForValue(op->result(0)).shape();
int temp_dim = -1;

for (size_t i = 0; i < shape_info.size(); ++i) {
if (shape_info[i].isa<int64_t>()) {
shape[i] = shape_info[i].Get<int64_t>();
} else {
shape[i] = temp_dim;
temp_dim = 1;
}
const auto& shape_info =
shape_analysis->GetShapeOrDataForValue(op->result(0)).shape();
int temp_dim = -1;

for (size_t i = 0; i < shape_info.size(); ++i) {
if (shape_info[i].isa<int64_t>()) {
shape[i] = shape_info[i].Get<int64_t>();
} else {
shape[i] = temp_dim;
temp_dim = 1;
}
}
return shape;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,11 @@ bool RemoveOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
if (has_dynamic_shape) {
auto& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
if (shape_analysis.HasShapeOrDataForValue(input) &&
shape_analysis.HasShapeOrDataForValue(output)) {
auto input_sym_shape =
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(input));
auto output_sym_shape =
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(output));
return input_sym_shape == output_sym_shape;
}
return false;
auto input_sym_shape =
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(input));
auto output_sym_shape =
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(output));
return input_sym_shape == output_sym_shape;
}
return GetDims(input) == GetDims(output);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ void InferSymbolicShapeForSubgraph(
auto infer_symbolic_shape_interface =
op->dyn_cast<paddle::dialect::InferSymbolicShapeInterface>();
if (infer_symbolic_shape_interface) {
infer_symbolic_shape_interface.InferSymbolicShape(shape_analysis);
// TODO(Hongqing-work): delete this after the shape analysis reconstruct
// is done.
infer_symbolic_shape_interface.InferSymbolicShape(
shape_analysis->GetInferSymbolicShapeContext());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
Expand Down Expand Up @@ -348,7 +351,6 @@ bool ReplaceShapeOpsToGenerateShape(
auto ShapeOrDataDimExprs4Value =
[&shape_analysis](
pir::Value value) -> const symbol::ShapeOrDataDimExprs& {
CHECK(shape_analysis->HasShapeOrDataForValue(value));
return shape_analysis->GetShapeOrDataForValue(value);
};
std::optional<pir::Value> opt_generated_shape =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,6 @@ class DynamicToStaticConverter {
}

bool Convert() {
if (!IsSymbolFullyInfered()) {
return false;
}
bool updated = false;
VisitEachValue(fusion_op_, [&](pir::Value value) {
updated |= UpdateValueShape(value);
Expand All @@ -116,16 +113,6 @@ class DynamicToStaticConverter {
}

private:
bool IsSymbolFullyInfered() {
bool is_infered = true;
VisitEachValue(fusion_op_, [&](pir::Value value) {
if (!shape_analysis_->HasShapeOrDataForValue(value)) {
is_infered = false;
}
});
return is_infered;
}

DimExpr4SymbolName InitDimExpr4SymbolName() {
const auto* map = GetGlobalDynamicToStaticDimMap();
CHECK(map->has_value());
Expand Down Expand Up @@ -178,7 +165,6 @@ class DynamicToStaticConverter {

bool UpdateValueShape(pir::Value value) {
bool update = false;
CHECK(shape_analysis_->HasShapeOrDataForValue(value));
const auto& origin_shape = GetOriginValueShape(value);
const auto& target_shape = GetTargetValueShape(value);
PADDLE_ENFORCE_EQ(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ struct StaticDimToDynamicConverter {
&pir::ShapeAnalysisManager::Instance().Get(
this->fusion_op->GetParentProgram());
ForEachValue([&](pir::Value value) {
CHECK(shape_analysis->HasShapeOrDataForValue(value));
const auto& origin_shape = GetOriginValueShape(value);
const auto& target_shape = GetTargetValueShape(
shape_analysis->GetShapeOrDataForValue(value).shape());
Expand Down Expand Up @@ -369,26 +368,8 @@ struct StaticDimToDynamicConverter {
pir::Value value,
int64_t constant,
const std::string& symbol) {
if (shape_analysis->HasShapeOrDataForValue(value)) {
const auto& old = shape_analysis->GetShapeOrDataForValue(value).shape();
return ConvertShapeOrDataDimExprs(Converter, old, constant, symbol);
} else {
auto& dims = value.type().dyn_cast<::pir::DenseTensorType>().dims();
const auto& int_dims = ::common::vectorize<int>(dims);
std::vector<symbol::DimExpr> old{};
for (int dim : int_dims) {
old.emplace_back(static_cast<std::int64_t>(dim));
}
const auto& opt_exprs =
ConvertShapeOrDataDimExprs(Converter, old, constant, symbol);
if (opt_exprs.has_value()) {
return opt_exprs.value();
} else {
return symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(old)};
}
}
PADDLE_THROW(phi::errors::Fatal("Dead code"));
const auto& old = shape_analysis->GetShapeOrDataForValue(value).shape();
return ConvertShapeOrDataDimExprs(Converter, old, constant, symbol);
}

template <typename ConverterT>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,14 @@ void SimplifyDimExpr(pir::Operation* module_op) {

VisitEachOp(module_op, [&](pir::Operation& op) {
VisitEachValue(op, [&](pir::Value value) {
if (!shape_analysis->HasShapeOrDataForValue(value)) {
VLOG(4) << "SimplifyDimExpr: shape_analysis can't find ShapeOrData for "
"value of the op:"
<< op.name();
} else {
const symbol::ShapeOrDataDimExprs& shape_or_data =
shape_analysis->GetShapeOrDataForValue(value);
VLOG(8) << op.name() << " origin_shape_or_data: " << shape_or_data;
symbol::ShapeOrDataDimExprs simplified_shape_or_data =
SimplifyShapeOrData(shape_or_data);
VLOG(8) << op.name()
<< " simplified_shape_or_data: " << simplified_shape_or_data;
shape_analysis->SetShapeOrDataForValue(value, simplified_shape_or_data);
}
const symbol::ShapeOrDataDimExprs& shape_or_data =
shape_analysis->GetShapeOrDataForValue(value);
VLOG(8) << op.name() << " origin_shape_or_data: " << shape_or_data;
symbol::ShapeOrDataDimExprs simplified_shape_or_data =
SimplifyShapeOrData(shape_or_data);
VLOG(8) << op.name()
<< " simplified_shape_or_data: " << simplified_shape_or_data;
shape_analysis->SetShapeOrDataForValue(value, simplified_shape_or_data);
});
if (op.num_results() > 0) {
pir::shape::SetShapeAttrForOp(
Expand Down
Loading

0 comments on commit 5d2998d

Please sign in to comment.