|
3 | 3 | # This source code is licensed under the BSD license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
| 6 | +import functools |
6 | 7 | from contextlib import contextmanager |
7 | 8 |
|
8 | 9 | import numpy as np |
9 | 10 | import torch |
10 | | -from torch.distributed.tensor import DTensor, Shard, distribute_tensor, init_device_mesh |
| 11 | +from torch.distributed.tensor import ( |
| 12 | + DTensor, |
| 13 | + Replicate, |
| 14 | + Shard, |
| 15 | + distribute_tensor, |
| 16 | + init_device_mesh, |
| 17 | +) |
11 | 18 | from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta |
12 | 19 | from torch.distributed.tensor._op_schema import ( |
13 | 20 | OpInfo, |
|
16 | 23 | OpStrategy, |
17 | 24 | OutputSharding, |
18 | 25 | OutputSpecType, |
| 26 | + RuntimeSchemaInfo, |
19 | 27 | TupleStrategy, |
20 | 28 | ) |
21 | 29 | from torch.testing._internal.common_utils import run_tests |
|
25 | 33 | ) |
26 | 34 |
|
27 | 35 | from autoparallel.dtensor_util import strategy_pool |
| 36 | +from autoparallel.dtensor_util.utils import batch_shard_strategy |
28 | 37 |
|
29 | 38 | aten = torch.ops.aten |
30 | 39 |
|
@@ -394,5 +403,105 @@ def test_implicit_registration(self): |
394 | 403 | self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) |
395 | 404 |
|
396 | 405 |
|
| 406 | +class DimShardingTest(DTensorTestBase): |
| 407 | + @with_comms |
| 408 | + def test_simple_batch_sharding(self): |
| 409 | + # both input tensors batch on dim 0 |
| 410 | + mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) |
| 411 | + test_op = torch.ops.mylib.numpy_sin.default |
| 412 | + |
| 413 | + # 1. strategy that will try shard dim 0 into one devices axis. |
| 414 | + shard_first_dim_strategy = functools.partial( |
| 415 | + batch_shard_strategy, input_shard_dim=[0, 0], output_shard_dim=[0] |
| 416 | + ) |
| 417 | + with op_strategy_context(test_op, shard_first_dim_strategy): |
| 418 | + # dim 0 is the batch dim. Here we shard 16 over one device axis |
| 419 | + input_x = torch.randn([16, 1, 4], device=self.device_type) |
| 420 | + input_y = torch.randn([16, 1, 4], device=self.device_type) |
| 421 | + # any sharding below should work |
| 422 | + input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()]) |
| 423 | + input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)]) |
| 424 | + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) |
| 425 | + |
| 426 | + # 2. strategy that will try shard dim 0 into multiple devices. |
| 427 | + shard_first_dim_to_multiple_devices_strategy = functools.partial( |
| 428 | + batch_shard_strategy, |
| 429 | + input_shard_dim=[0, 0], |
| 430 | + output_shard_dim=[0], |
| 431 | + enable_shard_batch_dim_over_multiple_axis=True, |
| 432 | + ) |
| 433 | + with op_strategy_context(test_op, shard_first_dim_to_multiple_devices_strategy): |
| 434 | + # dim 0 is the batch dim. Here we potentially shard 16 over multiple device axises |
| 435 | + input_x = torch.randn([16, 1, 4], device=self.device_type) |
| 436 | + input_y = torch.randn([16, 1, 4], device=self.device_type) |
| 437 | + # any sharding below should work |
| 438 | + input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()]) |
| 439 | + input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)]) |
| 440 | + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) |
| 441 | + |
| 442 | + @with_comms |
| 443 | + def test_broadcast_batch_sharding(self): |
| 444 | + # Not recommended, user need to make sure the op supports input with |
| 445 | + # broadcast first. If not supported, try unsqueeze inputs first to match |
| 446 | + # each other's dimensions and and use the example in the |
| 447 | + # test_simple_batch_sharding test. |
| 448 | + mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) |
| 449 | + test_op = torch.ops.mylib.numpy_sin.default |
| 450 | + shard_on_first_dim_strategy = functools.partial( |
| 451 | + batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0] |
| 452 | + ) |
| 453 | + with op_strategy_context(test_op, shard_on_first_dim_strategy): |
| 454 | + input_x = torch.randn([1, 4], device=self.device_type) |
| 455 | + # input_y's 16 locates on the batch dim |
| 456 | + input_y = torch.randn([16, 1, 4], device=self.device_type) |
| 457 | + # any sharding below should work as long as the tensor dim it is shardable |
| 458 | + input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()]) |
| 459 | + input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)]) |
| 460 | + |
| 461 | + output_dt = test_op(input_x_dt, input_y_dt) |
| 462 | + |
| 463 | + # split the batch dim to test correctness |
| 464 | + input_y_chucks = torch.chunk(input_y, 4, dim=0) |
| 465 | + output = torch.cat( |
| 466 | + [test_op(input_x, input_y_part) for input_y_part in input_y_chucks] |
| 467 | + ) |
| 468 | + self.assertEqual(output_dt.full_tensor(), output) |
| 469 | + |
| 470 | + # or we can test directly since the op support broadcast. |
| 471 | + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) |
| 472 | + |
| 473 | + @with_comms |
| 474 | + def test_simple_tuple_batch_sharding(self): |
| 475 | + # both input tensors batch on dim 0 |
| 476 | + mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) |
| 477 | + test_op = torch.ops.mylib.numpy_tuple_sin.default |
| 478 | + |
| 479 | + # 1. strategy that will try shard dim 0 into one devices axis. |
| 480 | + shard_first_dim_to_multiple_devices_strategy = functools.partial( |
| 481 | + batch_shard_strategy, |
| 482 | + input_shard_dim=[0, 0, 0, 0, 0], # flatten input_y |
| 483 | + output_shard_dim=[0], |
| 484 | + enable_shard_batch_dim_over_multiple_axis=True, |
| 485 | + ) |
| 486 | + with op_strategy_context( |
| 487 | + test_op, |
| 488 | + shard_first_dim_to_multiple_devices_strategy, |
| 489 | + schema_info=RuntimeSchemaInfo(needs_pytree=True), |
| 490 | + ): |
| 491 | + # dim 0 is the batch dim. Here we shard 16 over one device axis |
| 492 | + input_x = torch.randn([16, 8, 4], device=self.device_type) |
| 493 | + input_y = [ |
| 494 | + torch.randn([16, 8, 4], device=self.device_type) for _ in range(3) |
| 495 | + ] |
| 496 | + input_z = torch.randn([16, 8, 4], device=self.device_type) |
| 497 | + # any sharding below should work as long as the tensor dim it is shardable |
| 498 | + input_x_dt = distribute_tensor(input_x, mesh, [Shard(0), Shard(1)]) |
| 499 | + input_y_dt = [ |
| 500 | + distribute_tensor(i, mesh, [Shard(1), Shard(1)]) for i in input_y |
| 501 | + ] |
| 502 | + input_z_dt = distribute_tensor(input_z, mesh, [Shard(1), Shard(0)]) |
| 503 | + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt, input_z_dt) |
| 504 | + |
| 505 | + |
397 | 506 | if __name__ == "__main__": |
398 | 507 | run_tests() |
0 commit comments