From 10ae4f84b95692aa10a35760290501ddf177d2db Mon Sep 17 00:00:00 2001 From: "Ehsan K. Ardestani" Date: Thu, 14 Nov 2024 14:58:38 -0800 Subject: [PATCH] Set cache_precision = weights_precision in TBE if it is not explicitly set (#3370) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3370 X-link: https://github.com/facebookresearch/FBGEMM/pull/461 For historical reasons, we have decoupled cache_precision and weights_precision. This was to allow for lower precision embedding, while the hot ones in the cache are using higher precision. But it is not used. This decoupling has been source of unintentional difference in cache_precision and weights_precision for typical EMO usecases where we do want cache_precision == weights_precision. We had enforced this in torchrec (see [this](https://www.internalfb.com/code/fbsource/[3868325cdafd]/fbcode/torchrec/distributed/batched_embedding_kernel.py?lines=962-963%2C1446-1447)), but as new stacks are enabled, this could be overwritten. Here we enforce cache_precision = weights_precision if cache precision is not explicitly set. Reviewed By: sryap, zhangruiskyline Differential Revision: D65865527 fbshipit-source-id: a79e6aad3b30c46c80f607f406930589d089f37a --- ...t_table_batched_embeddings_ops_training.py | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index ba98b69c9..8eb3a3de5 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -580,7 +580,7 @@ def __init__( # noqa C901 cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, - cache_precision: SparseType = SparseType.FP32, + cache_precision: Optional[SparseType] = None, weights_precision: SparseType = SparseType.FP32, output_dtype: SparseType = SparseType.FP32, enforce_hbm: bool = False, @@ -619,6 +619,7 @@ def __init__( # noqa C901 uvm_host_mapped: bool = False, ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() + self.uuid = str(uuid.uuid4()) self.logging_table_name: str = self.get_table_name_for_logging(table_names) self.pooling_mode = pooling_mode @@ -627,6 +628,9 @@ def __init__( # noqa C901 os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value) ) self.weights_precision = weights_precision + cache_precision = ( + weights_precision if cache_precision is None else cache_precision + ) self.output_dtype: int = output_dtype.as_int() assert ( not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU @@ -1175,20 +1179,13 @@ def __init__( # noqa C901 ), ) - if cache_precision == SparseType.FP32: - cache_embedding_dtype = torch.float32 - elif cache_precision == SparseType.FP16: - cache_embedding_dtype = torch.float16 - else: - raise AssertionError(f"cache_precision {cache_precision} not supported!") - self._apply_cache_state( cache_state, cache_algorithm, cache_load_factor, cache_sets, cache_reserved_memory, - dtype=cache_embedding_dtype, + cache_precision, ) self.log(f"Contents: {table_names}") @@ -2643,7 +2640,7 @@ def _apply_cache_state( cache_load_factor: float, cache_sets: int, cache_reserved_memory: float, - dtype: torch.dtype, + cache_precision: SparseType, ) -> None: self.cache_algorithm = cache_algorithm self.timestep = 1 @@ -2663,6 +2660,17 @@ def _apply_cache_state( self._init_uvm_cache_stats() + if cache_precision == SparseType.FP32: + dtype = torch.float32 + elif cache_precision == SparseType.FP16: + dtype = torch.float16 + else: + dtype = torch.float32 # not relevant, but setting it to keep linter happy + if not self.use_cpu > 0: + raise AssertionError( + f"cache_precision {cache_precision} not supported!" + ) + # NOTE: no cache for CPU mode! if cache_state.total_cache_hash_size == 0 or self.use_cpu: self.register_buffer( @@ -2740,7 +2748,8 @@ def _apply_cache_state( f"{cache_algorithm}, {cache_sets} sets, " f"load_factor: {cache_load_factor : .3f}, " f"cache_size: {cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, " - f"cache_precision: {dtype}" + f"cache_precision: {dtype}, " + f"weights_precision: {self.weights_precision}" ) self.total_cache_hash_size = cache_state.total_cache_hash_size