Skip to content

Commit

Permalink
InferSymbolicShape of BroadcastShape Op
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 committed Jan 2, 2024
1 parent 3a19245 commit 716a7f5
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 36 deletions.
9 changes: 9 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,13 @@ if(NOT CINN_ONLY)
cinn_op_dialect
op_dialect_vjp)

cinn_cc_library(
fully_insert_broadcast_pass
SRCS
fully_insert_broadcast_pass.cc
DEPS
pir
cinn_op_dialect
op_dialect_vjp)

endif()
117 changes: 89 additions & 28 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
paddle::dialect::ArrayWrite_Op, paddle::dialect::SliceArrayOp,
paddle::dialect::SliceArrayDenseOp, paddle::dialect::AssignArray_Op,
paddle::dialect::ArrayToTensorOp, paddle::dialect::SelectInputOp,
paddle::dialect::IncrementOp, paddle::dialect::Increment_Op
paddle::dialect::IncrementOp, paddle::dialect::Increment_Op,
paddle::dialect::ShapeBroadcastOp
#else

#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
Expand All @@ -33,6 +34,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/phi/api/lib/data_type_set.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
Expand All @@ -46,7 +48,6 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/ir_context.h"
#include "paddle/phi/api/lib/data_type_set.h"

namespace paddle {
namespace dialect {
Expand Down Expand Up @@ -2685,36 +2686,41 @@ phi::DataType Increment_Op::GetKernelTypeForVar(
return expected_kernel_dtype;
}

void ShapeBroadcastOp::Build(pir::Builder &builder, pir::OperationArgument &argument, pir::Value x_, pir::Value y_) {
void ShapeBroadcastOp::Build(pir::Builder &builder,
pir::OperationArgument &argument,
pir::Value x_,
pir::Value y_) {
VLOG(4) << "Start build ShapeBroadcastOp";



VLOG(4) << "Builder construction inputs";
std::vector<pir::Value> argument_inputs = {x_, y_};
argument.AddInputs(argument_inputs);

VLOG(4) << "Builder construction attributes";

VLOG(4) << "Builder construction outputs";
paddle::dialect::DenseTensorType x = x_.type().dyn_cast<paddle::dialect::DenseTensorType>();
paddle::dialect::DenseTensorType y = y_.type().dyn_cast<paddle::dialect::DenseTensorType>();
paddle::dialect::DenseTensorType x =
x_.type().dyn_cast<paddle::dialect::DenseTensorType>();
paddle::dialect::DenseTensorType y =
y_.type().dyn_cast<paddle::dialect::DenseTensorType>();

VLOG(4) << "Builder construction dense_x";
paddle::dialect::IrTensor ir_tensor_x(paddle::dialect::TransToPhiDataType(x.dtype()),
x.dims(),
x.data_layout(),
x.lod(),
x.offset());
paddle::dialect::IrTensor ir_tensor_x(
paddle::dialect::TransToPhiDataType(x.dtype()),
x.dims(),
x.data_layout(),
x.lod(),
x.offset());
VLOG(4) << "Builder construction meta_x";
paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x);

VLOG(4) << "Builder construction dense_y";
paddle::dialect::IrTensor ir_tensor_y(paddle::dialect::TransToPhiDataType(y.dtype()),
y.dims(),
y.data_layout(),
y.lod(),
y.offset());
paddle::dialect::IrTensor ir_tensor_y(
paddle::dialect::TransToPhiDataType(y.dtype()),
y.dims(),
y.data_layout(),
y.lod(),
y.offset());
VLOG(4) << "Builder construction meta_y";
paddle::dialect::IrMetaTensor meta_y(&ir_tensor_y);
paddle::dialect::IrTensor dense_out;
Expand All @@ -2723,18 +2729,23 @@ void ShapeBroadcastOp::Build(pir::Builder &builder, pir::OperationArgument &argu
phi::ElementwiseInferMeta(meta_x, meta_y, &meta_out);

std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_out.dtype()), dense_out.dims(), dense_out.layout(), dense_out.lod(), dense_out.offset());
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(
pir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_out.dtype()),
dense_out.dims(),
dense_out.layout(),
dense_out.lod(),
dense_out.offset());
argument_outputs.push_back(out_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);

}

