File tree Expand file tree Collapse file tree 5 files changed +24
-0
lines changed Expand file tree Collapse file tree 5 files changed +24
-0
lines changed Original file line number Diff line number Diff line change @@ -565,6 +565,7 @@ def _group_tables_per_rank(
565
565
),
566
566
_prefetch_and_cached (table ),
567
567
table .use_virtual_table if is_inference else None ,
568
+ table .enable_embedding_update ,
568
569
)
569
570
# micromanage the order of we traverse the groups to ensure backwards compatibility
570
571
if grouping_key not in groups :
@@ -581,6 +582,7 @@ def _group_tables_per_rank(
581
582
_ ,
582
583
_ ,
583
584
use_virtual_table ,
585
+ enable_embedding_update ,
584
586
) = grouping_key
585
587
grouped_tables = groups [grouping_key ]
586
588
# remove non-native fused params
@@ -602,6 +604,7 @@ def _group_tables_per_rank(
602
604
compute_kernel = compute_kernel_type ,
603
605
embedding_tables = grouped_tables ,
604
606
fused_params = per_tbe_fused_params ,
607
+ enable_embedding_update = enable_embedding_update ,
605
608
)
606
609
)
607
610
return grouped_embedding_configs
Original file line number Diff line number Diff line change @@ -251,6 +251,7 @@ class GroupedEmbeddingConfig:
251
251
compute_kernel : EmbeddingComputeKernel
252
252
embedding_tables : List [ShardedEmbeddingTable ]
253
253
fused_params : Optional [Dict [str , Any ]] = None
254
+ enable_embedding_update : bool = False
254
255
255
256
def feature_hash_sizes (self ) -> List [int ]:
256
257
feature_hash_sizes = []
Original file line number Diff line number Diff line change @@ -220,6 +220,7 @@ def _shard(
220
220
total_num_buckets = info .embedding_config .total_num_buckets ,
221
221
use_virtual_table = info .embedding_config .use_virtual_table ,
222
222
virtual_table_eviction_policy = info .embedding_config .virtual_table_eviction_policy ,
223
+ enable_embedding_update = info .embedding_config .enable_embedding_update ,
223
224
)
224
225
)
225
226
return tables_per_rank
@@ -275,6 +276,20 @@ def _get_feature_hash_sizes(self) -> List[int]:
275
276
feature_hash_sizes .extend (group_config .feature_hash_sizes ())
276
277
return feature_hash_sizes
277
278
279
+ def _get_num_writable_features (self ) -> int :
280
+ return sum (
281
+ group_config .num_features ()
282
+ for group_config in self ._grouped_embedding_configs
283
+ if group_config .enable_embedding_update
284
+ )
285
+
286
+ def _get_writable_feature_hash_sizes (self ) -> List [int ]:
287
+ feature_hash_sizes : List [int ] = []
288
+ for group_config in self ._grouped_embedding_configs :
289
+ if group_config .enable_embedding_update :
290
+ feature_hash_sizes .extend (group_config .feature_hash_sizes ())
291
+ return feature_hash_sizes
292
+
278
293
279
294
class RwSparseFeaturesDist (BaseSparseFeaturesDist [KeyedJaggedTensor ]):
280
295
"""
Original file line number Diff line number Diff line change @@ -354,6 +354,7 @@ class BaseEmbeddingConfig:
354
354
total_num_buckets : Optional [int ] = None
355
355
use_virtual_table : bool = False
356
356
virtual_table_eviction_policy : Optional [VirtualTableEvictionPolicy ] = None
357
+ enable_embedding_update : bool = False
357
358
358
359
def get_weight_init_max (self ) -> float :
359
360
if self .weight_init_max is None :
Original file line number Diff line number Diff line change @@ -43,6 +43,7 @@ class StableEmbeddingBagConfig:
43
43
total_num_buckets : Optional [int ] = None
44
44
use_virtual_table : bool = False
45
45
virtual_table_eviction_policy : Optional [VirtualTableEvictionPolicy ] = None
46
+ enable_embedding_update : bool = False
46
47
pooling : PoolingType = PoolingType .SUM
47
48
48
49
@@ -66,6 +67,9 @@ class StableEmbeddingConfig:
66
67
67
68
class TestEmbeddingConfigSchema (unittest .TestCase ):
68
69
def test_embedding_bag_config (self ) -> None :
70
+ import fbvscode
71
+
72
+ fbvscode .set_trace ()
69
73
self .assertTrue (
70
74
is_signature_compatible (
71
75
inspect .signature (StableEmbeddingBagConfig .__init__ ),
You can’t perform that action at this time.
0 commit comments