Skip to content

Commit

Permalink
Set cache_precision = weights_precision in TBE if it is not explicitl…
Browse files Browse the repository at this point in the history
…y set

Summary:
X-link: facebookresearch/FBGEMM#461

For historical reasons, we have decoupled cache_precision and weights_precision. This was to allow for lower precision embedding, while the hot ones in the cache are using higher precision. But it is not used.

This decoupling has been source of unintentional difference in cache_precision and weights_precision for typical EMO usecases where we do want cache_precision == weights_precision. We had enforced this in torchrec (see [this](https://www.internalfb.com/code/fbsource/[3868325cdafd]/fbcode/torchrec/distributed/batched_embedding_kernel.py?lines=962-963%2C1446-1447)), but as new stacks are enabled, this could be overwritten. 

Here we enforce cache_precision = weights_precision if cache precision is not explicitly set.

Differential Revision: D65865527
  • Loading branch information
ehsanardestani authored and facebook-github-bot committed Nov 14, 2024
1 parent 2da2b7a commit 1442708
Showing 1 changed file with 20 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def __init__( # noqa C901
cache_load_factor: float = 0.2,
cache_sets: int = 0,
cache_reserved_memory: float = 0.0,
cache_precision: SparseType = SparseType.FP32,
cache_precision: Optional[SparseType] = None,
weights_precision: SparseType = SparseType.FP32,
output_dtype: SparseType = SparseType.FP32,
enforce_hbm: bool = False,
Expand Down Expand Up @@ -619,6 +619,7 @@ def __init__( # noqa C901
uvm_host_mapped: bool = False,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()

self.uuid = str(uuid.uuid4())
self.logging_table_name: str = self.get_table_name_for_logging(table_names)
self.pooling_mode = pooling_mode
Expand All @@ -627,6 +628,9 @@ def __init__( # noqa C901
os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value)
)
self.weights_precision = weights_precision
cache_precision = (
weights_precision if cache_precision is None else cache_precision
)
self.output_dtype: int = output_dtype.as_int()
assert (
not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU
Expand Down Expand Up @@ -1175,20 +1179,13 @@ def __init__( # noqa C901
),
)

if cache_precision == SparseType.FP32:
cache_embedding_dtype = torch.float32
elif cache_precision == SparseType.FP16:
cache_embedding_dtype = torch.float16
else:
raise AssertionError(f"cache_precision {cache_precision} not supported!")

self._apply_cache_state(
cache_state,
cache_algorithm,
cache_load_factor,
cache_sets,
cache_reserved_memory,
dtype=cache_embedding_dtype,
cache_precision,
)

self.log(f"Contents: {table_names}")
Expand Down Expand Up @@ -2643,7 +2640,7 @@ def _apply_cache_state(
cache_load_factor: float,
cache_sets: int,
cache_reserved_memory: float,
dtype: torch.dtype,
cache_precision: SparseType,
) -> None:
self.cache_algorithm = cache_algorithm
self.timestep = 1
Expand All @@ -2663,6 +2660,17 @@ def _apply_cache_state(

self._init_uvm_cache_stats()

if cache_precision == SparseType.FP32:
dtype = torch.float32
elif cache_precision == SparseType.FP16:
dtype = torch.float16
else:
dtype = torch.float32 # not relevant, but setting it to keep linter happy
if not self.use_cpu > 0:
raise AssertionError(
f"cache_precision {cache_precision} not supported!"
)

# NOTE: no cache for CPU mode!
if cache_state.total_cache_hash_size == 0 or self.use_cpu:
self.register_buffer(
Expand Down Expand Up @@ -2740,7 +2748,8 @@ def _apply_cache_state(
f"{cache_algorithm}, {cache_sets} sets, "
f"load_factor: {cache_load_factor : .3f}, "
f"cache_size: {cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
f"cache_precision: {dtype}"
f"cache_precision: {dtype}, "
f"weights_precision: {self.weights_precision}"
)

self.total_cache_hash_size = cache_state.total_cache_hash_size
Expand Down

0 comments on commit 1442708

Please sign in to comment.