Skip to content

Commit

Permalink
sparse benchmarking numbers (pytorch#303)
Browse files Browse the repository at this point in the history
- Updated benchmark script for standalone sparse numbers.
- Switched from segment-anything to segment-anything-fast
- Updated README with results for segment-anything and BERT
  • Loading branch information
jcaip committed Jun 3, 2024
1 parent bf26fd3 commit 07efae2
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 14 deletions.
36 changes: 22 additions & 14 deletions benchmarks/benchmark_sam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import argparse
from itertools import product

import pandas as pd
# to install segment-anything-fast you can run:
# pip install git+https://github.com/pytorch-labs/segment-anything-fast.git
from segment_anything_fast import sam_model_registry
import torch
from segment_anything import sam_model_registry
from torch.utils.benchmark import Timer
from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT
from torchao.quantization.quant_api import (
Expand All @@ -16,7 +21,6 @@
apply_fake_sparsity,
)
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
from itertools import product
from tqdm import tqdm

sam_checkpoint_base_path = "/home/jessecai/local/MODELS"
Expand Down Expand Up @@ -112,18 +116,22 @@ def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True,

if __name__ == "__main__":
print("BENCHMARKING")
ALL_RUNS = [run_once(qkv="quant+sparse (cutlass)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")]
# for option in tqdm(SUBCLASSES)]
# ALL_RUNS = [
# run_once(),
# run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"),
# run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"),
# run_once(qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"),
# run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
# ]
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--eager', action='store_true', help='enable/disable torch.compile')
args = parser.parse_args()
# ALL_RUNS = [run_once(qkv="quant+sparse (cutlass)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")]
ALL_RUNS = [
run_once(compile=not args.eager),
run_once(compile=not args.eager, lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"),
run_once(compile=not args.eager, lin1="sparse (cutlass)", lin2="sparse (cutlass)"),
run_once(compile=not args.eager, qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"),
run_once(compile=not args.eager, qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"),
# run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"),
# run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
]
df = pd.DataFrame(ALL_RUNS)
df.to_csv("sam_benchmark_results.csv")
print(df)
38 changes: 38 additions & 0 deletions torchao/sparsity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,44 @@ More concretely, we hope to provide tutorials and APIs for both sparse kernels (
2. Recover accuracy loss of pruned model with custom pruning algorthim.
3. Accelerate masked/pruned models on sparsity-supported hardware to realize performance improvements.

## Success Stories

#### segment-anything
We applied 2:4 sparsity to accelerate segment-anything, as part of [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast).
The results mentioned in the README of the repo compose sparsity with a suite of other inference acceleration techniques.

From our [benchmarking](https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_sam.py), we see a 1.1x speedup when running with `SEGMENT_ANYTHING_FAST_USE_FLASH_4` enabled.
To reproduce these benchmarks you can run the following command:

The inference acceleration of semi-structured sparsity depends on the matmul shapes, which is why we don't see additional speedups when applying to all linear layers (attn + mlp) of segment-anything.
We find that accelerating the MLP linear layers provied the most speedups (`lin1`, `lin2`). To repoduce our benchmarks you can run the following command:

```
python benchmarks/benchmark_sam.py
```

The following benchmarks we run on an A100, with batch_size=32 and `bfloat16` dtype:

| qkv | proj | lin1 | lin2 | time | memory | img/s |
| ---- | ---- | ---- | ---- | ---- | ------ | ----- |
| None | None | None | None | 1361.73 | 15.81 | 23.50 |
| None | None | sparse (cusparselt) | sparse (cusparselt) | 1245.15 | 15.46 | 25.70 |
| None | None | sparse (cutlass) | sparse (cutlass) | 1251.047651 | 15.41 | 25.59 |
| sparse (cusparselt) | sparse (cusparselt) | sparse (cusparselt) | sparse (cusparselt) | 1265.43 | 12.71 | 25.29|
| sparse (cutlass) | sparse (cutlass) | sparse (cutlass) | sparse (cutlass) | 1274.96 | 12.70 | 25.10 |

#### BERT

We were able to accelerate BERT 1.23x on an A100 with a negligible accuracy drop on SQuAD.
For more information about accelerting BERT with semi-sturcutred sparsity, please see our [tutorial](https://pytorch.org/tutorials/advanced/semi_structured_sparse.html?highlight=beta).

| Metrics | fp16 | 2:4 sparse | delta / speedup |
| --- | --- | --- | --- |
| Exact Match (%) | 78.53 | 78.44 | -0.09 |
| F1 (%) | 86.93 | 86.49 | -0.44 |
| Time (bs=16) | 19.35 | 15.74 | 1.23x |


# Design

Sparsity, like quantization, is an accuracy/performance trade-off, where we care not only about the speedup but also on the accuracy degradation of our architecture optimization technique.
Expand Down

0 comments on commit 07efae2

Please sign in to comment.