Skip to content

Conversation

@elokrainz
Copy link

@elokrainz elokrainz commented Oct 23, 2025

Summary:
Due to atomic add in torch.index_select, the backward performance sometimes is bad comparing with gather. In this diff, it provides users with control over the indexing process and select the suitable operator based on specific cases.

Perf comparison on pure operators(forward+backward)

2D Embedding, No Repetition
Config: shape=(1000000, 256), dim=0, indices=100000, unique=95300 (95.3%)
Method Time (s) Speedup Status
torch.gather 0.9439 1.00 x 🏆
torch.index_select 1.0509 0.90 x

2D Embedding, Low Repetition
Config: shape=(1000000, 256), dim=0, indices=100000, unique=48732 (48.7%)
Method Time (s) Speedup Status
torch.gather 0.9076 1.00 x 🏆
torch.index_select 1.0415 0.87 x

2D Embedding, High Repetition
Config: shape=(1000000, 256), dim=0, indices=250000, unique=9957 (4.0%)
Method Time (s) Speedup Status
torch.gather 1.2385 1.00 x 🏆
torch.index_select 1.6225 0.76 x

Small Vocab, Low Repetition
Config: shape=(1000, 256), dim=0, indices=2000, unique=635 (31.8%)
Method Time (s) Speedup Status
torch.gather 0.1502 1.00 x 🏆
torch.index_select 0.1763 0.85 x

Small Vocab, Very High Repetition
Config: shape=(1000, 256), dim=0, indices=100000, unique=625 (0.6%)
Method Time (s) Speedup Status
torch.gather 0.2626 1.00 x 🏆
torch.index_select 0.4126 0.64 x

Large Vocab, No Repetition
Config: shape=(10000000, 256), dim=0, indices=10000, unique=9996 (100.0%)
Method Time (s) Speedup Status
torch.gather 5.8014 1.00 x 🏆
torch.index_select 5.8184 1.00 x

Large Vocab, Low Repetition
Config: shape=(10000000, 256), dim=0, indices=10000, unique=5000 (50.0%)
Method Time (s) Speedup Status
torch.gather 5.7912 1.00 x 🏆
torch.index_select 5.8137 1.00 x

Large Vocab, High Repetition
Config: shape=(10000000, 256), dim=0, indices=10000, unique=400 (4.0%)
Method Time (s) Speedup Status
torch.gather 5.7784 1.00 x 🏆
torch.index_select 5.8100 0.99 x

Differential Revision: D85309309

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 23, 2025
@meta-codesync
Copy link
Contributor

meta-codesync bot commented Oct 23, 2025

@elokrainz has exported this pull request. If you are a Meta employee, you can view the originating Diff in D85309309.

elokrainz pushed a commit to elokrainz/torchrec that referenced this pull request Oct 24, 2025
Summary:

Due to atomic add in torch.index_select, the backward performance sometimes is bad comparing with gather. In this diff, it provides users with control over the indexing process and select the suitable operator based on specific cases. 


Perf comparison on pure operators(forward+backward)

2D Embedding, No Repetition
Config: shape=(1000000, 256), dim=0, indices=100000, unique=95300 (95.3%)
Method                    Time (s)     Speedup    Status
torch.gather              0.9439       1.00      x 🏆
torch.index_select        1.0509       0.90      x 

2D Embedding, Low Repetition
Config: shape=(1000000, 256), dim=0, indices=100000, unique=48732 (48.7%)
Method                    Time (s)     Speedup    Status
torch.gather              0.9076       1.00      x 🏆
torch.index_select        1.0415       0.87      x 

2D Embedding, High Repetition
Config: shape=(1000000, 256), dim=0, indices=250000, unique=9957 (4.0%)
Method                    Time (s)     Speedup    Status
torch.gather              1.2385       1.00      x 🏆
torch.index_select        1.6225       0.76      x 

Small Vocab, Low Repetition
Config: shape=(1000, 256), dim=0, indices=2000, unique=635 (31.8%)
Method                    Time (s)     Speedup    Status
torch.gather              0.1502       1.00      x 🏆
torch.index_select        0.1763       0.85      x 

Small Vocab, Very High Repetition
Config: shape=(1000, 256), dim=0, indices=100000, unique=625 (0.6%)
Method                    Time (s)     Speedup    Status
torch.gather              0.2626       1.00      x 🏆
torch.index_select        0.4126       0.64      x 

Large Vocab, No Repetition
Config: shape=(10000000, 256), dim=0, indices=10000, unique=9996 (100.0%)
Method                    Time (s)     Speedup    Status
torch.gather              5.8014       1.00      x 🏆
torch.index_select        5.8184       1.00      x 

Large Vocab, Low Repetition
Config: shape=(10000000, 256), dim=0, indices=10000, unique=5000 (50.0%)
Method                    Time (s)     Speedup    Status
torch.gather              5.7912       1.00      x 🏆
torch.index_select        5.8137       1.00      x 

Large Vocab, High Repetition
Config: shape=(10000000, 256), dim=0, indices=10000, unique=400 (4.0%)
Method                    Time (s)     Speedup    Status
torch.gather              5.7784       1.00      x 🏆
torch.index_select        5.8100       0.99      x 



Mast Job Test:
baseline: fire-jingchang-f816557933
torch.index_select backward takes ~37ms
{F1982939713}

exp: fire-jingchang-f816355728
torch.gather backward takes ~10ms
{F1982939742}

Reviewed By: TroyGarden

