Skip to content

Commit

Permalink
[AutoParallel-PIR] Remove program clone in Pass (PaddlePaddle#64137)
Browse files Browse the repository at this point in the history
* remove clone

* revise typoes

* update unitest
  • Loading branch information
JZ-LIANG authored and co63oc committed May 10, 2024
1 parent 6776844 commit ae876f3
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,27 +109,27 @@ void VerifyDenseBlock(pir::Block* block) {
}
}

std::shared_ptr<pir::Program> DistToDensePass(pir::Program* prog) {
void DistToDensePass(pir::Program* prog) {
if (FLAGS_print_ir) {
VLOG(0) << "IR before DistToDense Pass = " << *prog;
}

pir::IrMapping mapper;
auto new_prog = prog->Clone(mapper);
// auto new_prog = prog->Clone(mapper);

pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<OperatorDialect>();
ctx->GetOrRegisterDialect<DistDialect>();

ProcessDistBlock(new_prog->block());
VLOG(6) << "IR before VerifyDenseBlock Pass = " << *new_prog;
VerifyDenseBlock(new_prog->block());
ProcessDistBlock(prog->block());
VLOG(6) << "IR before VerifyDenseBlock Pass = " << *prog;
VerifyDenseBlock(prog->block());

if (FLAGS_print_ir) {
VLOG(0) << "IR after DistToDense Pass = " << *new_prog;
VLOG(0) << "IR after DistToDense Pass = " << *prog;
}

return new_prog;
// return prog;
}

} // namespace dialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
namespace paddle {
namespace dialect {

TEST_API std::shared_ptr<pir::Program> DistToDensePass(pir::Program* prog);
TEST_API void DistToDensePass(pir::Program* prog);

void ProcessDistBlock(pir::Block* block);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,26 +132,26 @@ void VerifyDistBlock(pir::Block* block) {
}
}

std::shared_ptr<pir::Program> MixToDistPass(pir::Program* prog) {
void MixToDistPass(pir::Program* prog) {
if (FLAGS_print_ir) {
std::cout << "IR before MixToDist Pass = " << *prog << std::endl;
VLOG(0) << "IR before MixToDist Pass = " << *prog << std::endl;
}

pir::IrMapping mapper;
auto new_prog = prog->Clone(mapper);
// auto new_prog = prog->Clone(mapper);

pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<OperatorDialect>();
ctx->GetOrRegisterDialect<DistDialect>();

ProcessMixBlock(new_prog->block());
VerifyDistBlock(new_prog->block());
ProcessMixBlock(prog->block());
VerifyDistBlock(prog->block());

if (FLAGS_print_ir) {
std::cout << "IR after MixToDist Pass = " << *new_prog << std::endl;
VLOG(0) << "IR after MixToDist Pass = " << *prog << std::endl;
}

return new_prog;
// return prog;
}

} // namespace dialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace dialect {

// pir::Type ConvertOpTypeToKernelType(pir::Type op_type);

TEST_API std::shared_ptr<pir::Program> MixToDistPass(pir::Program* prog);
TEST_API void MixToDistPass(pir::Program* prog);

void ProcessMixBlock(pir::Block* block);

Expand Down
13 changes: 7 additions & 6 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,10 @@ def _parallel_pir(self, mode):
# Part 1: Complete program
# Step 1.1: Mix2Dense Pass
# TODO(JZ-LIANG) regulization pass with pass management.
dist_program = paddle.base.libpaddle.pir.apply_mix2dist_pass(
mix_fw_program
)
dist_program = mix_fw_program
self._fwd_main_progs[mode] = mix_fw_program.clone()
paddle.base.libpaddle.pir.apply_mix2dist_pass(dist_program)

# Step 1.2: pir backward
if mode == "train" and self._loss and self._optimizer:
loss = dist_program.get_output_value_by_name(self._loss_names[0])
Expand Down Expand Up @@ -699,9 +700,9 @@ def _parallel_pir(self, mode):

# TODO(JZ-LIANG) Step 4.4 Dist2Dense Pass
# NOTE All optimization pass that need dist_attr info should be called before Dist2Dense Pass.
dense_program = paddle.base.libpaddle.pir.apply_dist2dense_pass(
dist_program
)
dense_program = dist_program
dist_program = dist_program.clone()
paddle.base.libpaddle.pir.apply_dist2dense_pass(dense_program)

self._pir_dense_main_progs[mode] = dense_program
self._pir_dist_main_progs[mode] = dist_program
Expand Down
4 changes: 1 addition & 3 deletions test/auto_parallel/pir/test_learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ def test_copy_between_mesh(self):
engine = dist_model._engine
engine._build("train")
dist_program = engine._fwd_main_progs["train"]
dist_program = paddle.base.libpaddle.pir.apply_mix2dist_pass(
dist_program
)
paddle.base.libpaddle.pir.apply_mix2dist_pass(dist_program)
loss = dist_program.get_output_value_by_name(engine._loss_names[0])
with paddle.static.program_guard(dist_program):
params_grads = paddle.autograd.ir_backward.append_backward(loss)
Expand Down
4 changes: 1 addition & 3 deletions test/auto_parallel/pir/test_reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def test_to_static_program(self):
engine = dist_model._engine
engine._build("train")
dist_program = engine._fwd_main_progs["train"]
dist_program = paddle.base.libpaddle.pir.apply_mix2dist_pass(
dist_program
)
paddle.base.libpaddle.pir.apply_mix2dist_pass(dist_program)
loss = dist_program.get_output_value_by_name(engine._loss_names[0])
with paddle.static.program_guard(dist_program):
params_grads = paddle.autograd.ir_backward.append_backward(loss)
Expand Down
5 changes: 1 addition & 4 deletions test/auto_parallel/test_pir_mix2dist_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ def test_build_api(self):
initializer=paddle.nn.initializer.Uniform(),
)

dist_program = paddle.base.libpaddle.pir.apply_mix2dist_pass(
main_program
)
print(dist_program)
paddle.base.libpaddle.pir.apply_mix2dist_pass(main_program)


if __name__ == "__main__":
Expand Down
5 changes: 2 additions & 3 deletions test/cpp/pir/distributed/dist_dialect_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,8 @@ TEST(mix_to_dist_pass_test, base) {
(uint32_t)1);

// Apply Pass
std::shared_ptr<pir::Program> new_program =
paddle::dialect::MixToDistPass(&program);
pir::Block* new_block = new_program->block();
paddle::dialect::MixToDistPass(&program);
pir::Block* new_block = program.block();
EXPECT_EQ(2, static_cast<int>(new_block->num_ops()));
std::vector<pir::Operation*> ops;
for (auto& op : *new_block) {
Expand Down

0 comments on commit ae876f3

Please sign in to comment.