Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1944,6 +1944,10 @@ def __init__(
assert (
config.is_using_virtual_table
), "Try to create ZeroCollisionKeyValueEmbedding for non virtual tables"
assert embedding_cache_mode == config.enable_embedding_update, (
f"Embedding_cache kernel is {embedding_cache_mode} "
f"but embedding config has enable_embedding_update {config.enable_embedding_update}"
)
for table in config.embedding_tables:
assert table.local_cols % 4 == 0, (
f"table {table.name} has local_cols={table.local_cols} "
Expand Down
119 changes: 117 additions & 2 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,12 @@ def __init__(
# https://github.com/pytorch/pytorch/issues/122788
with record_function(f"## all2all_data:kjt {label} ##"):
if self._pg._get_backend_name() == "custom":
if input_tensor.dim() == 2:
output_size = [sum(output_split), input_tensor.size(1)]
else:
output_size = [sum(output_split)]
output_tensor = torch.empty(
sum(output_split),
output_size,
device=self._device,
dtype=input_tensor.dtype,
)
Expand All @@ -391,8 +395,12 @@ def __init__(
)
self._output_tensors.append(output_tensor)
else:
if input_tensor.dim() == 2:
output_size = [sum(output_split), input_tensor.size(1)]
else:
output_size = [sum(output_split)]
output_tensor = torch.empty(
sum(output_split), device=self._device, dtype=input_tensor.dtype
output_size, device=self._device, dtype=input_tensor.dtype
)
with record_function(f"## all2all_data:kjt {label} ##"):
awaitable = dist.all_to_all_single(
Expand Down Expand Up @@ -542,6 +550,113 @@ def _wait_impl(self) -> KJTAllToAllTensorsAwaitable:
)


class KJEAllToAll(nn.Module):
"""
Redistributes `KeyedJaggedTensor` to a `ProcessGroup` according to splits.

Implementation utilizes AlltoAll collective as part of torch.distributed.

The input provides the necessary tensors, embedding weights and input splits to distribute.
The first collective call in `KJTAllToAllSplitsAwaitable` will transmit output
splits (to allocate correct space for tensors) and batch size per rank. The
following collective calls in `KJTAllToAllTensorsAwaitable` will transmit the actual
tensors asynchronously.
This module is used for embedding updates wherein input KJT weights are updated into the embedding tables.

Args:
pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication.
splits (List[int]): List of len(pg.size()) which indicates how many features to
send to each pg.rank(). It is assumed the `KeyedJaggedTensor` is ordered by
destination rank. Same for all ranks.
stagger (int): stagger value to apply to recat tensor, see `_get_recat` function
for more detail.

Example::

keys=['A','B','C']
splits=[2,1]
kjeA2A = KJEAllToAll(pg, splits)
awaitable = kjeA2A(rank0_input)

# where:
# rank0_input is KeyedJaggedTensor holding

# 0 1 2
# 'A' [A.V0] None [A.V1, A.V2]
# 'B' None [B.V0] [B.V1]
# 'C' [C.V0] [C.V1] None

# rank1_input is KeyedJaggedTensor holding

# 0 1 2
# 'A' [A.V3] [A.V4] None
# 'B' None [B.V2] [B.V3, B.V4]
# 'C' [C.V2] [C.V3] None

Output is None since this is write operation but still awaitable for synchronization
awaitable.wait()

# where input after the distribution is :
# rank0

# 0 1 2 3 4 5
# 'A' [A.V0] None [A.V1, A.V2] [A.V3] [A.V4] None
# 'B' None [B.V0] [B.V1] None [B.V2] [B.V3, B.V4]

# rank1
# 0 1 2 3 4 5
# 'C' [C.V0] [C.V1] None [C.V2] [C.V3] None
"""

def __init__(
self,
pg: dist.ProcessGroup,
splits: List[int],
stagger: int = 1,
) -> None:
super().__init__()
torch._check(len(splits) == pg.size())
self._pg: dist.ProcessGroup = pg
self._splits = splits
self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits))
self._stagger = stagger

def forward(
self, input: KeyedJaggedTensor
) -> Awaitable[KJTAllToAllTensorsAwaitable]:
"""
Sends input to relevant `ProcessGroup` ranks.

The first wait will get the output splits for the provided tensors and issue
tensors AlltoAll. The second wait will wait for the update.

Args:
input (KeyedJaggedTensor): `KeyedJaggedTensor` of values and weights to distribute.

Returns:
Awaitable[KJTAllToAllTensorsAwaitable]: awaitable of a `KJTAllToAllTensorsAwaitable`.
"""

with torch.no_grad():
assert len(input.keys()) == sum(self._splits)
rank = dist.get_rank(self._pg)
local_keys = input.keys()[
self._splits_cumsum[rank] : self._splits_cumsum[rank + 1]
]

return KJTAllToAllSplitsAwaitable(
pg=self._pg,
input=input,
splits=self._splits,
labels=input.dist_labels(),
tensor_splits=input.dist_splits(self._splits),
input_tensors=input.dist_tensors(),
keys=local_keys,
device=input.device(),
stagger=self._stagger,
)


class KJTAllToAll(nn.Module):
"""
Redistributes `KeyedJaggedTensor` to a `ProcessGroup` according to splits.
Expand Down
46 changes: 45 additions & 1 deletion torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,19 @@ def __init__(
for sharding_type, embedding_confings in sharding_type_to_sharding_infos.items()
}

self.enable_embedding_update: bool = any(
config.enable_embedding_update for config in self._embedding_configs
)
self._device = device
self._input_dists: List[nn.Module] = []
self._write_dists: List[nn.Module] = []
self._lookups: List[nn.Module] = []
self._updates: List[nn.Module] = []
self._create_lookups()
self._output_dists: List[nn.Module] = []
self._create_output_dist()

self._write_splits: List[int] = []
self._feature_splits: List[int] = []
self._features_order: List[int] = []

Expand Down Expand Up @@ -631,6 +637,7 @@ def create_grouped_sharding_infos(
total_num_buckets=config.total_num_buckets,
use_virtual_table=config.use_virtual_table,
virtual_table_eviction_policy=config.virtual_table_eviction_policy,
enable_embedding_update=config.enable_embedding_update,
),
param_sharding=parameter_sharding,
param=param,
Expand Down Expand Up @@ -1308,7 +1315,10 @@ def _create_input_dist(

def _create_lookups(self) -> None:
for sharding in self._sharding_type_to_sharding.values():
self._lookups.append(sharding.create_lookup())
lookup = sharding.create_lookup()
if self.enable_embedding_update and sharding.enable_embedding_update:
self._updates.append(sharding.create_update(lookup))
self._lookups.append(lookup)

def _create_output_dist(
self,
Expand Down Expand Up @@ -1627,6 +1637,40 @@ def fused_optimizer(self) -> KeyedOptimizer:
def create_context(self) -> EmbeddingCollectionContext:
return EmbeddingCollectionContext(sharding_contexts=[])

def _create_write_dist(self) -> None:
for sharding in self._sharding_type_to_sharding.values():
if sharding.enable_embedding_update:
self._write_dists.append(sharding.create_write_dist())
self._write_splits.append(sharding._get_num_writable_features())

# pyre-ignore [14]
def write_dist(
self, ctx: EmbeddingCollectionContext, embeddings: KeyedJaggedTensor
) -> Awaitable[Awaitable[KJTList]]:
if not self.enable_embedding_update:
raise ValueError("enable_embedding_update is False for this collection")
if not self._write_dists:
self._create_write_dist()
with torch.no_grad():
embeddings_by_shards = embeddings.split(self._write_splits)
awaitables = []
for write_dist, embeddings in zip(self._write_dists, embeddings_by_shards):
awaitables.append(write_dist(embeddings))

return KJTListSplitsAwaitable(
awaitables,
ctx,
self._module_fqn,
list(self._sharding_type_to_sharding.keys()),
)

def update(self, ctx: EmbeddingCollectionContext, dist_input: KJTList) -> None:
for update, embeddings in zip(
self._updates,
dist_input,
):
update(embeddings)


class EmbeddingCollectionSharder(BaseEmbeddingSharder[EmbeddingCollection]):
def __init__(
Expand Down
51 changes: 44 additions & 7 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
from abc import ABC
from collections import OrderedDict
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -39,6 +39,7 @@
BatchedFusedEmbeddingBag,
KeyValueEmbedding,
KeyValueEmbeddingBag,
ZeroCollisionEmbeddingCache,
ZeroCollisionKeyValueEmbedding,
ZeroCollisionKeyValueEmbeddingBag,
)
Expand All @@ -49,6 +50,7 @@
from torchrec.distributed.embedding_kernel import BaseEmbedding
from torchrec.distributed.embedding_types import (
BaseEmbeddingLookup,
BaseEmbeddingUpdate,
BaseGroupedFeatureProcessor,
EmbeddingComputeKernel,
GroupedEmbeddingConfig,
Expand Down Expand Up @@ -249,12 +251,20 @@ def _create_embedding_kernel(
)
elif config.compute_kernel == EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE:
# for dram kv
return ZeroCollisionKeyValueEmbedding(
config=config,
pg=pg,
device=device,
backend_type=BackendType.DRAM,
)
if config.enable_embedding_update:
return ZeroCollisionEmbeddingCache(
config=config,
pg=pg,
device=device,
backend_type=BackendType.DRAM,
)
else:
return ZeroCollisionKeyValueEmbedding(
config=config,
pg=pg,
device=device,
backend_type=BackendType.DRAM,
)
else:
raise ValueError(f"Compute kernel not supported {config.compute_kernel}")

Expand Down Expand Up @@ -411,6 +421,33 @@ def purge(self) -> None:
emb_module.purge()


class GroupedEmbeddingsUpdate(BaseEmbeddingUpdate[KeyedJaggedTensor]):
"""
Update modules for Sequence embeddings (i.e Embeddings)
"""

def __init__(
self,
grouped_embeddings_lookup: GroupedEmbeddingsLookup,
) -> None:
super().__init__()
self._emb_modules: List[BaseEmbedding] = []
self._feature_splits: List[int] = []
for emb_module in grouped_embeddings_lookup._emb_modules:
emb_module = cast(BaseBatchedEmbedding[torch.Tensor], emb_module)
if emb_module.config.enable_embedding_update:
self._feature_splits.append(emb_module.config.num_features())
self._emb_modules.append(emb_module)

def forward(self, embeddings: KeyedJaggedTensor) -> None:
features_by_group = embeddings.split(
self._feature_splits,
)
for emb_module, features in zip(self._emb_modules, features_by_group):
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
emb_module.update(features)


class CommOpGradientScaling(torch.autograd.Function):
@staticmethod
# pyre-ignore
Expand Down
Loading