|
3 | 3 | # This source code is licensed under the BSD license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
| 6 | +import functools |
| 7 | + |
6 | 8 | import numpy as np |
7 | 9 | import torch |
8 | 10 | from torch.distributed.device_mesh import init_device_mesh |
9 | | -from torch.distributed.tensor import DTensor, Shard, distribute_tensor |
| 11 | +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor |
10 | 12 | from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta |
11 | 13 | from torch.distributed.tensor._op_schema import ( |
12 | 14 | OpInfo, |
|
15 | 17 | OpStrategy, |
16 | 18 | OutputSharding, |
17 | 19 | OutputSpecType, |
| 20 | + RuntimeSchemaInfo, |
18 | 21 | TupleStrategy, |
19 | 22 | ) |
20 | 23 | from torch.testing._internal.common_utils import run_tests |
|
23 | 26 | with_comms, |
24 | 27 | ) |
25 | 28 |
|
26 | | -from autoparallel.dtensor_util import get_op_strategy, with_implicit_strategies |
| 29 | +from autoparallel.dtensor_util import ( |
| 30 | + batch_shard_strategy, |
| 31 | + get_op_strategy, |
| 32 | + op_strategy_context, |
| 33 | + with_implicit_strategies, |
| 34 | +) |
27 | 35 |
|
28 | 36 | propagator = DTensor._op_dispatcher.sharding_propagator |
29 | 37 |
|
@@ -370,5 +378,105 @@ def test_implicit_registration(self): |
370 | 378 | self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) |
371 | 379 |
|
372 | 380 |
|
| 381 | +class DimShardingTest(DTensorTestBase): |
| 382 | + @with_comms |
| 383 | + def test_simple_batch_sharding(self): |
| 384 | + # both input tensors batch on dim 0 |
| 385 | + mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) |
| 386 | + test_op = torch.ops.mylib.numpy_sin.default |
| 387 | + |
| 388 | + # 1. strategy that will try shard dim 0 into one devices axis. |
| 389 | + shard_first_dim_strategy = functools.partial( |
| 390 | + batch_shard_strategy, input_shard_dim=[0, 0], output_shard_dim=[0] |
| 391 | + ) |
| 392 | + with op_strategy_context(test_op, shard_first_dim_strategy): |
| 393 | + # dim 0 is the batch dim. Here we shard 16 over one device axis |
| 394 | + input_x = torch.randn([16, 1, 4], device=self.device_type) |
| 395 | + input_y = torch.randn([16, 1, 4], device=self.device_type) |
| 396 | + # any sharding below should work |
| 397 | + input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()]) |
| 398 | + input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)]) |
| 399 | + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) |
| 400 | + |
| 401 | + # 2. strategy that will try shard dim 0 into multiple devices. |
| 402 | + shard_first_dim_to_multiple_devices_strategy = functools.partial( |
| 403 | + batch_shard_strategy, |
| 404 | + input_shard_dim=[0, 0], |
| 405 | + output_shard_dim=[0], |
| 406 | + enable_shard_batch_dim_over_multiple_axis=True, |
| 407 | + ) |
| 408 | + with op_strategy_context(test_op, shard_first_dim_to_multiple_devices_strategy): |
| 409 | + # dim 0 is the batch dim. Here we potentially shard 16 over multiple device axes |
| 410 | + input_x = torch.randn([16, 1, 4], device=self.device_type) |
| 411 | + input_y = torch.randn([16, 1, 4], device=self.device_type) |
| 412 | + # any sharding below should work |
| 413 | + input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()]) |
| 414 | + input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)]) |
| 415 | + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) |
| 416 | + |
| 417 | + @with_comms |
| 418 | + def test_broadcast_batch_sharding(self): |
| 419 | + # Not recommended, user need to make sure the op supports input with |
| 420 | + # broadcast first. If not supported, try unsqueeze inputs first to match |
| 421 | + # each other's dimensions and and use the example in the |
| 422 | + # test_simple_batch_sharding test. |
| 423 | + mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) |
| 424 | + test_op = torch.ops.mylib.numpy_sin.default |
| 425 | + shard_on_first_dim_strategy = functools.partial( |
| 426 | + batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0] |
| 427 | + ) |
| 428 | + with op_strategy_context(test_op, shard_on_first_dim_strategy): |
| 429 | + input_x = torch.randn([1, 4], device=self.device_type) |
| 430 | + # input_y's 16 locates on the batch dim |
| 431 | + input_y = torch.randn([16, 1, 4], device=self.device_type) |
| 432 | + # any sharding below should work as long as the tensor dim it is shardable |
| 433 | + input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()]) |
| 434 | + input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)]) |
| 435 | + |
| 436 | + output_dt = test_op(input_x_dt, input_y_dt) |
| 437 | + |
| 438 | + # split the batch dim to test correctness |
| 439 | + input_y_chucks = torch.chunk(input_y, 4, dim=0) |
| 440 | + output = torch.cat( |
| 441 | + [test_op(input_x, input_y_part) for input_y_part in input_y_chucks] |
| 442 | + ) |
| 443 | + self.assertEqual(output_dt.full_tensor(), output) |
| 444 | + |
| 445 | + # or we can test directly since the op support broadcast. |
| 446 | + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) |
| 447 | + |
| 448 | + @with_comms |
| 449 | + def test_simple_tuple_batch_sharding(self): |
| 450 | + # both input tensors batch on dim 0 |
| 451 | + mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) |
| 452 | + test_op = torch.ops.mylib.numpy_tuple_sin.default |
| 453 | + |
| 454 | + # 1. strategy that will try shard dim 0 into one devices axis. |
| 455 | + shard_first_dim_to_multiple_devices_strategy = functools.partial( |
| 456 | + batch_shard_strategy, |
| 457 | + input_shard_dim=[0, 0, 0, 0, 0], # flatten input_y |
| 458 | + output_shard_dim=[0], |
| 459 | + enable_shard_batch_dim_over_multiple_axis=True, |
| 460 | + ) |
| 461 | + with op_strategy_context( |
| 462 | + test_op, |
| 463 | + shard_first_dim_to_multiple_devices_strategy, |
| 464 | + schema_info=RuntimeSchemaInfo(needs_pytree=True), |
| 465 | + ): |
| 466 | + # dim 0 is the batch dim. Here we shard 16 over one device axis |
| 467 | + input_x = torch.randn([16, 8, 4], device=self.device_type) |
| 468 | + input_y = [ |
| 469 | + torch.randn([16, 8, 4], device=self.device_type) for _ in range(3) |
| 470 | + ] |
| 471 | + input_z = torch.randn([16, 8, 4], device=self.device_type) |
| 472 | + # any sharding below should work as long as the tensor dim it is shardable |
| 473 | + input_x_dt = distribute_tensor(input_x, mesh, [Shard(0), Shard(1)]) |
| 474 | + input_y_dt = [ |
| 475 | + distribute_tensor(i, mesh, [Shard(1), Shard(1)]) for i in input_y |
| 476 | + ] |
| 477 | + input_z_dt = distribute_tensor(input_z, mesh, [Shard(1), Shard(0)]) |
| 478 | + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt, input_z_dt) |
| 479 | + |
| 480 | + |
373 | 481 | if __name__ == "__main__": |
374 | 482 | run_tests() |
0 commit comments