Skip to content

Commit

Permalink
[PIR] normalize the use of value.
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang committed Sep 15, 2023
1 parent ae188d1 commit e1f44ab
Show file tree
Hide file tree
Showing 19 changed files with 163 additions and 171 deletions.
24 changes: 12 additions & 12 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ inline pir::Operation* InsertCombineOperationForTarget(
std::string combine_op_name(pir::CombineOp::name());
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name);

std::vector<pir::OpResult> src_values;
std::vector<pir::Value> src_values;
std::vector<pir::Type> types_in_vec;
for (const auto& arg_name : args) {
auto defining_info = param_map->at(arg_name);
Expand Down Expand Up @@ -299,7 +299,7 @@ pir::OpResult OpTranscriber::GetAttributeAsInput(
return defining_op->result(0);
}

std::vector<pir::OpResult> OpTranscriber::GenerateOperationInput(
std::vector<pir::Value> OpTranscriber::GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
Expand All @@ -314,7 +314,7 @@ std::vector<pir::OpResult> OpTranscriber::GenerateOperationInput(

VLOG(10) << "[op:" << op_desc.Type() << "][input] start";

std::vector<pir::OpResult> op_inputs;
std::vector<pir::Value> op_inputs;

for (const auto& info : input_infos) {
if (auto special_handler = this->GetSpecialInputHandlers(info.name)) {
Expand Down Expand Up @@ -779,7 +779,7 @@ struct AssignValueOpTranscriber : public OpTranscriber {

VLOG(10) << "[op assign_value] attribute translation done";

std::vector<pir::OpResult> op_inputs = {};
std::vector<pir::Value> op_inputs = {};

OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types;
Expand Down Expand Up @@ -904,7 +904,7 @@ struct FeedOpTranscriber : public OpTranscriber {
return attribute_map;
}

std::vector<pir::OpResult> GenerateOperationInput(
std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
Expand Down Expand Up @@ -942,7 +942,7 @@ struct DataOpTranscriber : public FeedOpTranscriber {
};

struct SplitOpTranscriber : public OpTranscriber {
std::vector<pir::OpResult> GenerateOperationInput(
std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
Expand All @@ -953,7 +953,7 @@ struct SplitOpTranscriber : public OpTranscriber {

VLOG(10) << "[op:split][input] start";

std::vector<pir::OpResult> op_inputs;
std::vector<pir::Value> op_inputs;
// process first input
auto x_input_vars = op_desc.Input("X");
IR_ENFORCE(x_input_vars.size() == 1, "x input of split MUST be a tensor");
Expand Down Expand Up @@ -1085,7 +1085,7 @@ struct ShadowOutputOpTranscriber : public OpTranscriber {
pir::Program* program) override {
auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name());

std::vector<pir::OpResult> op_inputs;
std::vector<pir::Value> op_inputs;
auto legacy_input_vars = op_desc.Input("x", true);

auto defining_info = (*param_map)[legacy_input_vars[0]];
Expand Down Expand Up @@ -1163,7 +1163,7 @@ struct FillConstant2FullTranscriber : public OpTranscriber {
return op_info;
}

std::vector<pir::OpResult> GenerateOperationInput(
std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
Expand Down Expand Up @@ -1245,14 +1245,14 @@ struct FillConstant2FullWithTensorTranscriber : public OpTranscriber {
return op_info;
}

std::vector<pir::OpResult> GenerateOperationInput(
std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
pir::Program* program) override {
std::vector<pir::OpResult> op_inputs;
std::vector<pir::Value> op_inputs;
if (op_desc.HasInput("ShapeTensor", true) &&
op_desc.Input("ShapeTensor", true).size() > 0) {
auto shape_tensor_vars = op_desc.Input("ShapeTensor", true);
Expand Down Expand Up @@ -1409,7 +1409,7 @@ struct ReduceOpTranscriber : public OpTranscriber {
};

struct ElementwiseTranscriber : public OpTranscriber {
std::vector<pir::OpResult> GenerateOperationInput(
std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/op_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct OpTranscriber {

public:
virtual pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc);
virtual std::vector<pir::OpResult> GenerateOperationInput(
virtual std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
Expand Down
13 changes: 6 additions & 7 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ void HandleForSpecialOp(
HandleForIfOp(place, op_item, block, ctx, map_op_pair, map_value_pair);
return;
}
std::vector<pir::OpResult> vec_inputs;
std::vector<pir::Value> vec_inputs;
std::vector<pir::Type> op_output_types;
if (op_item->name() == "builtin.combine") {
// Copy op inputs
Expand Down Expand Up @@ -754,8 +754,7 @@ void HandleForSpecialOp(

if (new_in.type().isa<pir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<pir::VectorType>().data();
auto index = op_item->attributes()
.at("index")
auto index = op_item->attribute("index")
.dyn_cast<pir::Int32Attribute>()
.data();
op_output_types.push_back(vec_types[index]);
Expand Down Expand Up @@ -899,7 +898,7 @@ std::vector<pir::Type> BuildOpOutputType(pir::Operation* op_item,
return op_output_types;
}

std::vector<pir::OpResult> BuildOpInputList(
std::vector<pir::Value> BuildOpInputList(
pir::Operation* op_item,
const std::string& kernel_fn_str,
const phi::KernelKey& kernel_key,
Expand All @@ -913,7 +912,7 @@ std::vector<pir::OpResult> BuildOpInputList(
return {};
}

std::vector<pir::OpResult> vec_inputs;
std::vector<pir::Value> vec_inputs;

for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
Expand Down Expand Up @@ -981,7 +980,7 @@ std::vector<pir::OpResult> BuildOpInputList(
auto pre_define_op = cur_in.GetDefiningOp();

if (pre_define_op->name() == "builtin.combine") {
std::vector<pir::OpResult> inner_inputs;
std::vector<pir::Value> inner_inputs;
std::vector<pir::Type> types_in_vec;
bool is_trans = false;
for (size_t j = 0; j < pre_define_op->num_operands(); ++j) {
Expand Down Expand Up @@ -1155,7 +1154,7 @@ std::string GetKernelFnStr(const OpYamlInfoParser* op_info_parser,
pir::Operation* BuildPhiKernelOp(
const std::string& kernel_fn_str,
const phi::KernelKey& kernel_key,
const std::vector<pir::OpResult>& vec_inputs,
const std::vector<pir::Value>& vec_inputs,
const std::vector<pir::Type>& op_output_types,
pir::Operation* op_item,
pir::Block* block,
Expand Down
8 changes: 1 addition & 7 deletions paddle/pir/core/op_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ bool OpResult::operator==(const OpResult &other) const {
return impl_ == other.impl_;
}

// OpResult::OpResult(const detail::OpResultImpl *impl) : Value(impl) {}

uint32_t OpResult::GetValidInlineIndex(uint32_t index) {
uint32_t max_inline_index =
pir::detail::OpResultImpl::GetMaxInlineResultIndex();
return index <= max_inline_index ? index : max_inline_index;
}
OpResult::OpResult(const detail::OpResultImpl *impl) : Value(impl) {}

} // namespace pir
8 changes: 2 additions & 6 deletions paddle/pir/core/op_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,14 @@ class OpResultImpl;
///
class IR_API OpResult : public Value {
public:
OpResult() = default;
OpResult(std::nullptr_t ptr = nullptr) : Value(ptr){}; // NOLINT
Operation *owner() const;
uint32_t GetResultIndex() const;
bool operator==(const OpResult &other) const;
// OpResult(const detail::OpResultImpl *impl); // NOLINT

// This func will remove in next pr.
OpResult(const detail::ValueImpl *impl) : Value(impl) {} // NOLINT

private:
friend Operation;
static uint32_t GetValidInlineIndex(uint32_t index);
OpResult(const detail::OpResultImpl *impl); // NOLINT
// Access classof annd dyn_cast_from.
friend Value;
static bool classof(Value value);
Expand Down
10 changes: 7 additions & 3 deletions paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@

namespace pir {
Operation *Operation::Create(OperationArgument &&argument) {
return Create(argument.inputs,
std::vector<Value> inputs;
for (auto op_result : argument.inputs) {
inputs.emplace_back(op_result);
}
return Create(inputs,
argument.attributes,
argument.output_types,
argument.info,
Expand All @@ -38,7 +42,7 @@ Operation *Operation::Create(OperationArgument &&argument) {
// Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, operand.
Operation *Operation::Create(const std::vector<pir::OpResult> &inputs,
Operation *Operation::Create(const std::vector<Value> &inputs,
const AttributeMap &attributes,
const std::vector<Type> &output_types,
pir::OpInfo op_info,
Expand Down Expand Up @@ -89,7 +93,7 @@ Operation *Operation::Create(const std::vector<pir::OpResult> &inputs,
IR_THROW("The address of OpOperandImpl must be divisible by 8.");
}
for (size_t idx = 0; idx < num_operands; idx++) {
new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
new (base_ptr) detail::OpOperandImpl(inputs[idx], op);
base_ptr += sizeof(detail::OpOperandImpl);
}
// 3.4. Construct BlockOperands.
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class IR_API alignas(8) Operation final {
/// NOTE: Similar to new and delete, the destroy() and the create() need to be
/// used in conjunction.
///
static Operation *Create(const std::vector<pir::OpResult> &inputs,
static Operation *Create(const std::vector<pir::Value> &inputs,
const AttributeMap &attributes,
const std::vector<pir::Type> &output_types,
pir::OpInfo op_info,
Expand Down
28 changes: 27 additions & 1 deletion paddle/pir/core/operation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <initializer_list>
#include <memory>
#include "paddle/pir/core/attribute.h"
#include "paddle/pir/core/op_info.h"
Expand Down Expand Up @@ -56,12 +57,29 @@ struct OperationArgument {
num_regions(num_regions),
successors(successors) {}

/// Add Operand.
// Will be deleted in the next pr.
void AddOperand(OpResult operand) { inputs.emplace_back(operand); }

void AddInput(Value input) {
inputs.emplace_back(input.dyn_cast<OpResult>());
}

// Will be deleted in the next pr.
template <class InputIt>
void AddOperands(InputIt first, InputIt last);

template <class InputIt>
void AddInputs(InputIt first, InputIt last);

void AddInputs(std::initializer_list<Value> value_list) {
AddInputs(std::begin(value_list), std::end(value_list));
}

template <class ValueContainer>
void AddInputs(const ValueContainer& value_container) {
AddInputs(std::begin(value_container), std::end(value_container));
}

/// Add Output.
void AddOutput(Type type) { output_types.emplace_back(type); }

Expand All @@ -87,6 +105,14 @@ void OperationArgument::AddOperands(InputIt first, InputIt last) {
inputs.emplace_back(*first++);
}
}

template <class InputIt>
void OperationArgument::AddInputs(InputIt first, InputIt last) {
while (first != last) {
AddInput(*first++);
}
}

template <class InputIt>
void OperationArgument::AddOutputs(InputIt first, InputIt last) {
while (first != last) {
Expand Down
14 changes: 4 additions & 10 deletions paddle/pir/core/parser/ir_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ Operation* IrParser::ParseOperation() {

OpInfo opinfo = ParseOpInfo();

std::vector<OpResult> inputs = ParseOprandList();
std::vector<Value> inputs = ParseOprandList();

pir::AttributeMap attributeMap = ParseAttributeMap();

Expand Down Expand Up @@ -269,14 +269,14 @@ OpInfo IrParser::ParseOpInfo() {

// OprandList := ValueList
// ValueList := ValueId(,ValueId)*
std::vector<OpResult> IrParser::ParseOprandList() {
std::vector<Value> IrParser::ParseOprandList() {
ConsumeAToken("(");
std::vector<OpResult> inputs{};
std::vector<Value> inputs{};
Token ind_token = ConsumeToken();
while (ind_token.val_ != ")") {
std::string t = "";
if (ind_token.token_type_ == NULL_) {
inputs.push_back(GetNullValue());
inputs.emplace_back();
} else {
t = ind_token.val_;
inputs.push_back(opresultmap[t]);
Expand Down Expand Up @@ -327,12 +327,6 @@ std::vector<Type> IrParser::ParseTypeList() {
return type_vector;
}

OpResult IrParser::GetNullValue() {
Value* v = new Value{nullptr};
OpResult* opresult = static_cast<OpResult*>(v);
return *opresult;
}

Attribute Attribute::Parse(std::istream& is, IrContext* ctx) {
IrParser parser(ctx, is);
return parser.ParseAttribute();
Expand Down
4 changes: 1 addition & 3 deletions paddle/pir/core/parser/ir_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@ class IrParser {

std::vector<std::string> ParseOpResultList();

std::vector<OpResult> ParseOprandList();
std::vector<Value> ParseOprandList();

AttributeMap ParseAttributeMap();

std::vector<Type> ParseTypeList();

OpResult GetNullValue();

Type ParseType();

Attribute ParseAttribute();
Expand Down
Loading

0 comments on commit e1f44ab

Please sign in to comment.