Skip to content

Commit

Permalink
add two more unit test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuzyc committed Sep 22, 2023
1 parent 8c27328 commit e9c5e27
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions test/auto_parallel/spmd_rules/test_reshape_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,44 @@ def test_reshape_infer_backward(self):
infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1]
)

# shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] (intput --> output)
# dims_mapping: [-1, 0, -1, -1, 1] --> [0, -1, -1, -1], [-1, 0, -1, -1, -1] (output --> input, output)
self.output_dist_tensor_spec.shape = [1, 72, 48, 4, 6]
self.output_dist_tensor_spec.set_dims_mapping([-1, 0, -1, -1, 1])
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['shape'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1, -1, -1]
)

# shape: [6, 12, 48, 24] --> [3, 24, 6, 8, 24] (intput --> output)
# dims_mapping: [-1, 1, -1, -1, 0] --> [-1, -1, -1, 0], [-1, -1, -1, -1, 0] (output --> input, output)
self.output_dist_tensor_spec.shape = [3, 24, 6, 8, 24]
self.output_dist_tensor_spec.set_dims_mapping([-1, 1, -1, -1, 0])
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['shape'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, 0]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1, 0]
)


if __name__ == "__main__":
unittest.main()

0 comments on commit e9c5e27

Please sign in to comment.