Skip to content

Commit

Permalink
Support sending using lengths to TBE instead of just offsets
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#508

X-link: pytorch/torchrec#2557

Here we modify FBGEMM inference TBE to have the forward logic live in `_forward_impl`. This change makes it easy for subclasses of TBE to extend without having to call `super()`, which is not TorchScriptable. An example of a subclass using this is D66515313.

Having subclasses TorchScriptable is vital for inference as TBE is generally not FX traced through.

Reviewed By: sryap

Differential Revision: D64906767

fbshipit-source-id: 41bab272c2611fc97dece7f89fdbbd820671843c
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Nov 27, 2024
1 parent cffa05a commit 6eb379a
Showing 1 changed file with 11 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def _update_tablewise_cache_miss(

self.table_wise_cache_miss[i] += miss_count

def forward(
def _forward_impl(
self,
indices: Tensor,
offsets: Tensor,
Expand Down Expand Up @@ -1016,6 +1016,16 @@ def forward(
fp8_exponent_bias=self.fp8_exponent_bias,
)

def forward(
self,
indices: Tensor,
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
) -> Tensor:
return self._forward_impl(
indices=indices, offsets=offsets, per_sample_weights=per_sample_weights
)

def initialize_logical_weights_placements_and_offsets(
self,
) -> None:
Expand Down

0 comments on commit 6eb379a

Please sign in to comment.