Skip to content

Commit

Permalink
[PIR] Replace OpResult usage in Operation::result(s) (PaddlePaddle#61053
Browse files Browse the repository at this point in the history
)
  • Loading branch information
huangjiyi authored and eee4017 committed Jan 30, 2024
1 parent 5de5aea commit 2efef93
Show file tree
Hide file tree
Showing 38 changed files with 312 additions and 365 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class IR_API GenerateShapeOp

void VerifySig() {}

pir::OpResult out() { return result(0); }
pir::Value out() { return result(0); }

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);

Expand Down
84 changes: 41 additions & 43 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,10 @@ void OpTranscriber::InsertSliceOperationForInput(
}
}

pir::OpResult OpTranscriber::GetAttributeAsInput(
pir::IrContext* ctx,
pir::Block* block,
const OpDesc& op_desc,
const OpInputInfo& input_info) {
pir::Value OpTranscriber::GetAttributeAsInput(pir::IrContext* ctx,
pir::Block* block,
const OpDesc& op_desc,
const OpInputInfo& input_info) {
auto& attribute_translator = AttributeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();

Expand Down Expand Up @@ -485,7 +484,7 @@ std::vector<pir::Value> OpTranscriber::GenerateOperationInput(
<< legacy_input_name;

std::vector<std::string> legacy_input_vars;
// return empty OpResult if this arg is optional and not shown in OpDesc
// return empty Value if this arg is optional and not shown in OpDesc
if (op_desc.HasInput(legacy_input_name, true)) {
legacy_input_vars = op_desc.Input(legacy_input_name, true);
}
Expand Down Expand Up @@ -785,7 +784,7 @@ void OpTranscriber::RecordOpResultMapping(pir::IrContext* ctx,
VLOG(10) << "[output recording]"
<< "[" << op_desc.Type() << "]" << arg_name << " " << idx_in_op
<< " " << idx_in_vec;
pir::OpResult value = operation->result(idx_in_op);
pir::Value value = operation->result(idx_in_op);
bool generated_by_vector = value.type().isa<pir::VectorType>();

param_map->PushValue(
Expand Down Expand Up @@ -1037,12 +1036,12 @@ struct AssignValueOpTranscriber : public OpTranscriber {
// So we generate an input by `full` with same type of output `DropoutState` of
// OpDesc And we still should be aware that `DropoutState` is an optional output
// in static graph.
pir::OpResult TranslateDropOutStateIn(pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfo& input_info,
pir::Block* block) {
pir::Value TranslateDropOutStateIn(pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfo& input_info,
pir::Block* block) {
const std::string legacy_output_name = "DropoutState";
std::vector<std::string> legacy_output_vars;
if (op_desc.HasOutput(legacy_output_name)) {
Expand All @@ -1051,7 +1050,7 @@ pir::OpResult TranslateDropOutStateIn(pir::IrContext* ctx,

if (legacy_output_vars.empty()) {
VLOG(3) << "[input translating] not find output variable: DropoutState";
return pir::OpResult(nullptr);
return pir::Value(nullptr);
}

// `DropoutState` is a tensor
Expand Down Expand Up @@ -1416,7 +1415,7 @@ struct TrilAndTriuGradOpTranscriber : public OpTranscriber {
};

using ValueInfo =
std::tuple<std::vector<int64_t>, dialect::DenseTensorType, pir::OpResult>;
std::tuple<std::vector<int64_t>, dialect::DenseTensorType, pir::Value>;

ValueInfo GetTensorInfoByVarName(const OpDesc& op_desc,
const std::vector<std::string>& names,
Expand All @@ -1434,7 +1433,7 @@ ValueInfo GetTensorInfoByVarName(const OpDesc& op_desc,
name);
const auto& defining_info = param_map->at(name);

pir::OpResult value = defining_info.value.dyn_cast<pir::OpResult>();
pir::Value value = defining_info.value;
IR_ENFORCE(
value, "Expected op[%s]'s input %s is not null", op_desc.Type(), name);
const pir::Type& type = value.type();
Expand Down Expand Up @@ -1529,7 +1528,7 @@ struct MulOpTranscriber : public OpTranscriber {
});
dialect::ReshapeOp reshape_op_x =
builder.Build<dialect::ReshapeOp>(x_value, x_new_shape);
pir::OpResult x_new = reshape_op_x.out();
pir::Value x_new = reshape_op_x.out();
VLOG(6) << "[" << op_desc.Type() << "] x_shape change from "
<< x_tensor_type.dims() << " to " << common::make_ddim(x_new_shape);

Expand All @@ -1547,7 +1546,7 @@ struct MulOpTranscriber : public OpTranscriber {

dialect::ReshapeOp reshape_op_y =
builder.Build<dialect::ReshapeOp>(y_value, y_new_shape);
pir::OpResult y_new = reshape_op_y.out();
pir::Value y_new = reshape_op_y.out();
VLOG(6) << "[" << op_desc.Type() << "] y_shape change from "
<< y_tensor_type.dims() << " to " << common::make_ddim(y_new_shape);

Expand All @@ -1566,7 +1565,7 @@ struct MulOpTranscriber : public OpTranscriber {
op_desc, op_desc.Output("Out"), param_map, "Out");

const dialect::DenseTensorType& out_tensor_type = std::get<1>(out_info);
pir::OpResult& out_value = std::get<2>(out_info);
pir::Value& out_value = std::get<2>(out_info);

const auto& output_vars = op_desc.Output("Out");
const auto& output_name = output_vars[0];
Expand Down Expand Up @@ -1594,7 +1593,7 @@ struct MulOpTranscriber : public OpTranscriber {
pir::Builder builder(ctx, operation->GetParent());
dialect::ReshapeOp reshape_op_out =
builder.Build<dialect::ReshapeOp>(out_value, out_new_shape);
pir::OpResult out_new = reshape_op_out.out().dyn_cast<pir::OpResult>();
pir::Value out_new = reshape_op_out.out();
VLOG(6) << "[" << op_desc.Type() << "] out_shape change from "
<< out_tensor_type.dims() << " to "
<< common::make_ddim(out_new_shape);
Expand Down Expand Up @@ -1674,7 +1673,7 @@ struct MulGradOpTranscriber : public OpTranscriber {

const dialect::DenseTensorType& out_grad_tensor_type =
std::get<1>(out_grad_info);
pir::OpResult& out_grad_value = std::get<2>(out_grad_info);
pir::Value& out_grad_value = std::get<2>(out_grad_info);

pir::Builder builder(ctx, block);

Expand All @@ -1692,7 +1691,7 @@ struct MulGradOpTranscriber : public OpTranscriber {
});
dialect::ReshapeOp reshape_op_x =
builder.Build<dialect::ReshapeOp>(x_value, x_new_shape);
pir::OpResult x_new = reshape_op_x.out();
pir::Value x_new = reshape_op_x.out();
VLOG(6) << "[" << op_desc.Type() << "] x_shape change from "
<< x_tensor_type.dims() << " to " << common::make_ddim(x_new_shape);

Expand All @@ -1710,7 +1709,7 @@ struct MulGradOpTranscriber : public OpTranscriber {

dialect::ReshapeOp reshape_op_y =
builder.Build<dialect::ReshapeOp>(y_value, y_new_shape);
pir::OpResult y_new = reshape_op_y.out();
pir::Value y_new = reshape_op_y.out();
VLOG(6) << "[" << op_desc.Type() << "] y_shape change from "
<< y_tensor_type.dims() << " to " << common::make_ddim(y_new_shape);

Expand All @@ -1719,7 +1718,7 @@ struct MulGradOpTranscriber : public OpTranscriber {

dialect::ReshapeOp reshape_op_out_grad =
builder.Build<dialect::ReshapeOp>(out_grad_value, out_grad_new_shape);
pir::OpResult out_grad_new = reshape_op_out_grad.out();
pir::Value out_grad_new = reshape_op_out_grad.out();
VLOG(6) << "[" << op_desc.Type() << "] out_grad_shape change from "
<< out_grad_tensor_type.dims() << " to "
<< common::make_ddim(out_grad_new_shape);
Expand Down Expand Up @@ -1770,7 +1769,7 @@ struct MulGradOpTranscriber : public OpTranscriber {
std::vector<int64_t> shape = var_desc->GetShape();
DenseTensorTypeStorage::Dim dim = common::make_ddim(shape);

pir::OpResult value_res = operation->result(idx_in_op);
pir::Value value_res = operation->result(idx_in_op);
auto reshape_op = builder.Build<dialect::ReshapeOp>(value_res, shape);

IR_ENFORCE(value_res,
Expand Down Expand Up @@ -2160,13 +2159,12 @@ struct SelectOutputOpTranscriber : public OpTranscriber {
}
};

pir::OpResult TranslateNumClassesForOneHot(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfo& input_info,
pir::Block* block) {
pir::Value TranslateNumClassesForOneHot(pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfo& input_info,
pir::Block* block) {
const std::string legacy_attr_name = "depth";
const std::string legacy_tensor_name = "depth_tensor";
std::vector<std::string> legacy_vars;
Expand All @@ -2182,7 +2180,7 @@ pir::OpResult TranslateNumClassesForOneHot(
"%s should be existed in one_hot_v2 as input depth_tensor.",
legacy_vars[0]);
auto defining_info = param_map->at(legacy_vars[0]);
return defining_info.value.dyn_cast<pir::OpResult>();
return defining_info.value;
}

auto& attribute_translator = AttributeTranslator::instance();
Expand Down Expand Up @@ -2300,7 +2298,7 @@ struct ElementwiseTranscriber : public OpTranscriber {
ctx, param_map, block, x_defining_info, x_name);
x_defining_info = param_map->at(x_name);
}
pir::OpResult x_value = x_defining_info.value.dyn_cast<pir::OpResult>();
pir::Value x_value = x_defining_info.value;
IR_ENFORCE(x_value,
"Expected op[%s]'s input %s is not null",
op_desc.Type(),
Expand Down Expand Up @@ -2331,7 +2329,7 @@ struct ElementwiseTranscriber : public OpTranscriber {
ctx, param_map, block, y_defining_info, y_name);
y_defining_info = param_map->at(y_name);
}
pir::OpResult y_value = y_defining_info.value.dyn_cast<pir::OpResult>();
pir::Value y_value = y_defining_info.value;
IR_ENFORCE(y_value,
"Expected op[%s]'s input %s is not null",
op_desc.Type(),
Expand Down Expand Up @@ -2362,7 +2360,7 @@ struct ElementwiseTranscriber : public OpTranscriber {
append_size);

pir::Builder builder(ctx, block);
pir::OpResult y_new;
pir::Value y_new;
if (std::find(y_shape.begin(), y_shape.end(), -1) == y_shape.end()) {
std::vector<int64_t> y_new_shape(y_shape);
for (int i = 0; i < append_size; i++) {
Expand Down Expand Up @@ -2451,7 +2449,7 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
op_desc.Type(),
y_name);
auto y_defining_info = param_map->at(y_name);
pir::OpResult y_value = y_defining_info.value.dyn_cast<pir::OpResult>();
pir::Value y_value = y_defining_info.value;
IR_ENFORCE(y_value,
"Expected op[%s]'s input %s is not null",
op_desc.Type(),
Expand All @@ -2465,7 +2463,7 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
dialect::DenseTensorType y_tensor_type =
y_type.dyn_cast<dialect::DenseTensorType>();

pir::OpResult value = operation->result(idx_in_op);
pir::Value value = operation->result(idx_in_op);

// if y_grad' shape is same with y, we don't need a reshape
pir::Type y_grad_type = value.type();
Expand All @@ -2489,10 +2487,10 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
};

struct SetValueOpTranscriber : public OpTranscriber {
pir::OpResult GetAttributeAsInput(pir::IrContext* ctx,
pir::Block* block,
const OpDesc& op_desc,
const OpInputInfo& input_info) override {
pir::Value GetAttributeAsInput(pir::IrContext* ctx,
pir::Block* block,
const OpDesc& op_desc,
const OpInputInfo& input_info) override {
auto& attribute_translator = AttributeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();

Expand Down Expand Up @@ -3062,7 +3060,7 @@ struct LegacyMatmulOpTranscriber : public OpTranscriber {
<< idx_in_op << " " << idx_in_vec;

pir::Builder builder(ctx, operation->GetParent());
pir::OpResult value = operation->result(idx_in_op);
pir::Value value = operation->result(idx_in_op);
auto scale_op = builder.Build<dialect::ScaleOp>(value, alpha);
param_map->PushValue(output_vars[0],
VariableDefiningInfo(scale_op.out(), false, -1));
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/ir_adaptor/translator/op_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ struct OpTranscriber {
const std::string& normalized_op_name,
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc);
virtual pir::OpResult GetAttributeAsInput(pir::IrContext* ctx,
pir::Block* block,
const OpDesc& op_desc,
const OpInputInfo& input_info);
virtual pir::Value GetAttributeAsInput(pir::IrContext* ctx,
pir::Block* block,
const OpDesc& op_desc,
const OpInputInfo& input_info);

virtual void RecordOpResultMapping(pir::IrContext* ctx,
TranslationContext* param_map,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pir::Operation* InsertSliceOperationForTarget(
{src_vec_type[defining_info.idx_in_vector]},
op_info);
block->push_back(operation);
pir::OpResult target_op_result = operation->result(0);
pir::Value target_op_result = operation->result(0);
param_map->PushValue(arg_name, VariableDefiningInfo(target_op_result));
return operation;
}
Expand Down
34 changes: 12 additions & 22 deletions paddle/fluid/pir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,16 @@
}}"""

OPTIONAL_OPRESULT_OUTPUT_TEMPLATE = """
paddle::optional<pir::OpResult> optional_{name};
paddle::optional<pir::Value> optional_{name};
if (!IsEmptyValue({op_name}_op.result({index}))) {{
optional_{name} = paddle::make_optional<pir::OpResult>({op_name}_op.result({index}));
optional_{name} = paddle::make_optional<pir::Value>({op_name}_op.result({index}));
}}"""

OPTIONAL_VECTOR_OPRESULT_OUTPUT_TEMPLATE = """
paddle::optional<std::vector<pir::OpResult>> optional_{name};
paddle::optional<std::vector<pir::Value>> optional_{name};
if (!IsEmptyValue({op_name}_op.result({index}))) {{
auto optional_{name}_slice_op = ApiBuilder::Instance().GetBuilder()->Build<pir::SplitOp>({op_name}_op.result({index}));
optional_{name} = paddle::make_optional<std::vector<pir::OpResult>>(optional_{name}_slice_op.outputs());
optional_{name} = paddle::make_optional<std::vector<pir::Value>>(optional_{name}_slice_op.outputs());
}}"""

SET_NULL_TYPE_TEMPLATE = """
Expand All @@ -171,26 +171,16 @@
VECTOR_TYPE = 'pir::VectorType'
INTARRAY_ATTRIBUTE = "paddle::dialect::IntArrayAttribute"

INPUT_TYPE_MAP = {
VALUE_TYPE_MAP = {
'paddle::dialect::DenseTensorType': 'pir::Value',
'paddle::dialect::SelectedRowsType': 'pir::Value',
'pir::VectorType<paddle::dialect::DenseTensorType>': 'std::vector<pir::Value>',
}
OPTIONAL_INPUT_TYPE_MAP = {
OPTIONAL_VALUE_TYPE_MAP = {
'paddle::dialect::DenseTensorType': 'paddle::optional<pir::Value>',
'paddle::dialect::SelectedRowsType': 'paddle::optional<pir::Value>',
'pir::VectorType<paddle::dialect::DenseTensorType>': 'paddle::optional<std::vector<pir::Value>>',
}
OUTPUT_TYPE_MAP = {
'paddle::dialect::DenseTensorType': 'pir::OpResult',
'paddle::dialect::SelectedRowsType': 'pir::OpResult',
'pir::VectorType<paddle::dialect::DenseTensorType>': 'std::vector<pir::OpResult>',
}
OPTIONAL_OUTPUT_TYPE_MAP = {
'paddle::dialect::DenseTensorType': 'paddle::optional<pir::OpResult>',
'paddle::dialect::SelectedRowsType': 'paddle::optional<pir::OpResult>',
'pir::VectorType<paddle::dialect::DenseTensorType>': 'paddle::optional<std::vector<pir::OpResult>>',
}


def get_op_class_name(op_name):
Expand Down Expand Up @@ -282,9 +272,9 @@ def _gen_api_inputs(self, op_info):
ret = []
for name, type, optional in zip(name_list, type_list, optional_list):
if optional == 'true':
ret.append(f'const {OPTIONAL_INPUT_TYPE_MAP[type]}& {name}')
ret.append(f'const {OPTIONAL_VALUE_TYPE_MAP[type]}& {name}')
else:
ret.append(f'const {INPUT_TYPE_MAP[type]}& {name}')
ret.append(f'const {VALUE_TYPE_MAP[type]}& {name}')
return ', '.join(ret)

def _gen_api_attrs(
Expand Down Expand Up @@ -347,17 +337,17 @@ def _gen_ret_type(self, op_info):
if intermediate == 'true':
continue
if self._is_optional_output(op_info, name):
ret.append(OPTIONAL_OUTPUT_TYPE_MAP[type])
ret.append(OPTIONAL_VALUE_TYPE_MAP[type])
else:
ret.append(OUTPUT_TYPE_MAP[type])
ret.append(VALUE_TYPE_MAP[type])
return 'std::tuple<{}>'.format(', '.join(ret))
elif output_num == 1:
index = intermediate_list.index('false')
name = name_list[index]
if self._is_optional_output(op_info, name):
return OPTIONAL_OUTPUT_TYPE_MAP[type_list[index]]
return OPTIONAL_VALUE_TYPE_MAP[type_list[index]]
else:
return OUTPUT_TYPE_MAP[type_list[index]]
return VALUE_TYPE_MAP[type_list[index]]
elif output_num == 0:
return 'void'

Expand Down
Loading

0 comments on commit 2efef93

Please sign in to comment.