Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 216 additions & 6 deletions torchrec/distributed/shards_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# COPY of the code from torch.distributed._tensor._shards_wrapper - for package compat

import logging
from typing import Any, List, Tuple

import torch
Expand All @@ -24,6 +25,7 @@
WriteItemType,
)

logger: logging.Logger = logging.getLogger(__name__)
aten = torch.ops.aten # pyre-ignore[5]


Expand Down Expand Up @@ -73,7 +75,7 @@ def __new__(
cat_tensor_shape[1] += shard.size()[1]

# in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension
if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding
if len(local_shards) > 1 and local_shards[0].ndim == 1: # row-wise sharding
for shard in local_shards[1:]:
cat_tensor_shape[0] += shard.size()[0]

Expand Down Expand Up @@ -119,6 +121,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
aten.copy_.default: cls.handle_copy_,
aten.zeros_like.default: cls.handle_zeros_like,
aten.empty_like.default: cls.handle_empty_like,
aten.constant_pad_nd.default: cls.handle_constant_pad_nd,
}

if func in dispatcher:
Expand Down Expand Up @@ -162,12 +165,14 @@ def handle_copy_(args, kwargs):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def handle_all_gather_into_tensor(args, kwargs):
dim = args[0].local_sizes()[0][1]
cat_tensor = torch.cat(
[t.view(-1) for t in args[0].local_shards()], dim=0
).view(-1, dim)
local_shards = args[0].local_shards()
if len(local_shards) == 1:
result_tensor = local_shards[0]
# 2D CW sharding: concat columns, 1D RW sharding: concat rows
result_tensor = torch.cat(local_shards, dim=-1)
logger.info(f"resulting tensor before all gather: {result_tensor}")
return torch.ops._c10d_functional.all_gather_into_tensor.default(
cat_tensor, *args[1:], **kwargs
result_tensor, *args[1:], **kwargs
)

@staticmethod
Expand Down Expand Up @@ -279,6 +284,211 @@ def handle_new_empty(args, kwargs):
self_ls.local_offsets(),
)

@staticmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def handle_constant_pad_nd(args, kwargs):
"""
Apply constant padding to LocalShardsWrapper.

The padding is based off of the following ideas:
- The resulting wrapper represents the padded version of the logical tensor.
- Each shard is padded based on the sharding type + dimension that is padded.
- For instance, CW shards padded on the left most col will have only padding on the first CW shard.
- Padding the top row will apply to all CW shards.
"""
self_lsw = args[0]
pad_spec = args[1]
pad_value = args[2] if len(args) > 2 else 0.0
logger.info(
f"padding {self_lsw} with {pad_spec} and value: {pad_value}, current shards: {self_lsw.local_shards()} with offsets: {self_lsw.local_offsets()}. tensor storage metadata: {self_lsw.storage_metadata()}"
)

if len(self_lsw.local_shards()) == 0:
raise NotImplementedError(
"Padding empty LocalShardsWrapper is not supported."
)

local_shards = self_lsw.local_shards()

if len(local_shards) == 1:
padded_shard = torch.nn.functional.pad(
local_shards[0], pad_spec, mode="constant", value=pad_value
)
return LocalShardsWrapper([padded_shard], self_lsw.local_offsets())

padded_shards = list(local_shards)

if local_shards[0].ndim == 2:
# 2D Column-wise sharding: [pad_left, pad_right, pad_top, pad_bottom]
if len(pad_spec) == 2:
# Single dimension padding happens on the left most column
pad_spec = pad_spec + [0, 0]

if len(pad_spec) != 4:
raise ValueError(
f"Padding spec must be of length 4 for 2D tensors, got {len(pad_spec)}"
)

pad_left, pad_right, pad_top, pad_bottom = (
pad_spec[0],
pad_spec[1],
pad_spec[2],
pad_spec[3],
)

