Skip to content

Commit

Permalink
[PIR] Support translate IfOp (#57342)
Browse files Browse the repository at this point in the history
* add code

* refine code

* add code

* add code

* fix bug

* add code

* add code

* add code

* add code

* add code

* add code

* addd code

* refine code

* refine code

* fix conflict
  • Loading branch information
zhangbo9674 authored Sep 16, 2023
1 parent 7da4295 commit db901f9
Show file tree
Hide file tree
Showing 8 changed files with 507 additions and 90 deletions.
122 changes: 61 additions & 61 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions paddle/fluid/ir_adaptor/translator/op_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ struct OpTranscriber {
const OpDesc&,
const std::string&,
const OpInputInfo&,
pir::Program*)>;
pir::Block*)>;
using AttributeHandlerFn = std::function<pir::Attribute(
pir::IrContext*, const OpDesc&, const OpAttributeInfo&)>;

public:
virtual pir::Operation* operator()(pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
pir::Program* program);
pir::Block* block);

public:
virtual pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc);
Expand All @@ -72,7 +72,7 @@ struct OpTranscriber {
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
pir::Program* program);
pir::Block* block);
virtual std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
pir::IrContext* ctx,
const OpDesc& op_desc,
Expand All @@ -86,7 +86,7 @@ struct OpTranscriber {
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc);
virtual pir::OpResult GetAttributeAsInput(pir::IrContext* ctx,
pir::Program* program,
pir::Block* block,
const OpDesc& op_desc,
const OpInputInfo& input_info);

Expand All @@ -109,7 +109,7 @@ struct OpTranscriber {
TranslationContext* param_map,
const OpDesc& op_desc,
const OpInputInfoList& input_infos,
pir::Program* program);
pir::Block* block);
};

class OpTranslator {
Expand All @@ -119,7 +119,7 @@ class OpTranslator {
using BlockDesc = paddle::framework::BlockDesc;
using VarDesc = paddle::framework::VarDesc;
using OpTranslateFn = std::function<pir::Operation*(
pir::IrContext*, TranslationContext*, const OpDesc&, pir::Program*)>;
pir::IrContext*, TranslationContext*, const OpDesc&, pir::Block*)>;

private:
OpTranslator(); // Disallow instantiation outside of the class.
Expand Down
Loading

0 comments on commit db901f9

Please sign in to comment.