Differential Revision: D85309309
elokrainz pushed a commit to elokrainz/torchrec that referenced this pull request Oct 24, 2025
Summary:
Pull Request resolved: meta-pytorch#3479

Due to atomic add in torch.index_select, the backward performance sometimes is bad comparing with gather. In this diff, it provides users with control over the indexing process and select the suitable operator based on specific cases.

Perf comparison on pure operators(forward+backward)

2D Embedding, No Repetition
Config: shape=(1000000, 256), dim=0, indices=100000, unique=95300 (95.3%)
Method                    Time (s)     Speedup    Status
torch.gather              0.9439       1.00      x 🏆
torch.index_select        1.0509       0.90      x

2D Embedding, Low Repetition
Config: shape=(1000000, 256), dim=0, indices=100000, unique=48732 (48.7%)
Method                    Time (s)     Speedup    Status
torch.gather              0.9076       1.00      x 🏆
torch.index_select        1.0415       0.87      x

2D Embedding, High Repetition
Config: shape=(1000000, 256), dim=0, indices=250000, unique=9957 (4.0%)
Method                    Time (s)     Speedup    Status
torch.gather              1.2385       1.00      x 🏆
torch.index_select        1.6225       0.76      x

Small Vocab, Low Repetition
Config: shape=(1000, 256), dim=0, indices=2000, unique=635 (31.8%)
Method                    Time (s)     Speedup    Status
torch.gather              0.1502       1.00      x 🏆
torch.index_select        0.1763       0.85      x

Small Vocab, Very High Repetition
Config: shape=(1000, 256), dim=0, indices=100000, unique=625 (0.6%)
Method                    Time (s)     Speedup    Status
torch.gather              0.2626       1.00      x 🏆
torch.index_select        0.4126       0.64      x

Large Vocab, No Repetition
Config: shape=(10000000, 256), dim=0, indices=10000, unique=9996 (100.0%)
Method                    Time (s)     Speedup    Status
torch.gather              5.8014       1.00      x 🏆
torch.index_select        5.8184       1.00      x

Large Vocab, Low Repetition
Config: shape=(10000000, 256), dim=0, indices=10000, unique=5000 (50.0%)
Method                    Time (s)     Speedup    Status
torch.gather              5.7912       1.00      x 🏆
torch.index_select        5.8137       1.00      x

Large Vocab, High Repetition
Config: shape=(10000000, 256), dim=0, indices=10000, unique=400 (4.0%)
Method                    Time (s)     Speedup    Status
torch.gather              5.7784       1.00      x 🏆
torch.index_select        5.8100       0.99      x

Mast Job Test:
baseline: fire-jingchang-f816557933
torch.index_select backward takes ~37ms
{F1982939713}

exp: fire-jingchang-f816355728
torch.gather backward takes ~10ms
{F1982939742}

Reviewed By: TroyGarden

Differential Revision: D85309309
Summary:

Due to atomic add in torch.index_select, the backward performance sometimes is bad comparing with gather. In this diff, it provides users with control over the indexing process and select the suitable operator based on specific cases. 


Perf comparison on pure operators(forward+backward)

2D Embedding, No Repetition
Config: shape=(1000000, 256), dim=0, indices=100000, unique=95300 (95.3%)
Method                    Time (s)     Speedup    Status
torch.gather              0.9439       1.00      x 🏆
torch.index_select        1.0509       0.90      x 

2D Embedding, Low Repetition
Config: shape=(1000000, 256), dim=0, indices=100000, unique=48732 (48.7%)
Method                    Time (s)     Speedup    Status
torch.gather              0.9076       1.00      x 🏆
torch.index_select        1.0415       0.87      x 

2D Embedding, High Repetition
Config: shape=(1000000, 256), dim=0, indices=250000, unique=9957 (4.0%)
Method                    Time (s)     Speedup    Status
torch.gather              1.2385       1.00      x 🏆
torch.index_select        1.6225       0.76      x 

Small Vocab, Low Repetition
Config: shape=(1000, 256), dim=0, indices=2000, unique=635 (31.8%)
Method                    Time (s)     Speedup    Status
torch.gather              0.1502       1.00      x 🏆
torch.index_select        0.1763       0.85      x 

Small Vocab, Very High Repetition
Config: shape=(1000, 256), dim=0, indices=100000, unique=625 (0.6%)
Method                    Time (s)     Speedup    Status
torch.gather              0.2626       1.00      x 🏆
torch.index_select        0.4126       0.64      x 

Large Vocab, No Repetition
Config: shape=(10000000, 256), dim=0, indices=10000, unique=9996 (100.0%)
Method                    Time (s)     Speedup    Status
torch.gather              5.8014       1.00      x 🏆
torch.index_select        5.8184       1.00      x 

Large Vocab, Low Repetition
Config: shape=(10000000, 256), dim=0, indices=10000, unique=5000 (50.0%)
Method                    Time (s)     Speedup    Status
torch.gather              5.7912       1.00      x 🏆
torch.index_select        5.8137       1.00      x 

Large Vocab, High Repetition
Config: shape=(10000000, 256), dim=0, indices=10000, unique=400 (4.0%)
Method                    Time (s)     Speedup    Status
torch.gather              5.7784       1.00      x 🏆
torch.index_select        5.8100       0.99      x 



Mast Job Test:
baseline: fire-jingchang-f816557933
torch.index_select backward takes ~37ms
{F1982939713}

exp: fire-jingchang-f816355728
torch.gather backward takes ~10ms
{F1982939742}

Reviewed By: TroyGarden

Differential Revision: D85309309
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant