Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
Shard,
ShardedTensor,
ShardedTensorMetadata,
ShardingType,
ShardMetadata,
TensorProperties,
)
Expand Down Expand Up @@ -720,13 +721,16 @@ def __init__(
config: GroupedEmbeddingConfig,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
sharding_type: Optional[ShardingType] = None,
) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}")
self._config = config
self._pg = pg

self._pooling: PoolingMode = pooling_type_to_pooling_mode(config.pooling)
self._pooling: PoolingMode = pooling_type_to_pooling_mode(
config.pooling, sharding_type # pyre-ignore[6]
)

self._local_rows: List[int] = []
self._weight_init_mins: List[float] = []
Expand Down Expand Up @@ -859,8 +863,9 @@ def __init__(
config: GroupedEmbeddingConfig,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
sharding_type: Optional[ShardingType] = None,
) -> None:
super().__init__(config, pg, device)
super().__init__(config, pg, device, sharding_type)

managed: List[EmbeddingLocation] = []
compute_devices: List[ComputeDevice] = []
Expand Down Expand Up @@ -962,8 +967,9 @@ def __init__(
config: GroupedEmbeddingConfig,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
sharding_type: Optional[ShardingType] = None,
) -> None:
super().__init__(config, pg, device)
super().__init__(config, pg, device, sharding_type)

