diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 18fde7d87..ba0d522d1 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -565,6 +565,7 @@ def _group_tables_per_rank( ), _prefetch_and_cached(table), table.use_virtual_table if is_inference else None, + table.enable_embedding_update, ) # micromanage the order of we traverse the groups to ensure backwards compatibility if grouping_key not in groups: @@ -581,6 +582,7 @@ def _group_tables_per_rank( _, _, use_virtual_table, + enable_embedding_update, ) = grouping_key grouped_tables = groups[grouping_key] # remove non-native fused params @@ -602,6 +604,7 @@ def _group_tables_per_rank( compute_kernel=compute_kernel_type, embedding_tables=grouped_tables, fused_params=per_tbe_fused_params, + enable_embedding_update=enable_embedding_update, ) ) return grouped_embedding_configs diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index b346222fb..f461d222a 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -251,6 +251,7 @@ class GroupedEmbeddingConfig: compute_kernel: EmbeddingComputeKernel embedding_tables: List[ShardedEmbeddingTable] fused_params: Optional[Dict[str, Any]] = None + enable_embedding_update: bool = False def feature_hash_sizes(self) -> List[int]: feature_hash_sizes = [] diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index 2eafebe49..cb16822c1 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -223,6 +223,7 @@ def _shard( total_num_buckets=info.embedding_config.total_num_buckets, use_virtual_table=info.embedding_config.use_virtual_table, virtual_table_eviction_policy=info.embedding_config.virtual_table_eviction_policy, + enable_embedding_update=info.embedding_config.enable_embedding_update, ) ) return tables_per_rank @@ -278,6 +279,20 @@ def _get_feature_hash_sizes(self) -> List[int]: feature_hash_sizes.extend(group_config.feature_hash_sizes()) return feature_hash_sizes + def _get_num_writable_features(self) -> int: + return sum( + group_config.num_features() + for group_config in self._grouped_embedding_configs + if group_config.enable_embedding_update + ) + + def _get_writable_feature_hash_sizes(self) -> List[int]: + feature_hash_sizes: List[int] = [] + for group_config in self._grouped_embedding_configs: + if group_config.enable_embedding_update: + feature_hash_sizes.extend(group_config.feature_hash_sizes()) + return feature_hash_sizes + class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]): """ diff --git a/torchrec/modules/embedding_configs.py b/torchrec/modules/embedding_configs.py index c1bd6325b..ebb41a6fe 100644 --- a/torchrec/modules/embedding_configs.py +++ b/torchrec/modules/embedding_configs.py @@ -370,6 +370,7 @@ class BaseEmbeddingConfig: total_num_buckets: Optional[int] = None use_virtual_table: bool = False virtual_table_eviction_policy: Optional[VirtualTableEvictionPolicy] = None + enable_embedding_update: bool = False def get_weight_init_max(self) -> float: if self.weight_init_max is None: diff --git a/torchrec/schema/api_tests/test_embedding_config_schema.py b/torchrec/schema/api_tests/test_embedding_config_schema.py index ee938ecf2..388132ced 100644 --- a/torchrec/schema/api_tests/test_embedding_config_schema.py +++ b/torchrec/schema/api_tests/test_embedding_config_schema.py @@ -43,6 +43,7 @@ class StableEmbeddingBagConfig: total_num_buckets: Optional[int] = None use_virtual_table: bool = False virtual_table_eviction_policy: Optional[VirtualTableEvictionPolicy] = None + enable_embedding_update: bool = False pooling: PoolingType = PoolingType.SUM