|
9 | 9 |
|
10 | 10 | # COPY of the code from torch.distributed._tensor._shards_wrapper - for package compat
|
11 | 11 |
|
| 12 | +import logging |
12 | 13 | from typing import Any, List, Tuple
|
13 | 14 |
|
14 | 15 | import torch
|
|
24 | 25 | WriteItemType,
|
25 | 26 | )
|
26 | 27 |
|
| 28 | +logger: logging.Logger = logging.getLogger(__name__) |
27 | 29 | aten = torch.ops.aten # pyre-ignore[5]
|
28 | 30 |
|
29 | 31 |
|
@@ -73,7 +75,7 @@ def __new__(
|
73 | 75 | cat_tensor_shape[1] += shard.size()[1]
|
74 | 76 |
|
75 | 77 | # in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension
|
76 |
| - if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding |
| 78 | + if len(local_shards) > 1 and local_shards[0].ndim == 1: # row-wise sharding |
77 | 79 | for shard in local_shards[1:]:
|
78 | 80 | cat_tensor_shape[0] += shard.size()[0]
|
79 | 81 |
|
@@ -119,6 +121,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
119 | 121 | aten.copy_.default: cls.handle_copy_,
|
120 | 122 | aten.zeros_like.default: cls.handle_zeros_like,
|
121 | 123 | aten.empty_like.default: cls.handle_empty_like,
|
| 124 | + aten.constant_pad_nd.default: cls.handle_constant_pad_nd, |
122 | 125 | }
|
123 | 126 |
|
124 | 127 | if func in dispatcher:
|
@@ -162,12 +165,14 @@ def handle_copy_(args, kwargs):
|
162 | 165 | # pyre-fixme[3]: Return type must be annotated.
|
163 | 166 | # pyre-fixme[2]: Parameter must be annotated.
|
164 | 167 | def handle_all_gather_into_tensor(args, kwargs):
|
165 |
| - dim = args[0].local_sizes()[0][1] |
166 |
| - cat_tensor = torch.cat( |
167 |
| - [t.view(-1) for t in args[0].local_shards()], dim=0 |
168 |
| - ).view(-1, dim) |
| 168 | + local_shards = args[0].local_shards() |
| 169 | + if len(local_shards) == 1: |
| 170 | + result_tensor = local_shards[0] |
| 171 | + # 2D CW sharding: concat columns, 1D RW sharding: concat rows |
| 172 | + result_tensor = torch.cat(local_shards, dim=-1) |
| 173 | + logger.info(f"resulting tensor before all gather: {result_tensor}") |
169 | 174 | return torch.ops._c10d_functional.all_gather_into_tensor.default(
|
170 |
| - cat_tensor, *args[1:], **kwargs |
| 175 | + result_tensor, *args[1:], **kwargs |
171 | 176 | )
|
172 | 177 |
|
173 | 178 | @staticmethod
|
@@ -279,6 +284,211 @@ def handle_new_empty(args, kwargs):
|
279 | 284 | self_ls.local_offsets(),
|
280 | 285 | )
|
281 | 286 |
|
| 287 | + @staticmethod |
| 288 | + # pyre-fixme[3]: Return type must be annotated. |
| 289 | + # pyre-fixme[2]: Parameter must be annotated. |
| 290 | + def handle_constant_pad_nd(args, kwargs): |
| 291 | + """ |
| 292 | + Apply constant padding to LocalShardsWrapper. |
| 293 | +
|
| 294 | + The padding is based off of the following ideas: |
| 295 | + - The resulting wrapper represents the padded version of the logical tensor. |
| 296 | + - Each shard is padded based on the sharding type + dimension that is padded. |
| 297 | + - For instance, CW shards padded on the left most col will have only padding on the first CW shard. |
| 298 | + - Padding the top row will apply to all CW shards. |
| 299 | + """ |
| 300 | + self_lsw = args[0] |
| 301 | + pad_spec = args[1] |
| 302 | + pad_value = args[2] if len(args) > 2 else 0.0 |
| 303 | + logger.info( |
| 304 | + f"padding {self_lsw} with {pad_spec} and value: {pad_value}, current shards: {self_lsw.local_shards()} with offsets: {self_lsw.local_offsets()}. tensor storage metadata: {self_lsw.storage_metadata()}" |
| 305 | + ) |
| 306 | + |
| 307 | + if len(self_lsw.local_shards()) == 0: |
| 308 | + raise NotImplementedError( |
| 309 | + "Padding empty LocalShardsWrapper is not supported." |
| 310 | + ) |
| 311 | + |
| 312 | + local_shards = self_lsw.local_shards() |
| 313 | + |
| 314 | + if len(local_shards) == 1: |
| 315 | + padded_shard = torch.nn.functional.pad( |
| 316 | + local_shards[0], pad_spec, mode="constant", value=pad_value |
| 317 | + ) |
| 318 | + return LocalShardsWrapper([padded_shard], self_lsw.local_offsets()) |
| 319 | + |
| 320 | + padded_shards = list(local_shards) |
| 321 | + |
| 322 | + if local_shards[0].ndim == 2: |
| 323 | + # 2D Column-wise sharding: [pad_left, pad_right, pad_top, pad_bottom] |
| 324 | + if len(pad_spec) == 2: |
| 325 | + # Single dimension padding happens on the left most column |
| 326 | + pad_spec = pad_spec + [0, 0] |
| 327 | + |
| 328 | + if len(pad_spec) != 4: |
| 329 | + raise ValueError( |
| 330 | + f"Padding spec must be of length 4 for 2D tensors, got {len(pad_spec)}" |
| 331 | + ) |
| 332 | + |
| 333 | + pad_left, pad_right, pad_top, pad_bottom = ( |
| 334 | + pad_spec[0], |
| 335 | + pad_spec[1], |
| 336 | + pad_spec[2], |
| 337 | + pad_spec[3], |
| 338 | + ) |
| 339 | + |
| 340 | + if pad_top > 0: |
| 341 | + padded_shards = [ |
| 342 | + torch.nn.functional.pad( |
| 343 | + shard, [0, 0, pad_top, 0], mode="constant", value=pad_value |
| 344 | + ) |
| 345 | + for shard in padded_shards |
| 346 | + ] |
| 347 | + if pad_bottom > 0: |
| 348 | + padded_shards = [ |
| 349 | + torch.nn.functional.pad( |
| 350 | + shard, [0, 0, 0, pad_bottom], mode="constant", value=pad_value |
| 351 | + ) |
| 352 | + for shard in padded_shards |
| 353 | + ] |
| 354 | + if pad_left > 0: |
| 355 | + padded_shards[0] = torch.nn.functional.pad( |
| 356 | + padded_shards[0], |
| 357 | + [pad_left, 0, 0, 0], |
| 358 | + mode="constant", |
| 359 | + value=pad_value, |
| 360 | + ) |
| 361 | + if pad_right > 0: |
| 362 | + padded_shards[-1] = torch.nn.functional.pad( |
| 363 | + padded_shards[-1], |
| 364 | + [0, pad_right, 0, 0], |
| 365 | + mode="constant", |
| 366 | + value=pad_value, |
| 367 | + ) |
| 368 | + elif local_shards[0].ndim == 1: |
| 369 | + # 1D Row-wise sharding: [pad_top, pad_bottom] |
| 370 | + if len(pad_spec) != 2: |
| 371 | + raise ValueError( |
| 372 | + f"Padding spec must be of length 2 for 1D tensors, got {len(pad_spec)}" |
| 373 | + ) |
| 374 | + pad_top, pad_bottom = pad_spec[0], pad_spec[1] |
| 375 | + |
| 376 | + if pad_top > 0: |
| 377 | + padded_shards[0] = torch.nn.functional.pad( |
| 378 | + padded_shards[0], [pad_top, 0], mode="constant", value=pad_value |
| 379 | + ) |
| 380 | + if pad_bottom > 0: |
| 381 | + padded_shards[-1] = torch.nn.functional.pad( |
| 382 | + padded_shards[-1], [0, pad_bottom], mode="constant", value=pad_value |
| 383 | + ) |
| 384 | + else: |
| 385 | + raise NotImplementedError( |
| 386 | + f"Padding for {local_shards[0].ndim}D tensors is not supported. " |
| 387 | + f"Only 1D and 2D tensors are currently supported." |
| 388 | + ) |
| 389 | + |
| 390 | + # Update offsets and storage metadata |
| 391 | + original_storage = self_lsw.storage_metadata() |
| 392 | + updated_offsets, updated_storage = LocalShardsWrapper._compute_updated_metadata( |
| 393 | + original_storage, |
| 394 | + self_lsw.local_offsets(), |
| 395 | + pad_spec, |
| 396 | + local_shards[0].ndim, |
| 397 | + padded_shards, |
| 398 | + ) |
| 399 | + |
| 400 | + result = LocalShardsWrapper(padded_shards, updated_offsets) |
| 401 | + result._storage_meta = updated_storage |
| 402 | + return result |
| 403 | + |
| 404 | + @staticmethod |
| 405 | + def _compute_updated_metadata( |
| 406 | + original_storage: TensorStorageMetadata, |
| 407 | + original_offsets: list[torch.Size], |
| 408 | + pad_spec: list[int], |
| 409 | + ndim: int, |
| 410 | + padded_shards: list[torch.Tensor], |
| 411 | + ) -> tuple[list[tuple[int, ...]], TensorStorageMetadata]: |
| 412 | + """ |
| 413 | + Compute updated offsets and storage metadata after padding is applied. |
| 414 | +
|
| 415 | + Args: |
| 416 | + original_storage: Original storage metadata |
| 417 | + original_offsets: Original shard offsets |
| 418 | + pad_spec: Padding specification |
| 419 | + ndim: Number of dimensions (1=RW or 2=CW) |
| 420 | + padded_shards: Padded shard tensors |
| 421 | +
|
| 422 | + Returns: |
| 423 | + Tuple of (updated_offsets, updated_storage_metadata) |
| 424 | + """ |
| 425 | + if ndim == 1: # 1D RW |
| 426 | + pad_top, pad_bottom = pad_spec[0], pad_spec[1] |
| 427 | + |
| 428 | + updated_offsets = [] |
| 429 | + for i, offset in enumerate(original_offsets): |
| 430 | + if i == 0: |
| 431 | + # First shard: offset stays the same (absorbs top padding) |
| 432 | + updated_offsets.append(tuple(offset)) |
| 433 | + else: |
| 434 | + # Subsequent shards: shift by top padding amount |
| 435 | + new_offset = (offset[0] + pad_top,) |
| 436 | + updated_offsets.append(new_offset) |
| 437 | + |
| 438 | + new_global_size = torch.Size( |
| 439 | + [original_storage.size[0] + pad_top + pad_bottom] |
| 440 | + ) |
| 441 | + |
| 442 | + elif ndim == 2: # 2D CW |
| 443 | + pad_left, pad_right, pad_top, pad_bottom = ( |
| 444 | + pad_spec[0], |
| 445 | + pad_spec[1], |
| 446 | + pad_spec[2], |
| 447 | + pad_spec[3], |
| 448 | + ) |
| 449 | + |
| 450 | + updated_offsets = [] |
| 451 | + for i, offset in enumerate(original_offsets): |
| 452 | + row_offset = offset[0] |
| 453 | + col_offset = offset[1] |
| 454 | + |
| 455 | + # Top/bottom padding doesn't affect offsets |
| 456 | + # Left padding affects column offsets |
| 457 | + if i == 0: |
| 458 | + # First shard: column offset stays the same (absorbs left padding) |
| 459 | + new_2d_offset = (row_offset, col_offset) |
| 460 | + else: |
| 461 | + # Subsequent shards: shift column offset by left padding amount |
| 462 | + new_2d_offset = (row_offset, col_offset + pad_left) |
| 463 | + |
| 464 | + updated_offsets.append(new_2d_offset) |
| 465 | + |
| 466 | + new_global_size = torch.Size( |
| 467 | + [ |
| 468 | + original_storage.size[0] + pad_top + pad_bottom, |
| 469 | + original_storage.size[1] + pad_left + pad_right, |
| 470 | + ] |
| 471 | + ) |
| 472 | + |
| 473 | + else: |
| 474 | + raise NotImplementedError(f"Metadata computation for {ndim}D not supported") |
| 475 | + |
| 476 | + updated_chunks = [ |
| 477 | + ChunkStorageMetadata( |
| 478 | + offsets=torch.Size(offset), |
| 479 | + sizes=shard.size(), |
| 480 | + ) |
| 481 | + for offset, shard in zip(updated_offsets, padded_shards) |
| 482 | + ] |
| 483 | + |
| 484 | + updated_storage = TensorStorageMetadata( |
| 485 | + properties=original_storage.properties, |
| 486 | + size=new_global_size, |
| 487 | + chunks=updated_chunks, |
| 488 | + ) |
| 489 | + |
| 490 | + return updated_offsets, updated_storage |
| 491 | + |
282 | 492 | @property
|
283 | 493 | def device(self) -> torch._C.device: # type: ignore[override]
|
284 | 494 | return (
|
|
0 commit comments