Skip to content

Commit c02aac4

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
Implement tensor padding for local shards wrapper (#3382)
Summary: X-link: pytorch/pytorch#163183 This diff implements the constant padding functionality (aten.constant_pad_nd.default) for `LocalShardsWrapper`. The method applies constant padding to the local shards based on the provided padding specification. Depending on the sharding type (RW, CW), the padding on [left, right, top, bottom] directions will be either applied to the first/last shard, or all local shards. New unit tests cover: - 1D (RW) top/bottom paddings - 2D (CW) left, right, top, bottom paddings - empty shards, number of dimensions > 2 Differential Revision: D82663766
1 parent c26367f commit c02aac4

File tree

2 files changed

+573
-6
lines changed

2 files changed

+573
-6
lines changed

torchrec/distributed/shards_wrapper.py

Lines changed: 216 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
# COPY of the code from torch.distributed._tensor._shards_wrapper - for package compat
1111

12+
import logging
1213
from typing import Any, List, Tuple
1314

1415
import torch
@@ -24,6 +25,7 @@
2425
WriteItemType,
2526
)
2627

28+
logger: logging.Logger = logging.getLogger(__name__)
2729
aten = torch.ops.aten # pyre-ignore[5]
2830

2931

@@ -73,7 +75,7 @@ def __new__(
7375
cat_tensor_shape[1] += shard.size()[1]
7476

7577
# in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension
76-
if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding
78+
if len(local_shards) > 1 and local_shards[0].ndim == 1: # row-wise sharding
7779
for shard in local_shards[1:]:
7880
cat_tensor_shape[0] += shard.size()[0]
7981

@@ -119,6 +121,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
119121
aten.copy_.default: cls.handle_copy_,
120122
aten.zeros_like.default: cls.handle_zeros_like,
121123
aten.empty_like.default: cls.handle_empty_like,
124+
aten.constant_pad_nd.default: cls.handle_constant_pad_nd,
122125
}
123126

124127
if func in dispatcher:
@@ -162,12 +165,14 @@ def handle_copy_(args, kwargs):
162165
# pyre-fixme[3]: Return type must be annotated.
163166
# pyre-fixme[2]: Parameter must be annotated.
164167
def handle_all_gather_into_tensor(args, kwargs):
165-
dim = args[0].local_sizes()[0][1]
166-
cat_tensor = torch.cat(
167-
[t.view(-1) for t in args[0].local_shards()], dim=0
168-
).view(-1, dim)
168+
local_shards = args[0].local_shards()
169+
if len(local_shards) == 1:
170+
result_tensor = local_shards[0]
171+
# 2D CW sharding: concat columns, 1D RW sharding: concat rows
172+
result_tensor = torch.cat(local_shards, dim=-1)
173+
logger.info(f"resulting tensor before all gather: {result_tensor}")
169174
return torch.ops._c10d_functional.all_gather_into_tensor.default(
170-
cat_tensor, *args[1:], **kwargs
175+
result_tensor, *args[1:], **kwargs
171176
)
172177

173178
@staticmethod
@@ -279,6 +284,211 @@ def handle_new_empty(args, kwargs):
279284
self_ls.local_offsets(),
280285
)
281286

