Skip to content

Commit

Permalink
[PIR] rectify yield input from opresult to value. (PaddlePaddle#57635)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Sep 23, 2023
1 parent a68d9a9 commit 37cc5a6
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions paddle/pir/dialect/control_flow/ir/cf_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ namespace pir {

void YieldOp::Build(Builder &builder,
OperationArgument &argument,
std::vector<OpResult> &&inputs) {
argument.AddInputs(inputs.begin(), inputs.end());
const std::vector<Value> &inputs) {
argument.AddInputs(inputs);
}
} // namespace pir

Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/dialect/control_flow/ir/cf_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class IR_API YieldOp : public Op<YieldOp> {

static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
std::vector<OpResult> &&inputs);
const std::vector<Value> &Value);
void Verify() {}
};
} // namespace pir
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/new_executor/standalone_executor_new_ir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,15 @@ TEST(StandaloneExecutor, if_op) {

auto full_op_1 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2}, true, phi::DataType::BOOL);
builder.Build<pir::YieldOp>(std::vector<pir::OpResult>{full_op_1.out()});
builder.Build<pir::YieldOp>(std::vector<pir::Value>{full_op_1.out()});

pir::Block* false_block = if_op.false_block();

builder.SetInsertionPointToStart(false_block);

auto full_op_2 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{3}, true, phi::DataType::BOOL);
builder.Build<pir::YieldOp>(std::vector<pir::OpResult>{full_op_2.out()});
builder.Build<pir::YieldOp>(std::vector<pir::Value>{full_op_2.out()});

auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);

Expand Down
4 changes: 2 additions & 2 deletions test/cpp/pir/cinn/group_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ std::shared_ptr<::pir::Program> BuildGroupProgram() {
builder.SetInsertionPointToEnd(block1);
auto full_op_x = builder.Build<paddle::dialect::FullOp>(
shape, value_one, phi::DataType::FLOAT32, phi::GPUPlace());
builder.Build<::pir::YieldOp>(std::vector<::pir::OpResult>{full_op_x.out()});
builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{full_op_x.out()});

builder.SetInsertionPointToEnd(program->block());
auto group_op2 = builder.Build<cinn::dialect::GroupOp>(
Expand All @@ -67,7 +67,7 @@ std::shared_ptr<::pir::Program> BuildGroupProgram() {
auto relu_op_x = builder.Build<paddle::dialect::ReluOp>(tan_op_x->result(0));
auto tan_op_y = builder.Build<paddle::dialect::TanOp>(relu_op_x->result(0));
auto relu_op_y = builder.Build<paddle::dialect::ReluOp>(tan_op_y->result(0));
builder.Build<::pir::YieldOp>(std::vector<::pir::OpResult>{relu_op_y.out()});
builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{relu_op_y.out()});
return program;
}

Expand Down
4 changes: 2 additions & 2 deletions test/cpp/pir/control_flow_dialect/if_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ TEST(if_op_test, base) {

auto full_op_1 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2}, true, phi::DataType::BOOL);
builder.Build<pir::YieldOp>(std::vector<pir::OpResult>{full_op_1.out()});
builder.Build<pir::YieldOp>(std::vector<pir::Value>{full_op_1.out()});

pir::Block* false_block = if_op.false_block();

builder.SetInsertionPointToStart(false_block);

auto full_op_2 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{3}, true, phi::DataType::BOOL);
builder.Build<pir::YieldOp>(std::vector<pir::OpResult>{full_op_2.out()});
builder.Build<pir::YieldOp>(std::vector<pir::Value>{full_op_2.out()});

std::stringstream ss;
program.Print(ss);
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,15 @@ TEST(kernel_dialect, cond_op_test) {

auto full_op_1 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2}, true, phi::DataType::BOOL);
builder.Build<pir::YieldOp>(std::vector<pir::OpResult>{full_op_1.out()});
builder.Build<pir::YieldOp>(std::vector<pir::Value>{full_op_1.out()});

pir::Block* false_block = if_op.false_block();

builder.SetInsertionPointToStart(false_block);

auto full_op_2 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{3}, true, phi::DataType::BOOL);
builder.Build<pir::YieldOp>(std::vector<pir::OpResult>{full_op_2.out()});
builder.Build<pir::YieldOp>(std::vector<pir::Value>{full_op_2.out()});

program.Print(std::cout);
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
Expand Down

0 comments on commit 37cc5a6

Please sign in to comment.