Skip to content

Commit 4627d30

Browse files
kausvfacebook-github-bot
authored andcommitted
Add configs for write dist (#3346)
Summary: Rollback Plan: Differential Revision: D81366596
1 parent a07bc63 commit 4627d30

File tree

5 files changed

+24
-0
lines changed

5 files changed

+24
-0
lines changed

torchrec/distributed/embedding_sharding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ def _group_tables_per_rank(
565565
),
566566
_prefetch_and_cached(table),
567567
table.use_virtual_table if is_inference else None,
568+
table.enable_embedding_update,
568569
)
569570
# micromanage the order of we traverse the groups to ensure backwards compatibility
570571
if grouping_key not in groups:
@@ -581,6 +582,7 @@ def _group_tables_per_rank(
581582
_,
582583
_,
583584
use_virtual_table,
585+
enable_embedding_update,
584586
) = grouping_key
585587
grouped_tables = groups[grouping_key]
586588
# remove non-native fused params
@@ -602,6 +604,7 @@ def _group_tables_per_rank(
602604
compute_kernel=compute_kernel_type,
603605
embedding_tables=grouped_tables,
604606
fused_params=per_tbe_fused_params,
607+
enable_embedding_update=enable_embedding_update,
605608
)
606609
)
607610
return grouped_embedding_configs

torchrec/distributed/embedding_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ class GroupedEmbeddingConfig:
251251
compute_kernel: EmbeddingComputeKernel
252252
embedding_tables: List[ShardedEmbeddingTable]
253253
fused_params: Optional[Dict[str, Any]] = None
254+
enable_embedding_update: bool = False
254255

255256
def feature_hash_sizes(self) -> List[int]:
256257
feature_hash_sizes = []

torchrec/distributed/sharding/rw_sharding.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def _shard(
220220
total_num_buckets=info.embedding_config.total_num_buckets,
221221
use_virtual_table=info.embedding_config.use_virtual_table,
222222
virtual_table_eviction_policy=info.embedding_config.virtual_table_eviction_policy,
223+
enable_embedding_update=info.embedding_config.enable_embedding_update,
223224
)
224225
)
225226
return tables_per_rank
@@ -275,6 +276,20 @@ def _get_feature_hash_sizes(self) -> List[int]:
275276
feature_hash_sizes.extend(group_config.feature_hash_sizes())
276277
return feature_hash_sizes
277278

279+
def _get_num_writable_features(self) -> int:
280+
return sum(
281+
group_config.num_features()
282+
for group_config in self._grouped_embedding_configs
283+
if group_config.enable_embedding_update
284+
)
285+
286+
def _get_writable_feature_hash_sizes(self) -> List[int]:
287+
feature_hash_sizes: List[int] = []
288+
for group_config in self._grouped_embedding_configs:
289+
if group_config.enable_embedding_update:
290+
feature_hash_sizes.extend(group_config.feature_hash_sizes())
291+
return feature_hash_sizes
292+
278293

279294
class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]):
280295
"""

torchrec/modules/embedding_configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ class BaseEmbeddingConfig:
354354
total_num_buckets: Optional[int] = None
355355
use_virtual_table: bool = False
356356
virtual_table_eviction_policy: Optional[VirtualTableEvictionPolicy] = None
357+
enable_embedding_update: bool = False
357358

358359
def get_weight_init_max(self) -> float:
359360
if self.weight_init_max is None:

torchrec/schema/api_tests/test_embedding_config_schema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class StableEmbeddingBagConfig:
4343
total_num_buckets: Optional[int] = None
4444
use_virtual_table: bool = False
4545
virtual_table_eviction_policy: Optional[VirtualTableEvictionPolicy] = None
46+
enable_embedding_update: bool = False
4647
pooling: PoolingType = PoolingType.SUM
4748

4849

@@ -66,6 +67,9 @@ class StableEmbeddingConfig:
6667

6768
class TestEmbeddingConfigSchema(unittest.TestCase):
6869
def test_embedding_bag_config(self) -> None:
70+
import fbvscode
71+
72+
fbvscode.set_trace()
6973
self.assertTrue(
7074
is_signature_compatible(
7175
inspect.signature(StableEmbeddingBagConfig.__init__),

0 commit comments

Comments
 (0)