287+
@staticmethod
288+
# pyre-fixme[3]: Return type must be annotated.
289+
# pyre-fixme[2]: Parameter must be annotated.
290+
def handle_constant_pad_nd(args, kwargs):
291+
"""
292+
Apply constant padding to LocalShardsWrapper.
293+
294+
The padding is based off of the following ideas:
295+
- The resulting wrapper represents the padded version of the logical tensor.
296+
- Each shard is padded based on the sharding type + dimension that is padded.
297+
- For instance, CW shards padded on the left most col will have only padding on the first CW shard.
298+
- Padding the top row will apply to all CW shards.
299+
"""
300+
self_lsw = args[0]
301+
pad_spec = args[1]
302+
pad_value = args[2] if len(args) > 2 else 0.0
303+
logger.info(
304+
f"padding {self_lsw} with {pad_spec} and value: {pad_value}, current shards: {self_lsw.local_shards()} with offsets: {self_lsw.local_offsets()}. tensor storage metadata: {self_lsw.storage_metadata()}"
305+
)
306+
307+
if len(self_lsw.local_shards()) == 0:
308+
raise NotImplementedError(
309+
"Padding empty LocalShardsWrapper is not supported."
310+
)
311+
312+
local_shards = self_lsw.local_shards()
313+
314+
if len(local_shards) == 1:
315+
padded_shard = torch.nn.functional.pad(
316+
local_shards[0], pad_spec, mode="constant", value=pad_value
317+
)
318+
return LocalShardsWrapper([padded_shard], self_lsw.local_offsets())
319+
320+
padded_shards = list(local_shards)
321+
322+
if local_shards[0].ndim == 2:
323+
# 2D Column-wise sharding: [pad_left, pad_right, pad_top, pad_bottom]
324+
if len(pad_spec) == 2:
325+
# Single dimension padding happens on the left most column
326+
pad_spec = pad_spec + [0, 0]
327+
328+
if len(pad_spec) != 4:
329+
raise ValueError(
330+
f"Padding spec must be of length 4 for 2D tensors, got {len(pad_spec)}"
331+
)
332+
333+
pad_left, pad_right, pad_top, pad_bottom = (
334+
pad_spec[0],
335+
pad_spec[1],
336+
pad_spec[2],
337+
pad_spec[3],
338+
)
339+
340+
if pad_top > 0:
341+
padded_shards = [
342+
torch.nn.functional.pad(
343+
shard, [0, 0, pad_top, 0], mode="constant", value=pad_value
344+
)
345+
for shard in padded_shards
346+
]
347+
if pad_bottom > 0:
348+
padded_shards = [
349+
torch.nn.functional.pad(
350+
shard, [0, 0, 0, pad_bottom], mode="constant", value=pad_value
351+
)
352+
for shard in padded_shards
353+
]
354+
if pad_left > 0:
355+
padded_shards[0] = torch.nn.functional.pad(
356+
padded_shards[0],
357+
[pad_left, 0, 0, 0],
358+
mode="constant",
359+
value=pad_value,
360+
)
361+
if pad_right > 0:
362+
padded_shards[-1] = torch.nn.functional.pad(
363+
padded_shards[-1],
364+
[0, pad_right, 0, 0],
365+
mode="constant",
366+
value=pad_value,
367+
)
368+
elif local_shards[0].ndim == 1:
369+
# 1D Row-wise sharding: [pad_top, pad_bottom]
370+
if len(pad_spec) != 2:
371+
raise ValueError(
372+
f"Padding spec must be of length 2 for 1D tensors, got {len(pad_spec)}"
373+
)
374+
pad_top, pad_bottom = pad_spec[0], pad_spec[1]
375+
376+
if pad_top > 0:
377+
padded_shards[0] = torch.nn.functional.pad(
378+
padded_shards[0], [pad_top, 0], mode="constant", value=pad_value
379+
)
380+
if pad_bottom > 0:
381+
padded_shards[-1] = torch.nn.functional.pad(
382+
padded_shards[-1], [0, pad_bottom], mode="constant", value=pad_value
383+
)
384+
else:
385+
raise NotImplementedError(
386+
f"Padding for {local_shards[0].ndim}D tensors is not supported. "
387+
f"Only 1D and 2D tensors are currently supported."
388+
)
389+
390+
# Update offsets and storage metadata
391+
original_storage = self_lsw.storage_metadata()
392+
updated_offsets, updated_storage = LocalShardsWrapper._compute_updated_metadata(
393+
original_storage,
394+
self_lsw.local_offsets(),
395+
pad_spec,
396+
local_shards[0].ndim,
397+
padded_shards,
398+
)
399+
400+
result = LocalShardsWrapper(padded_shards, updated_offsets)
401+
result._storage_meta = updated_storage
402+
return result
403+
404+
@staticmethod
405+
def _compute_updated_metadata(
406+
original_storage: TensorStorageMetadata,
407+
original_offsets: list[torch.Size],
408+
pad_spec: list[int],
409+
ndim: int,
410+
padded_shards: list[torch.Tensor],
411+
) -> tuple[list[tuple[int, ...]], TensorStorageMetadata]:
412+
"""
413+
Compute updated offsets and storage metadata after padding is applied.
414+
415+
Args:
416+
original_storage: Original storage metadata
417+
original_offsets: Original shard offsets
418+
pad_spec: Padding specification
419+
ndim: Number of dimensions (1=RW or 2=CW)
420+
padded_shards: Padded shard tensors
421+
422+
Returns:
423+
Tuple of (updated_offsets, updated_storage_metadata)
424+
"""
425+
if ndim == 1: # 1D RW
426+
pad_top, pad_bottom = pad_spec[0], pad_spec[1]
427+
428+
updated_offsets = []
429+
for i, offset in enumerate(original_offsets):
430+
if i == 0:
431+
# First shard: offset stays the same (absorbs top padding)
432+
updated_offsets.append(tuple(offset))
433+
else:
434+
# Subsequent shards: shift by top padding amount
435+
new_offset = (offset[0] + pad_top,)
436+
updated_offsets.append(new_offset)
437+
438+
new_global_size = torch.Size(
439+
[original_storage.size[0] + pad_top + pad_bottom]
440+
)
441+
442+
elif ndim == 2: # 2D CW
443+
pad_left, pad_right, pad_top, pad_bottom = (
444+
pad_spec[0],
445+
pad_spec[1],
446+
pad_spec[2],
447+
pad_spec[3],
448+
)
449+
450+
updated_offsets = []
451+
for i, offset in enumerate(original_offsets):
452+
row_offset = offset[0]
453+
col_offset = offset[1]
454+
455+
# Top/bottom padding doesn't affect offsets
456+
# Left padding affects column offsets
457+
if i == 0:
458+
# First shard: column offset stays the same (absorbs left padding)
459+
new_2d_offset = (row_offset, col_offset)
460+
else:
461+
# Subsequent shards: shift column offset by left padding amount
462+
new_2d_offset = (row_offset, col_offset + pad_left)
463+
464+
updated_offsets.append(new_2d_offset)
465+
466+
new_global_size = torch.Size(
467+
[
468+
original_storage.size[0] + pad_top + pad_bottom,
469+
original_storage.size[1] + pad_left + pad_right,
470+
]
471+
)
472+
473+
else:
474+
raise NotImplementedError(f"Metadata computation for {ndim}D not supported")
475+
476+
updated_chunks = [
477+
ChunkStorageMetadata(
478+
offsets=torch.Size(offset),
479+
sizes=shard.size(),
480+
)
481+
for offset, shard in zip(updated_offsets, padded_shards)
482+
]
483+
484+
updated_storage = TensorStorageMetadata(
485+
properties=original_storage.properties,
486+
size=new_global_size,
487+
chunks=updated_chunks,
488+
)
489+
490+
return updated_offsets, updated_storage
491+
282492
@property
283493
def device(self) -> torch._C.device: # type: ignore[override]
284494
return (

0 commit comments

Comments
 (0)