Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse benchmarking numbers #303

Merged
merged 9 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions benchmarks/benchmark_sam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import argparse
from itertools import product

import pandas as pd
# to install segment-anything-fast you can run:
# pip install git+https://github.com/pytorch-labs/segment-anything-fast.git
from segment_anything_fast import sam_model_registry
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that sam fast isn't included in our benchmarks - suggestion is maybe to put a sam folder under benchmarjs with a README on custom dependencies and how to install them or just add a comment above this line as to how people can install sam fast

import torch
from segment_anything import sam_model_registry
from torch.utils.benchmark import Timer
from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT
from torchao.quantization.quant_api import (
Expand All @@ -16,7 +21,6 @@
apply_fake_sparsity,
)
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
from itertools import product
from tqdm import tqdm

sam_checkpoint_base_path = "/home/jessecai/local/MODELS"
Expand Down Expand Up @@ -112,18 +116,22 @@ def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True,

if __name__ == "__main__":
print("BENCHMARKING")
ALL_RUNS = [run_once(qkv="quant+sparse (cutlass)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")]
# for option in tqdm(SUBCLASSES)]
# ALL_RUNS = [
# run_once(),
# run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"),
# run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"),
# run_once(qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"),
# run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
# ]
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--eager', action='store_true', help='enable/disable torch.compile')
args = parser.parse_args()
# ALL_RUNS = [run_once(qkv="quant+sparse (cutlass)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")]
ALL_RUNS = [
run_once(compile=not args.eager),
run_once(compile=not args.eager, lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"),
run_once(compile=not args.eager, lin1="sparse (cutlass)", lin2="sparse (cutlass)"),
run_once(compile=not args.eager, qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"),
run_once(compile=not args.eager, qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"),
# run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"),
# run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
]
df = pd.DataFrame(ALL_RUNS)
df.to_csv("sam_benchmark_results.csv")
print(df)
38 changes: 38 additions & 0 deletions torchao/sparsity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,44 @@ More concretely, we hope to provide tutorials and APIs for both sparse kernels (
2. Recover accuracy loss of pruned model with custom pruning algorthim.
3. Accelerate masked/pruned models on sparsity-supported hardware to realize performance improvements.

## Success Stories

#### segment-anything
We applied 2:4 sparsity to accelerate segment-anything, as part of [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast).
The results mentioned in the README of the repo compose sparsity with a suite of other inference acceleration techniques.

From our [benchmarking](https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_sam.py), we see a 1.1x speedup when running with `SEGMENT_ANYTHING_FAST_USE_FLASH_4` enabled.
To reproduce these benchmarks you can run the following command:

The inference acceleration of semi-structured sparsity depends on the matmul shapes, which is why we don't see additional speedups when applying to all linear layers (attn + mlp) of segment-anything.
We find that accelerating the MLP linear layers provied the most speedups (`lin1`, `lin2`). To repoduce our benchmarks you can run the following command:

```
python benchmarks/benchmark_sam.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put a direct link to the benchmarks script

```

The following benchmarks we run on an A100, with batch_size=32 and `bfloat16` dtype:

| qkv | proj | lin1 | lin2 | time | memory | img/s |
| ---- | ---- | ---- | ---- | ---- | ------ | ----- |
| None | None | None | None | 1361.73 | 15.81 | 23.50 |
| None | None | sparse (cusparselt) | sparse (cusparselt) | 1245.15 | 15.46 | 25.70 |
| None | None | sparse (cutlass) | sparse (cutlass) | 1251.047651 | 15.41 | 25.59 |
| sparse (cusparselt) | sparse (cusparselt) | sparse (cusparselt) | sparse (cusparselt) | 1265.43 | 12.71 | 25.29|
| sparse (cutlass) | sparse (cutlass) | sparse (cutlass) | sparse (cutlass) | 1274.96 | 12.70 | 25.10 |

#### BERT

We were able to accelerate BERT 1.23x on an A100 with a negligible accuracy drop on SQuAD.
For more information about accelerting BERT with semi-sturcutred sparsity, please see our [tutorial](https://pytorch.org/tutorials/advanced/semi_structured_sparse.html?highlight=beta).

| Metrics | fp16 | 2:4 sparse | delta / speedup |
| --- | --- | --- | --- |
| Exact Match (%) | 78.53 | 78.44 | -0.09 |
| F1 (%) | 86.93 | 86.49 | -0.44 |
| Time (bs=16) | 19.35 | 15.74 | 1.23x |


# Design

Sparsity, like quantization, is an accuracy/performance trade-off, where we care not only about the speedup but also on the accuracy degradation of our architecture optimization technique.
Expand Down
Loading