From e3d4fdd4b01f47f5d7df28d11e08f545febcbdf8 Mon Sep 17 00:00:00 2001 From: jackalcooper Date: Tue, 14 Jun 2022 03:46:45 +0000 Subject: [PATCH 1/3] use names in trait static func --- oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp | 17 ++++++++++------- oneflow/ir/lib/OneFlow/Passes.cpp | 6 +++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp b/oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp index f5bedf762c1..0321527a6d3 100644 --- a/oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp +++ b/oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp @@ -51,7 +51,8 @@ OpFoldResult UnaryFold(MLIRContext* ctx, ArrayRef operands, const auto attr_dict = operands.front().cast(); auto attrs = NamedAttrList(attr_dict); const auto tensor = support::DenseElementsAttrToTensor( - attr_dict.get("value"), attr_dict.get("device_tag"), attr_dict.get("device_name")); + attr_dict.get("value"), attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceTagAttr()), + attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceNameAttr())); const auto result = f(tensor).GetPtrOrThrow(); attrs.set("value", support::TensorToDenseElementsAttr(result, ctx)); attrs.set("op_name", GenNewVariableOpName(ctx)); @@ -67,12 +68,14 @@ OpFoldResult BinaryFold(MLIRContext* ctx, ArrayRef operands, auto rhs_attr_dict = operands.back().cast(); auto attrs = NamedAttrList(lhs_attr_dict); - const auto lhs_tensor = support::DenseElementsAttrToTensor(lhs_attr_dict.get("value"), - lhs_attr_dict.get("device_tag"), - lhs_attr_dict.get("device_name")); - const auto rhs_tensor = support::DenseElementsAttrToTensor(rhs_attr_dict.get("value"), - rhs_attr_dict.get("device_tag"), - rhs_attr_dict.get("device_name")); + const auto lhs_tensor = support::DenseElementsAttrToTensor( + lhs_attr_dict.get("value"), + lhs_attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceTagAttr()), + lhs_attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceNameAttr())); + const auto rhs_tensor = support::DenseElementsAttrToTensor( + rhs_attr_dict.get("value"), + rhs_attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceTagAttr()), + rhs_attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceNameAttr())); const auto result = f(lhs_tensor, rhs_tensor).GetPtrOrThrow(); diff --git a/oneflow/ir/lib/OneFlow/Passes.cpp b/oneflow/ir/lib/OneFlow/Passes.cpp index f76d9370109..8dae2046b99 100644 --- a/oneflow/ir/lib/OneFlow/Passes.cpp +++ b/oneflow/ir/lib/OneFlow/Passes.cpp @@ -349,9 +349,9 @@ ::llvm::SmallVector<::mlir::Value, 4> CreateConv2dAndErasePad(::mlir::PatternRew NamedAttrList GetUserOpCommonAttrs(MLIRContext* ctx, const std::string& op_name) { NamedAttrList attrs; - attrs.set("op_name", StringAttr::get(ctx, op_name)); - attrs.set("device_tag", StringAttr::get(ctx, "cpu")); - attrs.set("device_name", + attrs.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), StringAttr::get(ctx, op_name)); + attrs.set(OpTrait::IsOpConfCompatible::getDeviceTagAttr(), StringAttr::get(ctx, "cpu")); + attrs.set(OpTrait::IsOpConfCompatible::getDeviceNameAttr(), ArrayAttr::get(ctx, llvm::to_vector<8>(llvm::map_range(ArrayRef({"@0:0"}), [&](StringRef v) -> Attribute { return StringAttr::get(ctx, v); From 78b0ea99ce9036773f770fa80f40d9ad3fd60296 Mon Sep 17 00:00:00 2001 From: jackalcooper Date: Tue, 14 Jun 2022 03:52:03 +0000 Subject: [PATCH 2/3] more changes on op name attr --- oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp | 4 ++-- oneflow/ir/lib/OneFlow/Passes.cpp | 6 ++++-- .../oneflow-translate/lib/OneFlow/Importer.cpp | 17 ++++++++++++++--- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp b/oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp index 0321527a6d3..c3d491cf597 100644 --- a/oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp +++ b/oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp @@ -55,7 +55,7 @@ OpFoldResult UnaryFold(MLIRContext* ctx, ArrayRef operands, attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceNameAttr())); const auto result = f(tensor).GetPtrOrThrow(); attrs.set("value", support::TensorToDenseElementsAttr(result, ctx)); - attrs.set("op_name", GenNewVariableOpName(ctx)); + attrs.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), GenNewVariableOpName(ctx)); return attrs.getDictionary(ctx); } @@ -80,7 +80,7 @@ OpFoldResult BinaryFold(MLIRContext* ctx, ArrayRef operands, const auto result = f(lhs_tensor, rhs_tensor).GetPtrOrThrow(); attrs.set("value", support::TensorToDenseElementsAttr(result, ctx)); - attrs.set("op_name", GenNewVariableOpName(ctx)); + attrs.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), GenNewVariableOpName(ctx)); return attrs.getDictionary(ctx); } diff --git a/oneflow/ir/lib/OneFlow/Passes.cpp b/oneflow/ir/lib/OneFlow/Passes.cpp index 8dae2046b99..612e0a79a9a 100644 --- a/oneflow/ir/lib/OneFlow/Passes.cpp +++ b/oneflow/ir/lib/OneFlow/Passes.cpp @@ -569,7 +569,8 @@ llvm::SmallVector getInputOperandTransposeOp(NCHWCompatible op, PatternRewriter& rewriter) { std::string transpose_name = OpTrait::IsOpConfCompatible::getOpName(op).str() + "_transpose_input_" + std::to_string(num_transposed_operand); - transpose_attributes.set(llvm::StringRef("op_name"), rewriter.getStringAttr(transpose_name)); + transpose_attributes.set(llvm::StringRef(OpTrait::IsOpConfCompatible::getOpNameAttr()), + rewriter.getStringAttr(transpose_name)); SmallVector input_operands; input_operands.push_back(val); auto res = rewriter @@ -583,7 +584,8 @@ TransposeOp getResultTransposeOp(NCHWCompatible op, Value val, NamedAttrList tra int num_transposed_result, PatternRewriter& rewriter) { std::string transpose_name = OpTrait::IsOpConfCompatible::getOpName(op).str() + "_transpose_output_" + std::to_string(num_transposed_result); - transpose_attributes.set(llvm::StringRef("op_name"), rewriter.getStringAttr(transpose_name)); + transpose_attributes.set(llvm::StringRef(OpTrait::IsOpConfCompatible::getOpNameAttr()), + rewriter.getStringAttr(transpose_name)); SmallVector operands; operands.push_back(val); TransposeOp transpose_op = rewriter.create(op.getLoc(), val.getType(), diff --git a/oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp b/oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp index 5386629fd00..84afff290e4 100644 --- a/oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp +++ b/oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp @@ -946,7 +946,10 @@ LogicalResult ConvertVariableOpConf(VariableOp op, ::oneflow::OperatorConf* op_c // all operands are ctrl_inputs for (const auto& operand : op->getOperands()) { op_conf->add_ctrl_in_op_name( - operand.getDefiningOp()->getAttrOfType("op_name").getValue().str()); + operand.getDefiningOp() + ->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()) + .getValue() + .str()); } if (auto floatInit = op.float_initializer()) { var_op_conf->mutable_initializer()->mutable_constant_conf()->set_value( @@ -1002,7 +1005,11 @@ LogicalResult ConvertInputOpConf(InputOp op, ::oneflow::OperatorConf* op_conf) { // operand 0 is block argument, others are ctrl_inputs for (size_t i = 1; i < op->getNumOperands(); ++i) { op_conf->add_ctrl_in_op_name( - op->getOperand(i).getDefiningOp()->getAttrOfType("op_name").getValue().str()); + op->getOperand(i) + .getDefiningOp() + ->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()) + .getValue() + .str()); } return success(); @@ -1054,7 +1061,11 @@ LogicalResult ConvertOutputOpConf(OutputOp op, ::oneflow::OperatorConf* op_conf) output_op_conf->set_in(output_lbn); for (size_t i = 1; i < op->getNumOperands(); ++i) { op_conf->add_ctrl_in_op_name( - op->getOperand(i).getDefiningOp()->getAttrOfType("op_name").getValue().str()); + op->getOperand(i) + .getDefiningOp() + ->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()) + .getValue() + .str()); } return success(); } From ee2ffa08662fb8731082fece5b7849e06c7dcda0 Mon Sep 17 00:00:00 2001 From: jackalcooper Date: Tue, 14 Jun 2022 04:01:55 +0000 Subject: [PATCH 3/3] use wrapped func --- .../lib/OneFlow/Importer.cpp | 27 +++++-------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp b/oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp index 84afff290e4..97814d09633 100644 --- a/oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp +++ b/oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp @@ -492,10 +492,7 @@ LogicalResult ConvertCtrlInputs(Operation* op, ::oneflow::OperatorConf& op_conf) if (auto ctrl_ins = GetCtrlIntputOperands(op)) { for (auto ctrl_in : ctrl_ins.getValue()) { op_conf.add_ctrl_in_op_name( - ctrl_in.getDefiningOp() - ->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()) - .getValue() - .str()); + OpTrait::IsOpConfCompatible::getOpName(ctrl_in.getDefiningOp()).str()); } } return success(); @@ -675,9 +672,8 @@ llvm::Optional GetOutputLbn(OpResult result) { auto size = std::get<1>(name_size_tuple); if ((size_sum + size) > result_number) { const uint32_t bn_i = result_number - size_sum; - return def_op->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()) - .str() - + "/" + name + "_" + std::to_string(bn_i); + return OpTrait::IsOpConfCompatible::getOpName(def_op).str() + "/" + name + "_" + + std::to_string(bn_i); } size_sum += size; } @@ -946,10 +942,7 @@ LogicalResult ConvertVariableOpConf(VariableOp op, ::oneflow::OperatorConf* op_c // all operands are ctrl_inputs for (const auto& operand : op->getOperands()) { op_conf->add_ctrl_in_op_name( - operand.getDefiningOp() - ->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()) - .getValue() - .str()); + OpTrait::IsOpConfCompatible::getOpName(operand.getDefiningOp()).str()); } if (auto floatInit = op.float_initializer()) { var_op_conf->mutable_initializer()->mutable_constant_conf()->set_value( @@ -1005,11 +998,7 @@ LogicalResult ConvertInputOpConf(InputOp op, ::oneflow::OperatorConf* op_conf) { // operand 0 is block argument, others are ctrl_inputs for (size_t i = 1; i < op->getNumOperands(); ++i) { op_conf->add_ctrl_in_op_name( - op->getOperand(i) - .getDefiningOp() - ->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()) - .getValue() - .str()); + OpTrait::IsOpConfCompatible::getOpName(op->getOperand(i).getDefiningOp()).str()); } return success(); @@ -1061,11 +1050,7 @@ LogicalResult ConvertOutputOpConf(OutputOp op, ::oneflow::OperatorConf* op_conf) output_op_conf->set_in(output_lbn); for (size_t i = 1; i < op->getNumOperands(); ++i) { op_conf->add_ctrl_in_op_name( - op->getOperand(i) - .getDefiningOp() - ->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()) - .getValue() - .str()); + OpTrait::IsOpConfCompatible::getOpName(op->getOperand(i).getDefiningOp()).str()); } return success(); }