Skip to content

Commit

Permalink
Set cache_precision = weights_precision in TBE if it is not explicitl…
Browse files Browse the repository at this point in the history
…y set (#3370)

Summary:
Pull Request resolved: #3370

X-link: facebookresearch/FBGEMM#461

For historical reasons, we have decoupled cache_precision and weights_precision. This was to allow for lower precision embedding, while the hot ones in the cache are using higher precision. But it is not used.

This decoupling has been source of unintentional difference in cache_precision and weights_precision for typical EMO usecases where we do want cache_precision == weights_precision. We had enforced this in torchrec (see [this](https://www.internalfb.com/code/fbsource/[3868325cdafd]/fbcode/torchrec/distributed/batched_embedding_kernel.py?lines=962-963%2C1446-1447)), but as new stacks are enabled, this could be overwritten.

Here we enforce cache_precision = weights_precision if cache precision is not explicitly set.

Reviewed By: sryap, zhangruiskyline

Differential Revision: D65865527

fbshipit-source-id: a79e6aad3b30c46c80f607f406930589d089f37a
  • Loading branch information
ehsanardestani authored and facebook-github-bot committed Nov 14, 2024
1 parent 6dd2d31 commit 10ae4f8
Showing 1 changed file with 20 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def __init__( # noqa C901
cache_load_factor: float = 0.2,
cache_sets: int = 0,
cache_reserved_memory: float = 0.0,
cache_precision: SparseType = SparseType.FP32,
cache_precision: Optional[SparseType] = None,
weights_precision: SparseType = SparseType.FP32,
output_dtype: SparseType = SparseType.FP32,
enforce_hbm: bool = False,
Expand Down Expand Up @@ -619,6 +619,7 @@ def __init__( # noqa C901
uvm_host_mapped: bool = False,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()

self.uuid = str(uuid.uuid4())
self.logging_table_name: str = self.get_table_name_for_logging(table_names)
self.pooling_mode = pooling_mode
Expand All @@ -627,6 +628,9 @@ def __init__( # noqa C901
os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value)
)
self.weights_precision = weights_precision
cache_precision = (
weights_precision if cache_precision is None else cache_precision
)
self.output_dtype: int = output_dtype.as_int()
assert (
not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU
Expand Down Expand Up @@ -1175,20 +1179,13 @@ def __init__( # noqa C901
),
)

if cache_precision == SparseType.FP32:
cache_embedding_dtype = torch.float32
elif cache_precision == SparseType.FP16:
cache_embedding_dtype = torch.float16
else:
raise AssertionError(f"cache_precision {cache_precision} not supported!")

self._apply_cache_state(
cache_state,
cache_algorithm,
cache_load_factor,
cache_sets,
cache_reserved_memory,
dtype=cache_embedding_dtype,
cache_precision,
)

self.log(f"Contents: {table_names}")
Expand Down Expand Up @@ -2643,7 +2640,7 @@ def _apply_cache_state(
cache_load_factor: float,
cache_sets: int,
cache_reserved_memory: float,
dtype: torch.dtype,
cache_precision: SparseType,
) -> None:
self.cache_algorithm = cache_algorithm
self.timestep = 1
Expand All @@ -2663,6 +2660,17 @@ def _apply_cache_state(

self._init_uvm_cache_stats()

if cache_precision == SparseType.FP32:
dtype = torch.float32
elif cache_precision == SparseType.FP16:
dtype = torch.float16
else:
dtype = torch.float32 # not relevant, but setting it to keep linter happy
if not self.use_cpu > 0:
raise AssertionError(
f"cache_precision {cache_precision} not supported!"
)

# NOTE: no cache for CPU mode!
if cache_state.total_cache_hash_size == 0 or self.use_cpu:
self.register_buffer(
Expand Down Expand Up @@ -2740,7 +2748,8 @@ def _apply_cache_state(
f"{cache_algorithm}, {cache_sets} sets, "
f"load_factor: {cache_load_factor : .3f}, "
f"cache_size: {cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
f"cache_precision: {dtype}"
f"cache_precision: {dtype}, "
f"weights_precision: {self.weights_precision}"
)

self.total_cache_hash_size = cache_state.total_cache_hash_size
Expand Down

0 comments on commit 10ae4f8

Please sign in to comment.