weights_precision = data_type_to_sparse_type(config.data_type)
fused_params = config.fused_params or {}
Expand Down
8 changes: 6 additions & 2 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
QuantBatchedEmbedding,
QuantBatchedEmbeddingBag,
)
from torchrec.distributed.types import ShardedTensor
from torchrec.distributed.types import ShardedTensor, ShardingType
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -344,23 +344,27 @@ def __init__(
pg: Optional[dist.ProcessGroup] = None,
feature_processor: Optional[BaseGroupedFeatureProcessor] = None,
scale_weight_gradients: bool = True,
sharding_type: Optional[ShardingType] = None,
) -> None:
# TODO rename to _create_embedding_kernel
def _create_lookup(
config: GroupedEmbeddingConfig,
device: Optional[torch.device] = None,
sharding_type: Optional[ShardingType] = None,
) -> BaseEmbedding:
if config.compute_kernel == EmbeddingComputeKernel.DENSE:
return BatchedDenseEmbeddingBag(
config=config,
pg=pg,
device=device,
sharding_type=sharding_type,
)
elif config.compute_kernel == EmbeddingComputeKernel.FUSED:
return BatchedFusedEmbeddingBag(
config=config,
pg=pg,
device=device,
sharding_type=sharding_type,
)
else:
raise ValueError(
Expand All @@ -370,7 +374,7 @@ def _create_lookup(
super().__init__()
self._emb_modules: nn.ModuleList = nn.ModuleList()
for config in grouped_configs:
self._emb_modules.append(_create_lookup(config, device))
self._emb_modules.append(_create_lookup(config, device, sharding_type))

self._feature_splits: List[int] = []
for config in grouped_configs:
Expand Down
189 changes: 177 additions & 12 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
# pyre-strict

import copy
from collections import OrderedDict
from collections import defaultdict, OrderedDict
from dataclasses import dataclass, field
from typing import (
Any,
Callable,
cast,
Dict,
Iterator,
Expand All @@ -27,6 +28,7 @@
import torch
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from torch import nn, Tensor
from torch.autograd.profiler import record_function
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.embedding_sharding import (
Expand Down Expand Up @@ -79,7 +81,7 @@
)
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -378,6 +380,7 @@ class EmbeddingBagCollectionContext(Multistreamable):
)
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None
variable_batch_per_feature: bool = False
mean_pooling_callback: Optional[Callable[[KeyedTensor], KeyedTensor]] = None

def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for ctx in self.sharding_contexts:
Expand Down Expand Up @@ -415,13 +418,22 @@ def __init__(
self._embedding_bag_configs: List[EmbeddingBagConfig] = (
module.embedding_bag_configs()
)
self._table_names: List[str] = [
config.name for config in self._embedding_bag_configs
]

self._table_name_to_config: Dict[str, EmbeddingBagConfig] = {
config.name: config for config in self._embedding_bag_configs
}
self._table_names: List[str] = []
self._pooling_type_to_rs_features: Dict[str, List[str]] = defaultdict(list)
self._table_name_to_config: Dict[str, EmbeddingBagConfig] = {}

for config in self._embedding_bag_configs:
self._table_names.append(config.name)
self._table_name_to_config[config.name] = config

if table_name_to_parameter_sharding[config.name].sharding_type in [
ShardingType.TABLE_ROW_WISE.value,
ShardingType.ROW_WISE.value,
]:
self._pooling_type_to_rs_features[config.pooling.value].extend(
config.feature_names
)

self.module_sharding_plan: EmbeddingModuleShardingPlan = cast(
EmbeddingModuleShardingPlan,
Expand Down Expand Up @@ -472,6 +484,16 @@ def __init__(
self._uncombined_embedding_names: List[str] = []
self._uncombined_embedding_dims: List[int] = []
self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
# to support mean pooling callback hook
self._has_mean_pooling_callback: bool = (
True
if PoolingType.MEAN.value in self._pooling_type_to_rs_features
else False
)
self._dim_per_key: Optional[torch.Tensor] = None
self._kjt_key_indices: Dict[str, int] = {}
self._kjt_inverse_order: Optional[torch.Tensor] = None
self._kt_key_ordering: Optional[torch.Tensor] = None
# to support the FP16 hook
self._create_output_dist()

Expand Down Expand Up @@ -720,6 +742,38 @@ def _create_input_dist(
persistent=False,
)

def _init_mean_pooling_callback(
self,
input_feature_names: List[str],
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],
) -> None:
# account for shared features
feature_names: List[str] = [
feature_name
for sharding in self._sharding_type_to_sharding.values()
for feature_name in sharding.feature_names()
]

for i, key in enumerate(feature_names):
if key not in self._kjt_key_indices: # index of first occurence
self._kjt_key_indices[key] = i

keyed_tensor_ordering = []
for key in self._embedding_names:
if "@" in key:
key = key.split("@")[0]
keyed_tensor_ordering.append(self._kjt_key_indices[key])
self._kt_key_ordering = torch.tensor(keyed_tensor_ordering, device=self._device)

if inverse_indices:
key_to_inverse_index = {
name: i for i, name in enumerate(inverse_indices[0])
}
self._kjt_inverse_order = torch.tensor(
[key_to_inverse_index[key] for key in feature_names],
device=self._device,
)

def _create_lookups(
self,
) -> None:
Expand All @@ -737,6 +791,7 @@ def _create_output_dist(self) -> None:
)
self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims())
embedding_shard_metadata.extend(sharding.embedding_shard_metadata())
self._dim_per_key = torch.tensor(self._embedding_dims, device=self._device)
embedding_shard_offsets: List[int] = [
meta.shard_offsets[1] if meta is not None else 0
for meta in embedding_shard_metadata
Expand Down Expand Up @@ -789,12 +844,31 @@ def input_dist(
self._has_uninitialized_input_dist = False
if ctx.variable_batch_per_feature:
self._create_inverse_indices_permute_indices(ctx.inverse_indices)
if self._has_mean_pooling_callback:
self._init_mean_pooling_callback(features.keys(), ctx.inverse_indices)
with torch.no_grad():
if self._has_features_permute:
features = features.permute(
self._features_order,
self._features_order_tensor,
)
if self._has_mean_pooling_callback:
ctx.mean_pooling_callback = _create_mean_pooling_callback(
lengths=features.lengths(),
stride=features.stride(),
keys=features.keys(),
pooling_type_to_rs_features=self._pooling_type_to_rs_features,
stride_per_key=features.stride_per_key(),
dim_per_key=self._dim_per_key, # pyre-ignore[6]
embedding_names=self._embedding_names,
embedding_dims=self._embedding_dims,
variable_batch_per_feature=ctx.variable_batch_per_feature,
kjt_inverse_order=self._kjt_inverse_order, # pyre-ignore[6]
kjt_key_indices=self._kjt_key_indices,
kt_key_ordering=self._kt_key_ordering, # pyre-ignore[6]
inverse_indices=ctx.inverse_indices,
)

features_by_shards = features.split(
self._feature_splits,
)
Expand Down Expand Up @@ -840,7 +914,7 @@ def output_dist(
assert (
ctx.inverse_indices is not None
), "inverse indices must be provided from KJT if using variable batch size per feature."
return VariableBatchEmbeddingBagCollectionAwaitable(
awaitable = VariableBatchEmbeddingBagCollectionAwaitable(
awaitables=awaitables,
inverse_indices=ctx.inverse_indices,
inverse_indices_permute_indices=self._inverse_indices_permute_indices,
Expand All @@ -851,12 +925,18 @@ def output_dist(
permute_op=self._permute_op,
)
else:
return EmbeddingBagCollectionAwaitable(
awaitable = EmbeddingBagCollectionAwaitable(
awaitables=awaitables,
embedding_dims=self._embedding_dims,
embedding_names=self._embedding_names,
)

# register callback if there are features that need mean pooling
if self._has_mean_pooling_callback:
awaitable.callbacks.append(ctx.mean_pooling_callback)

return awaitable

def compute_and_output_dist(
self, ctx: EmbeddingBagCollectionContext, input: KJTList
) -> LazyAwaitable[KeyedTensor]:
Expand All @@ -879,7 +959,7 @@ def compute_and_output_dist(
assert (
ctx.inverse_indices is not None
), "inverse indices must be provided from KJT if using variable batch size per feature."
return VariableBatchEmbeddingBagCollectionAwaitable(
awaitable = VariableBatchEmbeddingBagCollectionAwaitable(
awaitables=awaitables,
inverse_indices=ctx.inverse_indices,
inverse_indices_permute_indices=self._inverse_indices_permute_indices,
Expand All @@ -890,12 +970,18 @@ def compute_and_output_dist(
permute_op=self._permute_op,
)
else:
return EmbeddingBagCollectionAwaitable(
awaitable = EmbeddingBagCollectionAwaitable(
awaitables=awaitables,
embedding_dims=self._embedding_dims,
embedding_names=self._embedding_names,
)

# register callback if there are features that need mean pooling
if self._has_mean_pooling_callback:
awaitable.callbacks.append(ctx.mean_pooling_callback)

return awaitable

@property
def fused_optimizer(self) -> KeyedOptimizer:
return self._optim
Expand Down Expand Up @@ -1166,3 +1252,82 @@ def shardable_parameters(self, module: nn.EmbeddingBag) -> Dict[str, nn.Paramete
@property
def module_type(self) -> Type[nn.EmbeddingBag]:
return nn.EmbeddingBag


def _create_mean_pooling_callback(
lengths: torch.Tensor,
keys: List[str],
stride: int,
stride_per_key: List[int],
dim_per_key: torch.Tensor,
pooling_type_to_rs_features: Dict[str, List[str]],
embedding_names: List[str],
embedding_dims: List[int],
variable_batch_per_feature: bool,
kjt_inverse_order: torch.Tensor,
kjt_key_indices: Dict[str, int],
kt_key_ordering: torch.Tensor,
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
) -> Callable[[KeyedTensor], KeyedTensor]:
with record_function("## ebc create mean pooling callback ##"):
batch_size = (
inverse_indices[1].size(dim=1) if variable_batch_per_feature else stride # pyre-ignore[16]
)

if variable_batch_per_feature:
device = inverse_indices[1].device
inverse_indices_t = inverse_indices[1]
if len(keys) != len(inverse_indices[0]):
inverse_indices_t = torch.index_select(
inverse_indices[1], 0, kjt_inverse_order
)
offsets = _to_offsets(torch.tensor(stride_per_key, device=device))[
:-1
].unsqueeze(-1)
indices = (inverse_indices_t + offsets).flatten()
lengths = torch.index_select(input=lengths, dim=0, index=indices)

# only convert the sum pooling features to be 1 lengths
for feature in pooling_type_to_rs_features[PoolingType.SUM.value]:
feature_index = kjt_key_indices[feature]
feature_index = feature_index * batch_size
lengths[feature_index : feature_index + batch_size] = 1

if len(embedding_names) != len(keys):
lengths = torch.index_select(
lengths.reshape(-1, batch_size),
0,
kt_key_ordering,
).reshape(-1)

# transpose to align features with keyed tensor dim_per_key
lengths = lengths.reshape(-1, batch_size).T # [batch_size, num_features]
output_size = sum(embedding_dims)

divisor = torch.repeat_interleave(
input=lengths,
repeats=dim_per_key,
dim=1,
output_size=output_size,
)
eps = 1e-6 # used to safe guard against 0 division
divisor = divisor + eps

# pyre-ignore[53]
def _apply_mean_pooling(keyed_tensor: KeyedTensor) -> KeyedTensor:
"""
Apply mean pooling to pooled embeddings in RW/TWRW sharding schemes.
This function is applied as a callback to the awaitable
"""
with record_function("## ebc apply mean pooling ##"):
mean_pooled_values = (
keyed_tensor.values() / divisor
) # [batch size, num_features * embedding dim]
return KeyedTensor(
keys=keyed_tensor.keys(),
values=mean_pooled_values,
length_per_key=keyed_tensor.length_per_key(),
key_dim=1,
)

return _apply_mean_pooling
Loading