namespace {

void ShapeBroadcastOpInferMeta(const phi::MetaTensor& x,
const phi::MetaTensor& y,
phi::MetaTensor* out) {
void ShapeBroadcastOpInferMeta(const phi::MetaTensor &x,
const phi::MetaTensor &y,
phi::MetaTensor *out) {
PADDLE_ENFORCE_EQ(x.dims().size(), 1);
PADDLE_ENFORCE_EQ(y.dims().size(), 1);
out->set_dims({std::max<int64_t>(x.dims().at(0), y.dims().at(0))});
Expand All @@ -2750,23 +2761,73 @@ void ShapeBroadcastOpInferMeta(const phi::MetaTensor& x,
out->share_lod(x);
}

}
} // namespace

void ShapeBroadcastOp::InferMeta( phi::InferMetaContext *infer_meta ) {
void ShapeBroadcastOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(ShapeBroadcastOpInferMeta);
fn(infer_meta);
}


phi::DataType ShapeBroadcastOp::GetKernelTypeForVar(
const std::string& var_name,
const phi::DataType& tensor_dtype,
const phi::DataType& expected_kernel_dtype) {
const std::string &var_name,
const phi::DataType &tensor_dtype,
const phi::DataType &expected_kernel_dtype) {
VLOG(4) << "Get KernelType for Var of op: ShapeBroadcastOp";

return expected_kernel_dtype;
}

namespace {

symbol::DimExpr GetBroadcastDimExpr(const symbol::DimExpr &lhs,
const symbol::DimExpr &rhs) {
if (lhs.isa<std::int64_t>() && rhs.isa<std::int64_t>()) {
CHECK_EQ(lhs.dyn_cast<std::int64_t>(), rhs.dyn_cast<std::int64_t>());
} else if (lhs.isa<std::int64_t>()) {
return lhs.dyn_cast<std::int64_t>() == 1 ? rhs : lhs;
} else if (rhs.isa<std::int64_t>()) {
return rhs.dyn_cast<std::int64_t>() == 1 ? lhs : rhs;
} else {
return symbol::Broadcast<symbol::DimExpr>{
symbol::List<symbol::DimExpr>{lhs, rhs}};
}
}

} // namespace

bool ShapeBroadcastOp::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value x = operand_source(0);
pir::Value y = operand_source(1);
std::string x_id = pir::GetValueId(&x);
std::string y_id = pir::GetValueId(&y);

IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(x_id) > 0,
"x_id does not exist.");
IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(y_id) > 0,
"y_id does not exist.");
const auto &x_data_shape = shape_analysis->value_id_to_shapeordata_.at(x_id);
const auto &y_data_shape = shape_analysis->value_id_to_shapeordata_.at(y_id);
IR_ENFORCE(x_data_shape.data().has_value(),
"Value x comes from ShapeOp, it must have data");
IR_ENFORCE(y_data_shape.data().has_value(),
"Value y comes from ShapeOp, it must have data");
const auto &x_data = x_data_shape.data().value();
const auto &y_data = y_data_shape.data().value();
IR_ENFORCE(x_data.size() == y_data.size(), "Support same rank temporarily");

std::vector<symbol::DimExpr> output_data;
for (std::size_t i = 0; i < x_data.size(); ++i) {
output_data.emplace_back(GetBroadcastDimExpr(x_data.at(i), y_data.at(i)));
}

pir::OpResult res = result(0);
std::string res_id = pir::GetValueId(&res);
symbol::ShapeOrDataDimExprs output_data_shape =
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(output_data);
shape_analysis->value_id_to_shapeordata_[res_id] = output_data_shape;
return true;
}

} // namespace dialect
} // namespace paddle
Expand Down
24 changes: 16 additions & 8 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h"
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
Expand Down Expand Up @@ -515,16 +516,22 @@ class Increment_Op
const std::vector<std::vector<bool>> &stop_gradients);
};

class IR_API ShapeBroadcastOp : public pir::Op<ShapeBroadcastOp,paddle::dialect::InferMetaInterface,paddle::dialect::GetKernelTypeForVarInterface> {
class IR_API ShapeBroadcastOp
: public pir::Op<ShapeBroadcastOp,
paddle::dialect::InferSymbolicShapeInterface,
paddle::dialect::InferMetaInterface,
paddle::dialect::GetKernelTypeForVarInterface> {
public:
using Op::Op;
static const char *name() { return "pd_op.shape_broadcast"; }
static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;
static void Build(pir::Builder &builder, pir::OperationArgument &argument, pir::Value x_, pir::Value y_);

void VerifySig() {}
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value x_,
pir::Value y_);

void VerifySig() {}

pir::Value x() { return operand_source(0); }
pir::Value y() { return operand_source(1); }
Expand All @@ -533,10 +540,11 @@ class IR_API ShapeBroadcastOp : public pir::Op<ShapeBroadcastOp,paddle::dialect:
static void InferMeta(phi::InferMetaContext *infer_meta);

static phi::DataType GetKernelTypeForVar(
const std::string& var_name,
const phi::DataType& tensor_dtype,
const phi::DataType& expected_kernel_dtype);
const std::string &var_name,
const phi::DataType &tensor_dtype,
const phi::DataType &expected_kernel_dtype);

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
};

} // namespace dialect
Expand All @@ -560,4 +568,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp)

0 comments on commit 716a7f5

Please sign in to comment.