if pad_top > 0:
padded_shards = [
torch.nn.functional.pad(
shard, [0, 0, pad_top, 0], mode="constant", value=pad_value
)
for shard in padded_shards
]
if pad_bottom > 0:
padded_shards = [
torch.nn.functional.pad(
shard, [0, 0, 0, pad_bottom], mode="constant", value=pad_value
)
for shard in padded_shards
]
if pad_left > 0:
padded_shards[0] = torch.nn.functional.pad(
padded_shards[0],
[pad_left, 0, 0, 0],
mode="constant",
value=pad_value,
)
if pad_right > 0:
padded_shards[-1] = torch.nn.functional.pad(
padded_shards[-1],
[0, pad_right, 0, 0],
mode="constant",
value=pad_value,
)
elif local_shards[0].ndim == 1:
# 1D Row-wise sharding: [pad_top, pad_bottom]
if len(pad_spec) != 2:
raise ValueError(
f"Padding spec must be of length 2 for 1D tensors, got {len(pad_spec)}"
)
pad_top, pad_bottom = pad_spec[0], pad_spec[1]

if pad_top > 0:
padded_shards[0] = torch.nn.functional.pad(
padded_shards[0], [pad_top, 0], mode="constant", value=pad_value
)
if pad_bottom > 0:
padded_shards[-1] = torch.nn.functional.pad(
padded_shards[-1], [0, pad_bottom], mode="constant", value=pad_value
)
else:
raise NotImplementedError(
f"Padding for {local_shards[0].ndim}D tensors is not supported. "
f"Only 1D and 2D tensors are currently supported."
)

# Update offsets and storage metadata
original_storage = self_lsw.storage_metadata()
updated_offsets, updated_storage = LocalShardsWrapper._compute_updated_metadata(
original_storage,
self_lsw.local_offsets(),
pad_spec,
local_shards[0].ndim,
padded_shards,
)

result = LocalShardsWrapper(padded_shards, updated_offsets)
result._storage_meta = updated_storage
return result

@staticmethod
def _compute_updated_metadata(
original_storage: TensorStorageMetadata,
original_offsets: list[torch.Size],
pad_spec: list[int],
ndim: int,
padded_shards: list[torch.Tensor],
) -> tuple[list[tuple[int, ...]], TensorStorageMetadata]:
"""
Compute updated offsets and storage metadata after padding is applied.

Args:
original_storage: Original storage metadata
original_offsets: Original shard offsets
pad_spec: Padding specification
ndim: Number of dimensions (1=RW or 2=CW)
padded_shards: Padded shard tensors

Returns:
Tuple of (updated_offsets, updated_storage_metadata)
"""
if ndim == 1: # 1D RW
pad_top, pad_bottom = pad_spec[0], pad_spec[1]

updated_offsets = []
for i, offset in enumerate(original_offsets):
if i == 0:
# First shard: offset stays the same (absorbs top padding)
updated_offsets.append(tuple(offset))
else:
# Subsequent shards: shift by top padding amount
new_offset = (offset[0] + pad_top,)
updated_offsets.append(new_offset)

new_global_size = torch.Size(
[original_storage.size[0] + pad_top + pad_bottom]
)

elif ndim == 2: # 2D CW
pad_left, pad_right, pad_top, pad_bottom = (
pad_spec[0],
pad_spec[1],
pad_spec[2],
pad_spec[3],
)

updated_offsets = []
for i, offset in enumerate(original_offsets):
row_offset = offset[0]
col_offset = offset[1]

# Top/bottom padding doesn't affect offsets
# Left padding affects column offsets
if i == 0:
# First shard: column offset stays the same (absorbs left padding)
new_2d_offset = (row_offset, col_offset)
else:
# Subsequent shards: shift column offset by left padding amount
new_2d_offset = (row_offset, col_offset + pad_left)

updated_offsets.append(new_2d_offset)

new_global_size = torch.Size(
[
original_storage.size[0] + pad_top + pad_bottom,
original_storage.size[1] + pad_left + pad_right,
]
)

else:
raise NotImplementedError(f"Metadata computation for {ndim}D not supported")

updated_chunks = [
ChunkStorageMetadata(
offsets=torch.Size(offset),
sizes=shard.size(),
)
for offset, shard in zip(updated_offsets, padded_shards)
]

updated_storage = TensorStorageMetadata(
properties=original_storage.properties,
size=new_global_size,
chunks=updated_chunks,
)

return updated_offsets, updated_storage

@property
def device(self) -> torch._C.device: # type: ignore[override]
return (
Expand Down
Loading
Loading