Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ def _populate_zero_collision_tbe_params(
tbe_params.pop("kvzch_eviction_trigger_mode")
else:
eviction_trigger_mode = 2 # 2 means mem_util based eviction

fs_evcition_enabled: bool = False
for i, table in enumerate(config.embedding_tables):
policy_t = table.virtual_table_eviction_policy
if policy_t is not None:
Expand Down Expand Up @@ -369,6 +371,7 @@ def _populate_zero_collision_tbe_params(
raise ValueError(
f"Do not support multiple eviction strategy in one tbe {eviction_strategy} and 5 for tables {table_names}"
)
fs_evcition_enabled = True
elif isinstance(policy_t, TimestampBasedEvictionPolicy):
training_id_eviction_trigger_count[i] = (
policy_t.training_id_eviction_trigger_count
Expand Down Expand Up @@ -440,6 +443,7 @@ def _populate_zero_collision_tbe_params(
backend_return_whole_row=(backend_type == BackendType.DRAM),
eviction_policy=eviction_policy,
embedding_cache_mode=embedding_cache_mode_,
feature_score_collection_enabled=fs_evcition_enabled,
)


Expand Down Expand Up @@ -2872,6 +2876,7 @@ def __init__(
_populate_zero_collision_tbe_params(
ssd_tbe_params, self._bucket_spec, config, backend_type
)
self._kv_zch_params: KVZCHParams = ssd_tbe_params["kv_zch_params"]
compute_kernel = config.embedding_tables[0].compute_kernel
embedding_location = compute_kernel_to_embedding_location(compute_kernel)

Expand Down Expand Up @@ -3155,7 +3160,40 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
self._split_weights_res = None
self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None)

return super().forward(features)
weights = features.weights_or_none()
per_sample_weights = None
score_weights = None
if weights is not None and weights.dtype == torch.float64:
fp32_weights = weights.view(torch.float32)
per_sample_weights = fp32_weights[:, 0]
score_weights = fp32_weights[:, 1]
elif weights is not None and weights.dtype == torch.float32:
if self._kv_zch_params.feature_score_collection_enabled:
score_weights = weights.view(-1)
else:
per_sample_weights = weights.view(-1)
if features.variable_stride_per_key() and isinstance(
self.emb_module,
(
SplitTableBatchedEmbeddingBagsCodegen,
DenseTableBatchedEmbeddingBagsCodegen,
SSDTableBatchedEmbeddingBags,
),
):
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
weights=score_weights,
per_sample_weights=per_sample_weights,
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
)
else:
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
weights=score_weights,
per_sample_weights=per_sample_weights,
)


class BatchedFusedEmbeddingBag(
Expand Down
53 changes: 33 additions & 20 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
ShardedEmbeddingModule,
ShardingType,
)

from torchrec.distributed.feature_score_utils import (
create_sharding_type_to_feature_score_mapping,
may_collect_feature_scores,
)
from torchrec.distributed.fused_params import (
FUSED_PARAM_IS_SSD_TABLE,
FUSED_PARAM_SSD_TABLE_LIST,
Expand Down Expand Up @@ -90,7 +95,6 @@
from torchrec.modules.embedding_configs import (
EmbeddingConfig,
EmbeddingTableConfig,
FeatureScoreBasedEvictionPolicy,
PoolingType,
)
from torchrec.modules.embedding_modules import (
Expand Down Expand Up @@ -460,12 +464,12 @@ def __init__(
] = {
sharding_type: self.create_embedding_sharding(
sharding_type=sharding_type,
sharding_infos=embedding_confings,
sharding_infos=embedding_configs,
env=env,
device=device,
qcomm_codecs_registry=self.qcomm_codecs_registry,
)
for sharding_type, embedding_confings in sharding_type_to_sharding_infos.items()
for sharding_type, embedding_configs in sharding_type_to_sharding_infos.items()
}

self.enable_embedding_update: bool = any(
Expand All @@ -487,16 +491,20 @@ def __init__(
self._has_uninitialized_input_dist: bool = True
logger.info(f"EC index dedup enabled: {self._use_index_dedup}.")

for config in self._embedding_configs:
virtual_table_eviction_policy = config.virtual_table_eviction_policy
if virtual_table_eviction_policy is not None and isinstance(
virtual_table_eviction_policy, FeatureScoreBasedEvictionPolicy
):
self._enable_feature_score_weight_accumulation = True
break

self._enable_feature_score_weight_accumulation: bool = False
self._enabled_feature_score_auto_collection: bool = False
self._sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {}
(
self._enable_feature_score_weight_accumulation,
self._enabled_feature_score_auto_collection,
self._sharding_type_feature_score_mapping,
) = create_sharding_type_to_feature_score_mapping(
self._embedding_configs, sharding_type_to_sharding_infos
)
logger.info(
f"EC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}."
f"EC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}, "
f"auto collection enabled: {self._enabled_feature_score_auto_collection}, "
f"sharding type to feature score mapping: {self._sharding_type_feature_score_mapping}"
)

# Get all fused optimizers and combine them.
Expand Down Expand Up @@ -1357,22 +1365,22 @@ def _dedup_indices(
source_weights.dtype == torch.float32
), "Only float32 weights are supported for feature score eviction weights."

acc_weights = torch.ops.fbgemm.jagged_acc_weights_and_counts(
source_weights.view(-1),
reverse_indices,
# Accumulate weights using scatter_add
acc_weights = torch.zeros(
unique_indices.numel(),
dtype=torch.float32,
device=source_weights.device,
)

# Use PyTorch's scatter_add to accumulate weights
acc_weights.scatter_add_(0, reverse_indices, source_weights)

dedup_features = KeyedJaggedTensor(
keys=input_feature.keys(),
lengths=lengths,
offsets=offsets,
values=unique_indices,
weights=(
acc_weights.view(torch.float64).view(-1)
if acc_weights is not None
else None
),
weights=(acc_weights.view(-1) if acc_weights is not None else None),
)

ctx.input_features.append(input_feature)
Expand Down Expand Up @@ -1491,6 +1499,11 @@ def input_dist(
self._features_order_tensor,
)
features_by_shards = features.split(self._feature_splits)
features_by_shards = may_collect_feature_scores(
features_by_shards,
self._enabled_feature_score_auto_collection,
self._sharding_type_feature_score_mapping,
)
if self._use_index_dedup:
features_by_shards = self._dedup_indices(ctx, features_by_shards)

Expand Down
39 changes: 35 additions & 4 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
QuantBatchedEmbeddingBag,
)
from torchrec.distributed.types import rank_device, ShardedTensor, ShardingType
from torchrec.modules.embedding_configs import FeatureScoreBasedEvictionPolicy
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -490,6 +491,23 @@ def __init__(
) -> None:
super().__init__()
self._emb_modules: nn.ModuleList = nn.ModuleList()
self._feature_score_auto_collections: List[bool] = []
for config in grouped_configs:
collection = False
for table in config.embedding_tables:
if table.use_virtual_table and isinstance(
table.virtual_table_eviction_policy, FeatureScoreBasedEvictionPolicy
):
if (
table.virtual_table_eviction_policy.enable_auto_feature_score_collection
):
collection = True
self._feature_score_auto_collections.append(collection)

logger.info(
f"GroupedPooledEmbeddingsLookup: {self._feature_score_auto_collections=}"
)

for config in grouped_configs:
self._emb_modules.append(
self._create_embedding_kernel(config, device, pg, sharding_type)
Expand Down Expand Up @@ -663,8 +681,11 @@ def forward(
features_by_group = sparse_features.split(
self._feature_splits,
)
for config, emb_op, features in zip(
self.grouped_configs, self._emb_modules, features_by_group
for config, emb_op, features, fs_auto_collection in zip(
self.grouped_configs,
self._emb_modules,
features_by_group,
self._feature_score_auto_collections,
):
if (
config.has_feature_processor
Expand All @@ -674,9 +695,19 @@ def forward(
features = self._feature_processor(features)

if config.is_weighted:
features._weights = CommOpGradientScaling.apply(
feature_weights = CommOpGradientScaling.apply(
features._weights, self._scale_gradient_factor
)
).float()

if fs_auto_collection and features.weights_or_none() is not None:
score_weights = features.weights().float()
assert (
feature_weights.numel() == score_weights.numel()
), f"feature_weights.numel() {feature_weights.numel()} != score_weights.numel() {score_weights.numel()}"
cat_weights = torch.cat(
[feature_weights, score_weights], dim=1
).view(torch.float64)
features._weights = cat_weights

embeddings.append(emb_op(features))

Expand Down
27 changes: 27 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
KJTList,
ShardedEmbeddingModule,
)
from torchrec.distributed.feature_score_utils import (
create_sharding_type_to_feature_score_mapping,
may_collect_feature_scores,
)
from torchrec.distributed.fused_params import (
FUSED_PARAM_IS_SSD_TABLE,
FUSED_PARAM_SSD_TABLE_LIST,
Expand Down Expand Up @@ -565,6 +569,24 @@ def __init__(
# forward pass flow control
self._has_uninitialized_input_dist: bool = True
self._has_features_permute: bool = True

self._enable_feature_score_weight_accumulation: bool = False
self._enabled_feature_score_auto_collection: bool = False
self._sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {}
(
self._enable_feature_score_weight_accumulation,
self._enabled_feature_score_auto_collection,
self._sharding_type_feature_score_mapping,
) = create_sharding_type_to_feature_score_mapping(
self._embedding_bag_configs, self.sharding_type_to_sharding_infos
)

logger.info(
f"EBC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}, "
f"auto collection enabled: {self._enabled_feature_score_auto_collection}, "
f"sharding type to feature score mapping: {self._sharding_type_feature_score_mapping}"
)

# Get all fused optimizers and combine them.
optims = []
for lookup in self._lookups:
Expand Down Expand Up @@ -1565,6 +1587,11 @@ def input_dist(
features_by_shards = features.split(
self._feature_splits,
)
features_by_shards = may_collect_feature_scores(
features_by_shards,
self._enabled_feature_score_auto_collection,
self._sharding_type_feature_score_mapping,
)
awaitables = []
for input_dist, features_by_shard, sharding_type in zip(
self._input_dists,
Expand Down
Loading
Loading