From 2007686265fea70311cdb2d22c2b2cc1263acaea Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Fri, 15 Nov 2024 11:35:00 +0000 Subject: [PATCH] 2024-11-15 nightly release (abbb5dcde4e94dca483c68f60b04237dead400f8) --- ...dding_split_host_pt2_autograd_template.cpp | 1 + ...t_table_batched_embeddings_ops_training.py | 31 ++++++++++++------- src/EmbeddingSpMDMAutovec.cc | 18 ++++++----- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 5a7d73260..8b81f5f44 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -526,6 +526,7 @@ enum SSDTensor { class {{ autograd_func }} : public torch::autograd::Function<{{ autograd_func }}> { public: + static constexpr bool is_traceable = true; static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, const Tensor& placeholder_autograd_tensor, diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index ba98b69c9..8eb3a3de5 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -580,7 +580,7 @@ def __init__( # noqa C901 cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, - cache_precision: SparseType = SparseType.FP32, + cache_precision: Optional[SparseType] = None, weights_precision: SparseType = SparseType.FP32, output_dtype: SparseType = SparseType.FP32, enforce_hbm: bool = False, @@ -619,6 +619,7 @@ def __init__( # noqa C901 uvm_host_mapped: bool = False, ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() + self.uuid = str(uuid.uuid4()) self.logging_table_name: str = self.get_table_name_for_logging(table_names) self.pooling_mode = pooling_mode @@ -627,6 +628,9 @@ def __init__( # noqa C901 os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value) ) self.weights_precision = weights_precision + cache_precision = ( + weights_precision if cache_precision is None else cache_precision + ) self.output_dtype: int = output_dtype.as_int() assert ( not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU @@ -1175,20 +1179,13 @@ def __init__( # noqa C901 ), ) - if cache_precision == SparseType.FP32: - cache_embedding_dtype = torch.float32 - elif cache_precision == SparseType.FP16: - cache_embedding_dtype = torch.float16 - else: - raise AssertionError(f"cache_precision {cache_precision} not supported!") - self._apply_cache_state( cache_state, cache_algorithm, cache_load_factor, cache_sets, cache_reserved_memory, - dtype=cache_embedding_dtype, + cache_precision, ) self.log(f"Contents: {table_names}") @@ -2643,7 +2640,7 @@ def _apply_cache_state( cache_load_factor: float, cache_sets: int, cache_reserved_memory: float, - dtype: torch.dtype, + cache_precision: SparseType, ) -> None: self.cache_algorithm = cache_algorithm self.timestep = 1 @@ -2663,6 +2660,17 @@ def _apply_cache_state( self._init_uvm_cache_stats() + if cache_precision == SparseType.FP32: + dtype = torch.float32 + elif cache_precision == SparseType.FP16: + dtype = torch.float16 + else: + dtype = torch.float32 # not relevant, but setting it to keep linter happy + if not self.use_cpu > 0: + raise AssertionError( + f"cache_precision {cache_precision} not supported!" + ) + # NOTE: no cache for CPU mode! if cache_state.total_cache_hash_size == 0 or self.use_cpu: self.register_buffer( @@ -2740,7 +2748,8 @@ def _apply_cache_state( f"{cache_algorithm}, {cache_sets} sets, " f"load_factor: {cache_load_factor : .3f}, " f"cache_size: {cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, " - f"cache_precision: {dtype}" + f"cache_precision: {dtype}, " + f"weights_precision: {self.weights_precision}" ) self.total_cache_hash_size = cache_state.total_cache_hash_size diff --git a/src/EmbeddingSpMDMAutovec.cc b/src/EmbeddingSpMDMAutovec.cc index 67a720e88..29489165d 100644 --- a/src/EmbeddingSpMDMAutovec.cc +++ b/src/EmbeddingSpMDMAutovec.cc @@ -24,6 +24,8 @@ namespace fbgemm { +static constexpr size_t LOCAL_STORAGE_SIZE = 512; + template static inline void fill_output( OutType* out, @@ -99,10 +101,10 @@ bool EmbeddingSpMDM8Bit_autovec( const int64_t scale_bias_offset = scale_bias_last ? block_size : 0; const int64_t input_offset = scale_bias_last ? 0 : scale_bias_size; - std::array local_storage; + std::array local_storage; std::unique_ptr heap_storage; float* buf; - if (block_size <= 256) { + if (static_cast(block_size) <= LOCAL_STORAGE_SIZE) { buf = local_storage.data(); } else { heap_storage.reset(new float[block_size]); @@ -363,10 +365,10 @@ bool EmbeddingSpMDMNBit_autovec( int64_t current = 0; const int64_t rounded_block_size = round_up(block_size, num_elem_per_byte); - std::array local_storage; + std::array local_storage; std::unique_ptr heap_storage; float* buf; - if (rounded_block_size <= 256) { + if (static_cast(rounded_block_size) <= LOCAL_STORAGE_SIZE) { buf = local_storage.data(); } else { heap_storage.reset(new float[rounded_block_size]); @@ -504,10 +506,10 @@ bool EmbeddingSpMDM_autovec( output_stride = block_size; } - std::array local_storage; + std::array local_storage; std::unique_ptr heap_storage; float* buf; - if (block_size <= 256) { + if (static_cast(block_size) <= LOCAL_STORAGE_SIZE) { buf = local_storage.data(); } else { heap_storage.reset(new float[block_size]); @@ -862,10 +864,10 @@ bool EmbeddingSpMDMFP8_autovec( output_stride = block_size; } - std::array local_storage; + std::array local_storage; std::unique_ptr heap_storage; float* buf; - if (block_size <= 256) { + if (static_cast(block_size) <= LOCAL_STORAGE_SIZE) { buf = local_storage.data(); } else { heap_storage.reset(new float[block_size]);