Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ def _group_tables_per_rank(
),
_prefetch_and_cached(table),
table.use_virtual_table if is_inference else None,
table.enable_embedding_update,
)
# micromanage the order of we traverse the groups to ensure backwards compatibility
if grouping_key not in groups:
Expand All @@ -581,6 +582,7 @@ def _group_tables_per_rank(
_,
_,
use_virtual_table,
enable_embedding_update,
) = grouping_key
grouped_tables = groups[grouping_key]
# remove non-native fused params
Expand All @@ -602,6 +604,7 @@ def _group_tables_per_rank(
compute_kernel=compute_kernel_type,
embedding_tables=grouped_tables,
fused_params=per_tbe_fused_params,
enable_embedding_update=enable_embedding_update,
)
)
return grouped_embedding_configs
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ class GroupedEmbeddingConfig:
compute_kernel: EmbeddingComputeKernel
embedding_tables: List[ShardedEmbeddingTable]
fused_params: Optional[Dict[str, Any]] = None
enable_embedding_update: bool = False

def feature_hash_sizes(self) -> List[int]:
feature_hash_sizes = []
Expand Down
15 changes: 15 additions & 0 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def _shard(
total_num_buckets=info.embedding_config.total_num_buckets,
use_virtual_table=info.embedding_config.use_virtual_table,
virtual_table_eviction_policy=info.embedding_config.virtual_table_eviction_policy,
enable_embedding_update=info.embedding_config.enable_embedding_update,
)
)
return tables_per_rank
Expand Down Expand Up @@ -278,6 +279,20 @@ def _get_feature_hash_sizes(self) -> List[int]:
feature_hash_sizes.extend(group_config.feature_hash_sizes())
return feature_hash_sizes

def _get_num_writable_features(self) -> int:
return sum(
group_config.num_features()
for group_config in self._grouped_embedding_configs
if group_config.enable_embedding_update
)

def _get_writable_feature_hash_sizes(self) -> List[int]:
feature_hash_sizes: List[int] = []
for group_config in self._grouped_embedding_configs:
if group_config.enable_embedding_update:
feature_hash_sizes.extend(group_config.feature_hash_sizes())
return feature_hash_sizes


class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]):
"""
Expand Down
1 change: 1 addition & 0 deletions torchrec/modules/embedding_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ class BaseEmbeddingConfig:
total_num_buckets: Optional[int] = None
use_virtual_table: bool = False
virtual_table_eviction_policy: Optional[VirtualTableEvictionPolicy] = None
enable_embedding_update: bool = False

def get_weight_init_max(self) -> float:
if self.weight_init_max is None:
Expand Down
1 change: 1 addition & 0 deletions torchrec/schema/api_tests/test_embedding_config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class StableEmbeddingBagConfig:
total_num_buckets: Optional[int] = None
use_virtual_table: bool = False
virtual_table_eviction_policy: Optional[VirtualTableEvictionPolicy] = None
enable_embedding_update: bool = False
pooling: PoolingType = PoolingType.SUM


Expand Down
Loading