|
3 | 3 | # This source code is licensed under the BSD license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
6 | | -from contextlib import contextmanager |
| 6 | +import functools |
7 | 7 |
|
8 | 8 | import numpy as np |
9 | 9 | import torch |
10 | | -from torch.distributed.tensor import DTensor, Shard, distribute_tensor, init_device_mesh |
| 10 | +from torch.distributed.tensor import ( |
| 11 | + DTensor, |
| 12 | + Replicate, |
| 13 | + Shard, |
| 14 | + distribute_tensor, |
| 15 | + init_device_mesh, |
| 16 | +) |
11 | 17 | from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta |
12 | 18 | from torch.distributed.tensor._op_schema import ( |
13 | 19 | OpInfo, |
|
16 | 22 | OpStrategy, |
17 | 23 | OutputSharding, |
18 | 24 | OutputSpecType, |
| 25 | + RuntimeSchemaInfo, |
19 | 26 | TupleStrategy, |
20 | 27 | ) |
21 | 28 | from torch.testing._internal.common_utils import run_tests |
|
25 | 32 | ) |
26 | 33 |
|
27 | 34 | from autoparallel.dtensor_util import strategy_pool |
| 35 | +from autoparallel.dtensor_util.utils import batch_shard_strategy, op_strategy_context |
28 | 36 |
|
29 | 37 | aten = torch.ops.aten |
30 | 38 |
|
@@ -94,31 +102,6 @@ def _fw_tuple(x, y, z): |
94 | 102 | ) |
95 | 103 |
|
96 | 104 |
|
97 | | -@contextmanager |
98 | | -def op_strategy_context(op_overload, strategy_func, schema_info=None): |
99 | | - """ |
100 | | - Context manager for setting and clearing op strategies in unit tests. |
101 | | - Args: |
102 | | - op_overload: The operator overload to set or clear the strategy for. |
103 | | - strategy_func: The strategy function to set for the operator overload. |
104 | | - schema_info: Optional schema information for the operator overload. |
105 | | - Yields: |
106 | | - None |
107 | | - """ |
108 | | - try: |
109 | | - # register the op strategy |
110 | | - strategy_pool.register_op_strategy(op_overload, schema_info=schema_info)( |
111 | | - strategy_func |
112 | | - ) |
113 | | - yield |
114 | | - finally: |
115 | | - # clear this op strategy cache |
116 | | - if op_overload in strategy_pool.op_strategy_funcs: |
117 | | - del strategy_pool.op_strategy_funcs[op_overload] |
118 | | - if op_overload in strategy_pool.op_to_schema_info: |
119 | | - del strategy_pool.op_to_schema_info[op_overload] |
120 | | - |
121 | | - |
122 | 105 | # Overwrite upstream `_op_dispatcher.sharding_propagator` with customized |
123 | 106 | # sharding_propagator. This is for testing purpose under eager mode and |
124 | 107 | # AutoParallel won't use the propagate function. The main changes are 1) Skip |
@@ -394,5 +377,105 @@ def test_implicit_registration(self): |
394 | 377 | self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) |
395 | 378 |
|
396 | 379 |
|
| 380 | +class DimShardingTest(DTensorTestBase): |
| 381 | + @with_comms |
| 382 | + def test_simple_batch_sharding(self): |
| 383 | + # both input tensors batch on dim 0 |
| 384 | + mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) |
| 385 | + test_op = torch.ops.mylib.numpy_sin.default |
| 386 | + |
| 387 | + # 1. strategy that will try shard dim 0 into one devices axis. |
| 388 | + shard_first_dim_strategy = functools.partial( |
| 389 | + batch_shard_strategy, input_shard_dim=[0, 0], output_shard_dim=[0] |
| 390 | + ) |
| 391 | + with op_strategy_context(test_op, shard_first_dim_strategy): |
| 392 | + # dim 0 is the batch dim. Here we shard 16 over one device axis |
| 393 | + input_x = torch.randn([16, 1, 4], device=self.device_type) |
| 394 | + input_y = torch.randn([16, 1, 4], device=self.device_type) |
| 395 | + # any sharding below should work |
| 396 | + input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()]) |
| 397 | + input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)]) |
| 398 | + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) |
| 399 | + |
| 400 | + # 2. strategy that will try shard dim 0 into multiple devices. |
| 401 | + shard_first_dim_to_multiple_devices_strategy = functools.partial( |
| 402 | + batch_shard_strategy, |
| 403 | + input_shard_dim=[0, 0], |
| 404 | + output_shard_dim=[0], |
| 405 | + enable_shard_batch_dim_over_multiple_axis=True, |
| 406 | + ) |
| 407 | + with op_strategy_context(test_op, shard_first_dim_to_multiple_devices_strategy): |
| 408 | + # dim 0 is the batch dim. Here we potentially shard 16 over multiple device axes |
| 409 | + input_x = torch.randn([16, 1, 4], device=self.device_type) |
| 410 | + input_y = torch.randn([16, 1, 4], device=self.device_type) |
| 411 | + # any sharding below should work |
| 412 | + input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()]) |
| 413 | + input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)]) |
| 414 | + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) |
| 415 | + |
| 416 | + @with_comms |
| 417 | + def test_broadcast_batch_sharding(self): |
| 418 | + # Not recommended, user need to make sure the op supports input with |
| 419 | + # broadcast first. If not supported, try unsqueeze inputs first to match |
| 420 | + # each other's dimensions and and use the example in the |
| 421 | + # test_simple_batch_sharding test. |
| 422 | + mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) |
| 423 | + test_op = torch.ops.mylib.numpy_sin.default |
| 424 | + shard_on_first_dim_strategy = functools.partial( |
| 425 | + batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0] |
| 426 | + ) |
| 427 | + with op_strategy_context(test_op, shard_on_first_dim_strategy): |
| 428 | + input_x = torch.randn([1, 4], device=self.device_type) |
| 429 | + # input_y's 16 locates on the batch dim |
| 430 | + input_y = torch.randn([16, 1, 4], device=self.device_type) |
| 431 | + # any sharding below should work as long as the tensor dim it is shardable |
| 432 | + input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()]) |
| 433 | + input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)]) |
| 434 | + |
| 435 | + output_dt = test_op(input_x_dt, input_y_dt) |
| 436 | + |
| 437 | + # split the batch dim to test correctness |
| 438 | + input_y_chucks = torch.chunk(input_y, 4, dim=0) |
| 439 | + output = torch.cat( |
| 440 | + [test_op(input_x, input_y_part) for input_y_part in input_y_chucks] |
| 441 | + ) |
| 442 | + self.assertEqual(output_dt.full_tensor(), output) |
| 443 | + |
| 444 | + # or we can test directly since the op support broadcast. |
| 445 | + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) |
| 446 | + |
| 447 | + @with_comms |
| 448 | + def test_simple_tuple_batch_sharding(self): |
| 449 | + # both input tensors batch on dim 0 |
| 450 | + mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) |
| 451 | + test_op = torch.ops.mylib.numpy_tuple_sin.default |
| 452 | + |
| 453 | + # 1. strategy that will try shard dim 0 into one devices axis. |
| 454 | + shard_first_dim_to_multiple_devices_strategy = functools.partial( |
| 455 | + batch_shard_strategy, |
| 456 | + input_shard_dim=[0, 0, 0, 0, 0], # flatten input_y |
| 457 | + output_shard_dim=[0], |
| 458 | + enable_shard_batch_dim_over_multiple_axis=True, |
| 459 | + ) |
| 460 | + with op_strategy_context( |
| 461 | + test_op, |
| 462 | + shard_first_dim_to_multiple_devices_strategy, |
| 463 | + schema_info=RuntimeSchemaInfo(needs_pytree=True), |
| 464 | + ): |
| 465 | + # dim 0 is the batch dim. Here we shard 16 over one device axis |
| 466 | + input_x = torch.randn([16, 8, 4], device=self.device_type) |
| 467 | + input_y = [ |
| 468 | + torch.randn([16, 8, 4], device=self.device_type) for _ in range(3) |
| 469 | + ] |
| 470 | + input_z = torch.randn([16, 8, 4], device=self.device_type) |
| 471 | + # any sharding below should work as long as the tensor dim it is shardable |
| 472 | + input_x_dt = distribute_tensor(input_x, mesh, [Shard(0), Shard(1)]) |
| 473 | + input_y_dt = [ |
| 474 | + distribute_tensor(i, mesh, [Shard(1), Shard(1)]) for i in input_y |
| 475 | + ] |
| 476 | + input_z_dt = distribute_tensor(input_z, mesh, [Shard(1), Shard(0)]) |
| 477 | + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt, input_z_dt) |
| 478 | + |
| 479 | + |
397 | 480 | if __name__ == "__main__": |
398 | 481 | run_tests() |
0 commit comments