|
34 | 34 | from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP |
35 | 35 | from torchtitan.logging import logger |
36 | 36 | from torchtitan.parallelisms.parallel_dims import ParallelDims |
37 | | -from torchtitan.parallelisms.utils import check_if_feature_in_pytorch |
38 | 37 |
|
39 | 38 |
|
40 | 39 | def parallelize_llama( |
@@ -80,31 +79,22 @@ def parallelize_llama( |
80 | 79 | if ( |
81 | 80 | parallel_dims.dp_shard_enabled |
82 | 81 | ): # apply FSDP or HSDP, potentially with Context Parallel |
83 | | - try: |
84 | | - dp_mesh = ( |
85 | | - world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] |
86 | | - ) |
87 | | - except IndexError: |
88 | | - # note: this is a workaround of the above logic for old pytorch version |
89 | | - # where https://github.com/pytorch/pytorch/pull/138945 is not included |
90 | | - # throw a warning to encourage users to upgrade to a newer pytorch version |
91 | | - check_if_feature_in_pytorch( |
92 | | - "DeviceMesh flattening over 3D+ meshes", |
93 | | - "https://github.com/pytorch/pytorch/pull/138945", |
94 | | - "2.6.0.dev20241030", |
95 | | - ) |
96 | | - # TODO: remove this workaround once PyTorch 2.6 is released |
97 | | - dp_mesh_dim_names = ( |
98 | | - ("dp_replicate", "dp_shard") |
99 | | - if parallel_dims.dp_replicate_enabled |
100 | | - else ("dp",) |
101 | | - ) |
102 | | - # note that mesh can only be flattened from the finest-grained mesh dimensions |
103 | | - dp_mesh = ( |
104 | | - world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") |
105 | | - if parallel_dims.cp_enabled |
106 | | - else world_mesh[dp_mesh_dim_names] |
107 | | - ) |
| 82 | + |
| 83 | + # TODO: instead of flattening the mesh twice, we could've done in a batter way: |
| 84 | + # dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] |
| 85 | + # However, this leads to an error in `DeviceMesh.__get_item__` which I believe is |
| 86 | + # a bug in DeviceMesh. We should fix it and then use the above line. |
| 87 | + dp_mesh_dim_names = ( |
| 88 | + ("dp_replicate", "dp_shard") |
| 89 | + if parallel_dims.dp_replicate_enabled |
| 90 | + else ("dp",) |
| 91 | + ) |
| 92 | + # note that mesh can only be flattened from the finest-grained mesh dimensions |
| 93 | + dp_mesh = ( |
| 94 | + world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") |
| 95 | + if parallel_dims.cp_enabled |
| 96 | + else world_mesh[dp_mesh_dim_names] |
| 97 | + ) |
108 | 98 |
|
109 | 99 | apply_fsdp( |
110 | 100 | model, |
|
0 commit comments