Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghuancoder committed May 16, 2024
1 parent 094873d commit 9478637
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions test/auto_parallel/pir/test_ir_dist_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class TestBuildFakeProgram(unittest.TestCase):
def test_build_api(self):
with paddle.pir_utils.IrGuard():
main_program = paddle.base.Program()
with paddle.base.program_guard(main_program):
start_program = paddle.base.Program()
with paddle.base.program_guard(main_program, start_program):
mesh = dist.ProcessMesh([0, 1], dim_names=['mp'])
input = paddle.static.data(
name='input', shape=[BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]
Expand All @@ -56,7 +57,8 @@ def test_build_api(self):
def test_build_replicated_program(self):
with paddle.pir_utils.IrGuard():
main_program = paddle.base.Program()
with paddle.base.program_guard(main_program):
start_program = paddle.base.Program()
with paddle.base.program_guard(main_program, start_program):
mesh = dist.ProcessMesh([0, 1], dim_names=['mp'])
input = paddle.static.data(
name='input', shape=[BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]
Expand Down Expand Up @@ -123,7 +125,8 @@ def test_build_replicated_program(self):
def test_build_col_parallel_program(self):
with paddle.pir_utils.IrGuard():
main_program = paddle.base.Program()
with paddle.base.program_guard(main_program):
start_program = paddle.base.Program()
with paddle.base.program_guard(main_program, start_program):
mesh = dist.ProcessMesh([0, 1], dim_names=['mp'])
input = paddle.static.data(
name='input', shape=[BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]
Expand Down Expand Up @@ -172,7 +175,8 @@ def test_build_col_parallel_program(self):
def test_build_row_parallel_program(self):
with paddle.pir_utils.IrGuard():
main_program = paddle.base.Program()
with paddle.base.program_guard(main_program):
start_program = paddle.base.Program()
with paddle.base.program_guard(main_program, start_program):
mesh = dist.ProcessMesh([0, 1], dim_names=['mp'])
input = paddle.static.data(
name='input',
Expand Down Expand Up @@ -224,7 +228,8 @@ def test_build_row_parallel_program(self):
def test_build_with_shard_tensor(self):
with paddle.pir_utils.IrGuard():
main_program = paddle.base.Program()
with paddle.base.program_guard(main_program):
start_program = paddle.base.Program()
with paddle.base.program_guard(main_program, start_program):
mesh = dist.ProcessMesh([0, 1], dim_names=['mp'])
input = paddle.static.data(
name='input',
Expand Down

0 comments on commit 9478637

Please sign in to comment.