Skip to content

Commit

Permalink
Change shape_opt_pass and unit test && Standard naming Part 3 (Paddle…
Browse files Browse the repository at this point in the history
…Paddle#57961)

* change shape_opt_pass and unit test
  • Loading branch information
zhangbopd authored Oct 10, 2023
1 parent 17d7383 commit b1c9d44
Show file tree
Hide file tree
Showing 14 changed files with 989 additions and 880 deletions.
7 changes: 7 additions & 0 deletions paddle/pir/core/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/pir/core/type.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/dialect.h"
#include "paddle/pir/core/type_base.h"

Expand All @@ -24,4 +25,10 @@ TypeId Type::type_id() { return storage_->abstract_type().type_id(); }
const AbstractType &Type::abstract_type() { return storage_->abstract_type(); }

Dialect &Type::dialect() const { return storage_->abstract_type().dialect(); }

bool Type::IsIntOrIndex() const {
return isa<IndexType>() || isa<Int8Type>() || isa<UInt8Type>() ||
isa<Int16Type>() || isa<Int32Type>() || isa<Int64Type>();
}

} // namespace pir
8 changes: 7 additions & 1 deletion paddle/pir/core/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class IR_API Type {
using TypeBase = detail::StorageHelperBase<ConcreteType,
BaseType,
StorageType,
pir::TypeManager,
TypeManager,
TraitOrInterface...>;

using Storage = TypeStorage;
Expand Down Expand Up @@ -115,6 +115,12 @@ class IR_API Type {
return pir::cast<U>(*this);
}

///
/// \brief Return true if this is an integer (any signedness) or an index
/// type.
///
bool IsIntOrIndex() const;

protected:
const Storage *storage_{nullptr};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace pir {

class Pass;

// Apply some shape-related optimization.
IR_API std::unique_ptr<Pass> CreateShapeOptimizationPass();

} // namespace pir
220 changes: 191 additions & 29 deletions paddle/pir/dialect/shape/transforms/shape_optimization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,127 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pir/dialect/shape/transforms/shape_optimization.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/pir/dialect/shape/ir/shape_op.h"

#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/dialect/shape/utils/shape_utils.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_manager.h"
#include "paddle/pir/pass/pass_registry.h"

namespace pir {
namespace {
using PassPipelineRunner =
std::function<bool(pir::PassManager&, pir::ModuleOp)>;

bool InsertTieShapeOnValue(pir::Value value,
pir::Builder& builder) { // NOLINT
auto ty = value.type().dyn_cast<paddle::dialect::DenseTensorType>();

if (!ty || ty.dims().size() == 0) return true;
std::vector<pir::Value> dimSizes;
for (int64_t dim = 0, rank = ty.dims().size(); dim < rank; ++dim) {
auto dimOp = builder.Build<pir::dialect::TensorDimOp>(value, dim);
dimSizes.push_back(dimOp.out());
}
builder.Build<pir::dialect::TieShapeOp>(value, dimSizes);
return true;
}

bool InsertTieShapeOnRegion(pir::Region* region);

bool InsertTieShapeOnOperation(pir::Operation* op,
pir::Builder& builder) { // NOLINT
// TODO(zhangbo63): skip more specialized Ops.
if (op->isa<pir::dialect::TieShapeOp>() || op->isa<pir::dialect::FuncOp>())
return true;

for (size_t i = 0; i < op->num_regions(); ++i) {
if (!InsertTieShapeOnRegion(&(op->region(i)))) return false;
}
builder.SetInsertionPointAfter(op);
for (pir::OpResult v : op->results()) {
if (!InsertTieShapeOnValue(v, builder)) return false;
}

return true;
}

bool InsertTieShapeOnBlock(pir::Block* block) {
pir::Builder builder =
pir::Builder(pir::IrContext::Instance(), block, block->begin());
// TODO(liujinnan): mapping block arguments

std::vector<pir::Operation*> op_list;
for (pir::Operation* op : *block) op_list.push_back(op);
for (pir::Operation* op : op_list) {
if (!InsertTieShapeOnOperation(op, builder)) return false;
}
return true;
}

bool InsertTieShapeOnRegion(pir::Region* region) {
for (pir::Block* block : *region) {
if (!InsertTieShapeOnBlock(block)) return false;
}
return true;
}

bool MaterializeShapeComputation(pir::ModuleOp m) {
if (!InsertTieShapeOnRegion(&(m->region(0)))) return false;
// TODO(liujinnan): add rewitter pattern for reifyInferShape.
return true;
}

bool IsCandidateShapeTensorType(Type type) {
auto tensor_type = type.dyn_cast<DenseTensorType>();
auto shaped_type = tensor_type.dyn_cast<ShapedTypeInterface>();

return (tensor_type && tensor_type && shaped_type.GetRank() == 1 &&
shaped_type.HasStaticShape() &&
shaped_type.GetElementType().IsIntOrIndex() &&
shaped_type.GetShape()[0] < 32);
}

class ShapeComputationIRAnalysis {
public:
using func = std::function<bool(Operation* op)>;
explicit ShapeComputationIRAnalysis(ModuleOp m,
SymbolicDimMgr& mgr); // NOLINT
bool Run();

private:
bool RunOnRegion(Region* region, func fn);
bool RunOnBlock(Block* block, func fn);
bool RunOnOperation(Operation* op, func fn);

bool BuildShapeOnOperation(Operation* op);
bool BuildShapeOnValue(Value value);

bool ApplyOpConstraint(Operation* op);
bool ApplyIndexOpConstraint(Operation* op);
bool ApplyTieShapeOpConstraint(Operation* op);

bool initialized_ = false;
ModuleOp m_;
SymbolicDimMgr& mgr_;

std::unordered_map<Value, SymbolicDim> value_to_sym_dim_;

// shape tensor is the 1D ranked tensor with int/index dtype.
std::unordered_map<Value, std::vector<SymbolicDim>> shape_tensor_to_sym_dims_;

std::unordered_map<Value, std::vector<SymbolicDim>> dense_tensor_to_sym_dims_;
};

// Returns true if the type is possible to be a shape tensor type.
// Shape tensor type :
// - rank-1 static-shaped tensor type
// - element type of the tensor is int or index
// - number of elements of the tensor < 32, supposing that the
// higiest possible rank is smaller than 32.

ShapeComputationIRAnalysis::ShapeComputationIRAnalysis(ModuleOp m,
SymbolicDimMgr& mgr)
Expand All @@ -25,16 +142,16 @@ bool ShapeComputationIRAnalysis::Run() {
// Make sure only run once.
if (initialized_) return false;
initialized_ = true;
auto buildShapeFunc =
auto build_shape_func =
std::bind(&ShapeComputationIRAnalysis::BuildShapeOnOperation,
this,
std::placeholders::_1);
if (!RunOnRegion(&(m_->region(0)), buildShapeFunc)) return false;
auto applyOpConstraintFunc =
if (!RunOnRegion(&(m_->region(0)), build_shape_func)) return false;
auto apply_op_constraint_func =
std::bind(&ShapeComputationIRAnalysis::ApplyOpConstraint,
this,
std::placeholders::_1);
if (!RunOnRegion(&(m_->region(0)), applyOpConstraintFunc)) return false;
if (!RunOnRegion(&(m_->region(0)), apply_op_constraint_func)) return false;
return true;
}

Expand Down Expand Up @@ -90,7 +207,7 @@ bool ShapeComputationIRAnalysis::BuildShapeOnOperation(Operation* op) {
op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(),
ArrayAttribute::get(m_->ir_context(), attrs));
}
rankedTensor2SymDims_[value] = std::move(symbols);
dense_tensor_to_sym_dims_[value] = std::move(symbols);
return true;
}
for (size_t i = 0; i < op->num_results(); ++i) {
Expand All @@ -101,15 +218,15 @@ bool ShapeComputationIRAnalysis::BuildShapeOnOperation(Operation* op) {

bool ShapeComputationIRAnalysis::BuildShapeOnValue(Value value) {
Type type = value.type();
if (IsIntOrIndex(type)) {
if (type.IsIntOrIndex()) {
SymbolicDim sym = mgr_.NewSymbolicDim();
value2SymDim_[value] = sym;
value_to_sym_dim_[value] = sym;
} else if (IsCandidateShapeTensorType(type)) {
auto shapedTy = type.dyn_cast<ShapedTypeInterface>();
auto shaped_type = type.dyn_cast<ShapedTypeInterface>();
std::vector<SymbolicDim> symbols;
for (size_t i = 0, d = shapedTy.GetShape()[0]; i < d; ++i)
for (size_t i = 0, d = shaped_type.GetShape()[0]; i < d; ++i)
symbols.push_back(mgr_.NewSymbolicDim());
shapeTensor2SymDims_[value] = std::move(symbols);
shape_tensor_to_sym_dims_[value] = std::move(symbols);
}
return true;
}
Expand All @@ -128,24 +245,24 @@ bool ShapeComputationIRAnalysis::ApplyIndexOpConstraint(Operation* op) {
if (op->num_results() == 0) return true;

Type type = op->result(0).type();
if (!IsIntOrIndex(type)) return true;

if (auto dimOp = op->dyn_cast<dialect::TensorDimOp>()) {
int64_t dimIndex = dimOp.index()
.dyn_cast<OpResult>()
.owner()
->attribute<Int64Attribute>("value")
.data();
value2SymDim_[dimOp.out()].UpdateKnownNonNegative(true);
if (!type.IsIntOrIndex()) return true;

if (auto dim_op = op->dyn_cast<dialect::TensorDimOp>()) {
int64_t dim_index = dim_op.index()
.dyn_cast<OpResult>()
.owner()
->attribute<Int64Attribute>("value")
.data();
value_to_sym_dim_[dim_op.out()].UpdateKnownNonNegative(true);
if (!mgr_.MapSymbolicDimEqual(
value2SymDim_[dimOp.out()],
rankedTensor2SymDims_[dimOp.source()][dimIndex])) {
value_to_sym_dim_[dim_op.out()],
dense_tensor_to_sym_dims_[dim_op.source()][dim_index])) {
return false;
}

} else if (auto constOp = op->dyn_cast<ConstantOp>()) {
int64_t val = constOp.value().dyn_cast<Int64Attribute>().data();
if (!mgr_.MapSymbolicDimEqual(value2SymDim_[op->result(0)],
} else if (auto const_op = op->dyn_cast<ConstantOp>()) {
int64_t val = const_op.value().dyn_cast<Int64Attribute>().data();
if (!mgr_.MapSymbolicDimEqual(value_to_sym_dim_[op->result(0)],
mgr_.NewConstantSymbolicDim(val))) {
return false;
}
Expand All @@ -155,15 +272,60 @@ bool ShapeComputationIRAnalysis::ApplyIndexOpConstraint(Operation* op) {
}

bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) {
if (auto tieShape = op->dyn_cast<dialect::TieShapeOp>()) {
auto& value = rankedTensor2SymDims_[op->operand_source(0)];
for (size_t idx = 0; idx < tieShape.dims().size(); ++idx) {
if (!mgr_.MapSymbolicDimEqual(value2SymDim_[tieShape.dims()[idx]],
if (auto tie_shape = op->dyn_cast<dialect::TieShapeOp>()) {
auto& value = dense_tensor_to_sym_dims_[op->operand_source(0)];
for (size_t idx = 0; idx < tie_shape.dims().size(); ++idx) {
if (!mgr_.MapSymbolicDimEqual(value_to_sym_dim_[tie_shape.dims()[idx]],
value[idx]))
return false;
mgr_.GetRootSymbolicDim(value[idx]).UpdateKnownNonNegative(true);
}
}
return true;
}

bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) {
// TODO(liujinnan): Do some Canonicalizer.
pir::SymbolicDimMgr mgr(m);
IR_ENFORCE(mgr.Load(),
"SymbolicDimMgr Load failed in OptimizeShapeComputation.");
ShapeComputationIRAnalysis analysis(m, mgr);
if (!analysis.Run()) {
return false;
}
IR_ENFORCE(mgr.Save(),
"SymbolicDimMgr save failed in OptimizeShapeComputation.");
return true;
}

class ShapeOptimizationPass : public pir::Pass {
public:
ShapeOptimizationPass() : pir::Pass("shape_optimization", 0) {}

void Run(pir::Operation* op) override {
auto module_op = op->dyn_cast<pir::ModuleOp>();
IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op.");
MaterializeShapeComputation(module_op);
// runner is for Canonicalizer.
PassPipelineRunner runner = [this](pir::PassManager& pm, pir::ModuleOp m) {
return pm.Run(m.program());
};
if (!OptimizeShapeComputation(module_op, runner)) {
return;
}
}

bool CanApplyOn(pir::Operation* op) const override {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
}
};

} // namespace

std::unique_ptr<Pass> CreateShapeOptimizationPass() {
return std::make_unique<ShapeOptimizationPass>();
}

} // namespace pir

REGISTER_IR_PASS(shape_optimization, pir::ShapeOptimizationPass);
52 changes: 0 additions & 52 deletions paddle/pir/dialect/shape/transforms/shape_optimization.h

This file was deleted.

Loading

0 comments on commit b1c9d44

Please sign in to comment.