From 6eb379a96d55a2373c0a3ad4d7b46567d3b307d6 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Wed, 27 Nov 2024 12:19:18 -0800 Subject: [PATCH] Support sending using lengths to TBE instead of just offsets Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/508 X-link: https://github.com/pytorch/torchrec/pull/2557 Here we modify FBGEMM inference TBE to have the forward logic live in `_forward_impl`. This change makes it easy for subclasses of TBE to extend without having to call `super()`, which is not TorchScriptable. An example of a subclass using this is D66515313. Having subclasses TorchScriptable is vital for inference as TBE is generally not FX traced through. Reviewed By: sryap Differential Revision: D64906767 fbshipit-source-id: 41bab272c2611fc97dece7f89fdbbd820671843c --- .../split_table_batched_embeddings_ops_inference.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 63a646dfbc..3ff2e62712 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -913,7 +913,7 @@ def _update_tablewise_cache_miss( self.table_wise_cache_miss[i] += miss_count - def forward( + def _forward_impl( self, indices: Tensor, offsets: Tensor, @@ -1016,6 +1016,16 @@ def forward( fp8_exponent_bias=self.fp8_exponent_bias, ) + def forward( + self, + indices: Tensor, + offsets: Tensor, + per_sample_weights: Optional[Tensor] = None, + ) -> Tensor: + return self._forward_impl( + indices=indices, offsets=offsets, per_sample_weights=per_sample_weights + ) + def initialize_logical_weights_placements_and_offsets( self, ) -> None: