Skip to content

Commit

Permalink
[PIR] Support wrap_type_interface for AlloctedDenseTensorType Allocat…
Browse files Browse the repository at this point in the history
…edSelectedRowsType and AllocatedDenseTensorArrayType (#62451)

* refine code

* fix
  • Loading branch information
zhangbo9674 authored Mar 6, 2024
1 parent 7bfde24 commit 2ca34a7
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 441 deletions.
12 changes: 12 additions & 0 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
namespace paddle {
namespace dialect {

pir::Type AllocatedDenseTensorType::prim_type() {
return storage()->dense_tensor_type_;
}

const phi::Place& AllocatedDenseTensorType::place() const {
return storage()->place_;
}
Expand All @@ -41,6 +45,10 @@ size_t AllocatedDenseTensorType::offset() const {
return storage()->dense_tensor_type_.offset();
}

pir::Type AllocatedSelectedRowsType::prim_type() {
return storage()->selected_rows_type_;
}

const phi::Place& AllocatedSelectedRowsType::place() const {
return storage()->place_;
}
Expand All @@ -65,6 +73,10 @@ size_t AllocatedSelectedRowsType::offset() const {
return storage()->selected_rows_type_.offset();
}

pir::Type AllocatedDenseTensorArrayType::prim_type() {
return storage()->dense_tensor_array_type_;
}

const phi::Place& AllocatedDenseTensorArrayType::place() const {
return storage()->place_;
}
Expand Down
15 changes: 12 additions & 3 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ namespace dialect {
class AllocatedDenseTensorType
: public pir::Type::TypeBase<AllocatedDenseTensorType,
pir::Type,
AllocatedDenseTensorTypeStorage> {
AllocatedDenseTensorTypeStorage,
pir::WrapTypeInterface> {
public:
using Base::Base;

Expand All @@ -49,6 +50,8 @@ class AllocatedDenseTensorType
ctx, place, dense_tensor_type);
}

pir::Type prim_type();

const phi::Place &place() const;

pir::Type dtype() const;
Expand All @@ -65,7 +68,8 @@ class AllocatedDenseTensorType
class AllocatedSelectedRowsType
: public pir::Type::TypeBase<AllocatedSelectedRowsType,
pir::Type,
AllocatedSelectedRowsTypeStorage> {
AllocatedSelectedRowsTypeStorage,
pir::WrapTypeInterface> {
public:
using Base::Base;

Expand All @@ -90,6 +94,8 @@ class AllocatedSelectedRowsType
ctx, place, type);
}

pir::Type prim_type();

const phi::Place &place() const;

pir::Type dtype() const;
Expand All @@ -106,7 +112,8 @@ class AllocatedSelectedRowsType
class AllocatedDenseTensorArrayType
: public pir::Type::TypeBase<AllocatedDenseTensorArrayType,
pir::Type,
AllocatedDenseTensorArrayTypeStorage> {
AllocatedDenseTensorArrayTypeStorage,
pir::WrapTypeInterface> {
public:
using Base::Base;

Expand All @@ -129,6 +136,8 @@ class AllocatedDenseTensorArrayType
ctx, place, type);
}

pir::Type prim_type();

const phi::Place &place() const;

const pir::Type &dtype() const;
Expand Down
39 changes: 0 additions & 39 deletions paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,6 @@
{type} {name};
if ({name}_.type().isa<{type}>()) {{
{name} = {name}_.type().dyn_cast<{type}>(); (void){name};
}} else if ({name}_.type().isa<{allocated_type}>()) {{
{allocated_type} allocated_{name} = {name}_.type().dyn_cast<{allocated_type}>();
{name} = {type}::get(pir::IrContext::Instance(),
allocated_{name}.dtype(),
allocated_{name}.dims(),
allocated_{name}.data_layout(),
allocated_{name}.lod(),
allocated_{name}.offset());
(void){name};
}} else {{
PADDLE_THROW(phi::errors::Unimplemented("Only support {type} or {allocated_type}"));
}}
Expand Down Expand Up @@ -158,20 +149,11 @@ def GenBuildOutputsPart2(
paddle::dialect::IrMetaTensor meta_{name};
paddle::dialect::IrTensor ir_tensor_{name};
if ({name}_.impl() != nullptr) {{
VLOG(4) << "Builder construction dense_{name}";
{type} {name};
if ({name}_.type().isa<{type}>()) {{
{name} = {name}_.type().dyn_cast<{type}>();
}} else if ({name}_.type().isa<{allocated_type}>()) {{
{allocated_type} allocated_{name} = {name}_.type().dyn_cast<{allocated_type}>();
{name} = {type}::get(pir::IrContext::Instance(),
allocated_{name}.dtype(),
allocated_{name}.dims(),
allocated_{name}.data_layout(),
allocated_{name}.lod(),
allocated_{name}.offset());
}} else {{
PADDLE_THROW(phi::errors::Unimplemented("Only support {type} or {allocated_type}"));
}}
Expand All @@ -195,13 +177,6 @@ def GenBuildOutputsPart2(
{name}_type.data_layout(),
{name}_type.lod(),
{name}_type.offset()));
}} else if({name}[i].isa<paddle::dialect::AllocatedDenseTensorType>()){{
auto {name}_type = {name}[i].dyn_cast<paddle::dialect::AllocatedDenseTensorType>();
vec_ir_tensor_{name}.push_back(paddle::dialect::IrTensor(paddle::dialect::TransToPhiDataType({name}_type.dtype()),
{name}_type.dims(),
{name}_type.data_layout(),
{name}_type.lod(),
{name}_type.offset()));
}} else {{
PADDLE_THROW(phi::errors::Unimplemented("Only support DenseTensorType or AllocatedDenseTensorType"));
}}
Expand All @@ -228,13 +203,6 @@ def GenBuildOutputsPart2(
{name}_type.data_layout(),
{name}_type.lod(),
{name}_type.offset()));
}} else if({name}[i].isa<paddle::dialect::AllocatedDenseTensorType>()){{
auto {name}_type = {name}[i].dyn_cast<paddle::dialect::AllocatedDenseTensorType>();
vec_ir_tensor_{name}.push_back(paddle::dialect::IrTensor(paddle::dialect::TransToPhiDataType({name}_type.dtype()),
{name}_type.dims(),
{name}_type.data_layout(),
{name}_type.lod(),
{name}_type.offset()));
}} else {{
PADDLE_THROW(phi::errors::Unimplemented("Only support DenseTensorType or AllocatedDenseTensorType"));
}}
Expand Down Expand Up @@ -273,13 +241,6 @@ def GenBuildOutputsPart2(
{name}_size = 1;
}}
{name} = std::vector<int64_t>({name}_size, -1);
}} else if ({name}_.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {{
common::DDim {name}_dim = {name}_.type().dyn_cast<paddle::dialect::AllocatedDenseTensorType>().dims();
size_t {name}_size = common::product({name}_dim);
if (common::contain_unknown_dim({name}_dim)) {{
{name}_size = 1;
}}
{name} = std::vector<int64_t>({name}_size, -1);
}} else {{
PADDLE_THROW(phi::errors::Unimplemented("Only support VectorType or DenseTensorType or AllocatedDenseTensorType"));
}}\n"""
Expand Down
15 changes: 2 additions & 13 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,14 +575,6 @@ void WhileOp::VerifySig() {
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input, it should be a "
"bool DenseTensorType."));
} else if (auto cond_type =
operand_type(0).dyn_cast<AllocatedDenseTensorType>()) {
PADDLE_ENFORCE_EQ(
cond_type.dtype().isa<pir::BoolType>(),
true,
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input, it should be a "
"bool DenseTensorType."));
} else {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Currently, the while op cond input only support bool dense_tensor "
Expand Down Expand Up @@ -803,8 +795,7 @@ void HasElementsOp::VerifySig() {

// Verify outputs:
IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1.");
IR_ENFORCE((*this)->result_type(0).isa<DenseTensorType>() ||
(*this)->result_type(0).isa<AllocatedDenseTensorType>(),
IR_ENFORCE((*this)->result_type(0).isa<DenseTensorType>(),
"The type of cf.has_elements' output is not correct.");
}

Expand Down Expand Up @@ -874,8 +865,7 @@ void AssertOp::VerifySig() {
(*this)->operand(1).type().dyn_cast<pir::VectorType>()) {
for (size_t i = 0; i < vec_type.size(); ++i) {
IR_ENFORCE(vec_type[i].isa<paddle::dialect::DenseTensorType>() ||
vec_type[i].isa<paddle::dialect::SelectedRowsType>() ||
vec_type[i].isa<AllocatedDenseTensorType>(),
vec_type[i].isa<paddle::dialect::SelectedRowsType>(),
"Type validation failed for the 1th input.");
}
} else {
Expand All @@ -885,7 +875,6 @@ void AssertOp::VerifySig() {
->operand(1)
.type()
.isa<paddle::dialect::SelectedRowsType>(),
(*this)->operand(1).type().isa<AllocatedDenseTensorType>(),
"Type validation failed for the 1th input.");
}
}
Expand Down
9 changes: 0 additions & 9 deletions paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,6 @@ std::vector<pir::Type> ExpandOp::InferMeta(
paddle::dialect::DenseTensorType x;
if (x_.type().isa<paddle::dialect::DenseTensorType>()) {
x = x_.type().dyn_cast<paddle::dialect::DenseTensorType>();
} else if (x_.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
paddle::dialect::AllocatedDenseTensorType allocated_x =
x_.type().dyn_cast<paddle::dialect::AllocatedDenseTensorType>();
x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(),
allocated_x.dtype(),
allocated_x.dims(),
allocated_x.data_layout(),
allocated_x.lod(),
allocated_x.offset());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support paddle::dialect::DenseTensorType or "
Expand Down
Loading

0 comments on commit 2ca34a7

Please sign in to comment.