Skip to content

Commit

Permalink
Add --pooling in TBE nbit_cpu benchmark (#2200)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2200

As title

Reviewed By: YazhiGao

Differential Revision: D51963691

fbshipit-source-id: 45604dc2a7e4a029bc6172da4ef39d7ab648dc34
  • Loading branch information
sryap authored and facebook-github-bot committed Dec 8, 2023
1 parent 90e81f5 commit 8724d89
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,7 @@ def benchmark_cpu_requests(
@click.option("--output-dtype", type=SparseType, default=SparseType.FP16)
@click.option("--fp8-exponent-bits", type=int, default=None)
@click.option("--fp8-exponent-bias", type=int, default=None)
@click.option("--pooling", type=str, default="sum")
def nbit_cpu( # noqa C901
alpha: float,
bag_size: int,
Expand All @@ -807,6 +808,7 @@ def nbit_cpu( # noqa C901
output_dtype: SparseType,
fp8_exponent_bits: Optional[int],
fp8_exponent_bias: Optional[int],
pooling: str,
) -> None:
np.random.seed(42)
torch.manual_seed(42)
Expand All @@ -825,11 +827,23 @@ def nbit_cpu( # noqa C901
else:
Ds = [D] * T

if pooling is None or pooling == "sum":
pooling = "sum"
pooling_mode = PoolingMode.SUM
do_pooling = True
elif pooling == "mean":
pooling_mode = PoolingMode.MEAN
do_pooling = True
else: # "none"
pooling_mode = PoolingMode.NONE
do_pooling = False

emb = IntNBitTableBatchedEmbeddingBagsCodegen(
[("", E, d, weights_precision, EmbeddingLocation.HOST) for d in Ds],
device="cpu",
index_remapping=[torch.arange(E) for _ in Ds] if index_remapping else None,
output_dtype=output_dtype,
pooling_mode=pooling_mode,
fp8_exponent_bits=fp8_exponent_bits,
fp8_exponent_bias=fp8_exponent_bias,
).cpu()
Expand All @@ -839,9 +853,16 @@ def nbit_cpu( # noqa C901
nparams_byte = sum(w.numel() for (w, _) in emb.split_embedding_weights())
param_size_multiplier = weights_precision.bit_rate() / 8.0
output_size_multiplier = output_dtype.bit_rate() / 8.0
read_write_bytes = (
output_size_multiplier * B * T * D + param_size_multiplier * B * T * L * D
)
if do_pooling:
read_write_bytes = (
output_size_multiplier * B * T * D + param_size_multiplier * B * T * L * D
)
else:
read_write_bytes = (
output_size_multiplier * B * T * L * D
+ param_size_multiplier * B * T * L * D
)

logging.info(
f"{weights_precision} Embedding tables: {E * T} rows, {nparams_byte / param_size_multiplier / 1.0e9: .2f} GParam, "
f"{nparams_byte / 1.0e9: .2f} GB" # IntN TBE use byte for storage
Expand Down

0 comments on commit 8724d89

Please sign in to comment.