Skip to content

Commit

Permalink
mix2dist_pass support shard randomly sampled data (#67589)
Browse files Browse the repository at this point in the history
* mix2dist_pass support shard randomly sampled data

* add unit test case of checking full_int_array op upstream and its result
  • Loading branch information
jeff41404 authored Aug 21, 2024
1 parent 8a129e2 commit fb57ee7
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 2 deletions.
31 changes: 29 additions & 2 deletions python/paddle/distributed/auto_parallel/static/mix_to_dist_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def verify_dist_block(block):
raise RuntimeError("Block still contain shard_tensor_op.")
if op.dist_attr is None:
raise RuntimeError(
f"The op {op} does not hase OperatorDistAttr after Mix2Dist Pass."
f"The op {op} does not have OperatorDistAttr after Mix2Dist Pass."
)
for result in op.results():
if not result.initialized():
Expand Down Expand Up @@ -67,7 +67,34 @@ def apply_mix2dist_pass(program):
shard_operand_value.set_type(shard_result_value.type())
shard_operand_value.stop_gradient = shard_result_value.stop_gradient
shard_operand_value.persistable = shard_result_value.persistable

elif (
prev_op.name() == "pd_op.randint"
or prev_op.name() == "pd_op.gaussian"
):
mesh = shard_result_value.dist_attr().process_mesh
# input
shape_value = prev_op.operand_source(0)
dist_attr = paddle.base.libpaddle.pir.create_tensor_dist_attribute(
mesh, [-1 for _ in range(len(shape_value.shape))], {}
)
shape_value.update_dist_attr(dist_attr)
# op
prev_op.dist_attr = (
paddle.base.libpaddle.pir.create_op_dist_attribute(
mesh, [dist_attr], [shard_result_value.dist_attr()]
)
)
# deal with full_int_array op
prev_prev_op = shape_value.get_defining_op()
prev_prev_op.dist_attr = (
paddle.base.libpaddle.pir.create_op_dist_attribute(
mesh, [], [dist_attr]
)
)
# output
shard_operand_value.set_type(shard_result_value.type())
shard_operand_value.stop_gradient = shard_result_value.stop_gradient
shard_operand_value.persistable = shard_result_value.persistable
else:
dist_attr = shard_result_value.dist_attr()
if not is_replicated(dist_attr):
Expand Down
87 changes: 87 additions & 0 deletions test/auto_parallel/pir/test_static_pir_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

import paddle
import paddle.distributed as dist
from paddle.distributed.auto_parallel.static.mix_to_dist_pass import (
apply_mix2dist_pass,
)

BATCH_SIZE = 2
SEQ_LEN = 4
Expand Down Expand Up @@ -141,6 +144,90 @@ def test_build_with_shard_tensor(self):
self.assertEqual(dist_w1.dist_attr().process_mesh, mesh)
self.assertEqual(dist_w1.dist_attr().dims_mapping, [-1, 0])

def test_build_with_apply_mix2dist_pass(self):
paddle.enable_static()
with paddle.pir_utils.IrGuard():
main_program = paddle.base.Program()
with paddle.base.program_guard(main_program):
mesh = dist.ProcessMesh([0, 1], dim_names=['dp'])
input1 = paddle.randint(low=0, high=1000, shape=[8, 4])
output1 = dist.shard_tensor(input1, mesh, [dist.Shard(0)])

input2 = paddle.randn([4, 8])
output2 = dist.shard_tensor(input2, mesh, [dist.Shard(1)])

self.assertTrue(input1.is_dense_tensor_type())
self.assertTrue(input2.is_dense_tensor_type())

self.assertTrue(main_program.num_ops() == 6)

self.assertFalse(input1.use_empty())
self.assertFalse(input2.use_empty())

self.assertTrue(output1.use_empty())
self.assertTrue(output2.use_empty())

self.assertFalse(input1.get_defining_op().has_attr("op_dist_attr"))
self.assertFalse(input2.get_defining_op().has_attr("op_dist_attr"))

# check dist type
self.assertTrue(output1.is_dist_dense_tensor_type())
self.assertTrue(output2.is_dist_dense_tensor_type())

# run apply_mix2dist_pass
apply_mix2dist_pass(main_program)

# after apply_mix2dist_pass, the program changed
self.assertTrue(main_program.num_ops() == 4)

self.assertTrue(input1.is_dist_dense_tensor_type())
self.assertTrue(input2.is_dist_dense_tensor_type())

self.assertTrue(input1.get_defining_op().has_attr("op_dist_attr"))
self.assertTrue(input2.get_defining_op().has_attr("op_dist_attr"))

# check op result dist_attr
input1_op_dist_attr = input1.get_defining_op().dist_attr
tensor_dist_attr = input1_op_dist_attr.result(0).as_tensor_dist_attr()
self.assertEqual(tensor_dist_attr.process_mesh, mesh)
self.assertEqual(tensor_dist_attr.dims_mapping, [0, -1])

input2_op_dist_attr = input2.get_defining_op().dist_attr
tensor_dist_attr = input2_op_dist_attr.result(0).as_tensor_dist_attr()
self.assertEqual(tensor_dist_attr.process_mesh, mesh)
self.assertEqual(tensor_dist_attr.dims_mapping, [-1, 0])

# check value dist_attr
self.assertEqual(input1.dist_attr().process_mesh, mesh)
self.assertEqual(input1.dist_attr().dims_mapping, [0, -1])

self.assertEqual(input2.dist_attr().process_mesh, mesh)
self.assertEqual(input2.dist_attr().dims_mapping, [-1, 0])

# check full_int_array op result dist_attr
input1_shape = input1.get_defining_op().operand_source(0)
input1_shape_op_dist_attr = input1_shape.get_defining_op().dist_attr
tensor_dist_attr = input1_shape_op_dist_attr.result(
0
).as_tensor_dist_attr()
self.assertEqual(tensor_dist_attr.process_mesh, mesh)
self.assertEqual(tensor_dist_attr.dims_mapping, [-1])

input2_shape = input2.get_defining_op().operand_source(0)
input2_shape_op_dist_attr = input2_shape.get_defining_op().dist_attr
tensor_dist_attr = input2_shape_op_dist_attr.result(
0
).as_tensor_dist_attr()
self.assertEqual(tensor_dist_attr.process_mesh, mesh)
self.assertEqual(tensor_dist_attr.dims_mapping, [-1])

# check shape value dist_attr
self.assertEqual(input1_shape.dist_attr().process_mesh, mesh)
self.assertEqual(input1_shape.dist_attr().dims_mapping, [-1])

self.assertEqual(input2_shape.dist_attr().process_mesh, mesh)
self.assertEqual(input2_shape.dist_attr().dims_mapping, [-1])


if __name__ == "__main__":
unittest.main()

0 comments on commit fb57ee7

Please sign in to comment.