Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add semi-structured sparse + dynamic int8 subclasses #36

Merged
merged 33 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b3b2b6d
wip
jcaip Feb 6, 2024
5cc766e
test
jcaip Mar 21, 2024
cd785b5
wip
jcaip Mar 27, 2024
d084b91
wip
jcaip Mar 29, 2024
646a157
wip
jcaip Mar 29, 2024
5f0d450
test
jcaip Mar 29, 2024
f1a947f
refactor
jcaip Mar 29, 2024
af1bbbe
fix quant api
jcaip Mar 29, 2024
e193f2f
wip
jcaip Mar 29, 2024
db7d98d
updated script and cleaned up api
jcaip Mar 29, 2024
f9e3449
update
jcaip Mar 29, 2024
65e290c
wip
jcaip Apr 1, 2024
31706d5
formatted api and added per-linear tuning to script
jcaip Apr 1, 2024
0ce93c8
update
jcaip Apr 1, 2024
a7d5359
clean up imports
jcaip Apr 1, 2024
4648202
Merge branch 'main' into jcaip/quant+sparse_subclasses
jcaip Apr 19, 2024
3b89696
updated files
jcaip Apr 22, 2024
a7b4f8b
remove file
jcaip Apr 22, 2024
dafcd63
remove fuse mul API
jcaip Apr 23, 2024
070d773
add test
jcaip Apr 23, 2024
0a8e226
move to prototype
jcaip Apr 24, 2024
be6b387
Merge branch 'main' into jcaip/quant+sparse_subclasses
jcaip Apr 24, 2024
4e0c8b3
fix tests
jcaip Apr 24, 2024
db9ca9c
fix test
jcaip Apr 24, 2024
7a9b6f9
added init
jcaip Apr 25, 2024
fb3beb6
updated test
jcaip Apr 25, 2024
1a1edcc
fix test
jcaip Apr 25, 2024
0523042
Merge branch 'main' into jcaip/quant+sparse_subclasses
msaroufim Apr 26, 2024
461a306
skip on pt 2.2
jcaip Apr 26, 2024
5f787c0
Merge branch 'jcaip/quant+sparse_subclasses' of github.com:pytorch-la…
jcaip Apr 26, 2024
5962825
Merge branch 'main' into jcaip/quant+sparse_subclasses
jcaip Apr 26, 2024
ddc5dea
typo
jcaip Apr 26, 2024
a5f188a
Merge branch 'jcaip/quant+sparse_subclasses' of github.com:pytorch-la…
jcaip Apr 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions benchmarks/benchmark_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import pandas as pd
import torch
from segment_anything import sam_model_registry
from torch.utils.benchmark import Timer
from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
_is_linear,
QuantizedLinearWeightBase,
Int8DynamicallyQuantizedLinearWeight,
)
from torchao.quantization import change_linear_weights_to_int8_dqtensors
from torchao.sparsity import (
apply_sparse_semi_structured,
apply_fake_sparsity,
)
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
from itertools import product
from tqdm import tqdm

sam_checkpoint_base_path = "/home/jessecai/local/MODELS"
model_type = 'vit_h'
model_name = 'sam_vit_h_4b8939.pth'
checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}"

torch._inductor.config.epilogue_fusion = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True

@torch.no_grad()
def benchmark(f, *args, **kwargs):
for _ in range(3):
f(*args, **kwargs)
torch.cuda.synchronize()

torch.cuda.reset_peak_memory_stats()
t0 = Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20)
return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}

def get_sam_model(only_one_block=False, batchsize=1):
sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda()
model = sam.image_encoder.eval()
image = torch.randn(batchsize, 3, 1024, 1024, device='cuda')

# code to use just a single block of the model
if only_one_block:
model = model.blocks[0]
image = torch.randn(batchsize, 64, 64, 1280, device='cuda')
return model, image

def qkv_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'qkv' in name

def proj_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'proj' in name

def lin1_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin1' in name

def lin2_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin2' in name

SUBCLASSES = {
"quant" : Int8DynamicallyQuantizedLinearWeight,
"quant+sparse (cutlass)" : Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight,
"quant+sparse (cusparselt)" : Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight,
"sparse (cutlass)" : SparseSemiStructuredTensorCUTLASS,
"sparse (cusparselt)" : SparseSemiStructuredTensorCUSPARSELT,
}

def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True, qkv=None, proj=None, lin1=None, lin2=None):
res = {
"block_only": block_only,
"batchsize": batchsize,
"dtype": dtype,
"compile": compile,
"qkv" : qkv,
"proj": proj,
"lin1": lin1,
"lin2": lin2,
}
with torch.no_grad():
model, image = get_sam_model(block_only, batchsize)
model = model.to(dtype)
image = image.to(dtype)

# 2:4 prune model
apply_fake_sparsity(model)
option_and_filter_fn = zip([qkv, proj, lin1, lin2], [qkv_only, proj_only, lin1_only, lin2_only])

for option, filter_fn in option_and_filter_fn:
subclass = SUBCLASSES.get(option, None)
if subclass and issubclass(subclass, SparseSemiStructuredTensor):
# replace with to_sparse_semi_structured
for name, mod in model.named_modules():
if filter_fn(mod, name):
mod.weight = torch.nn.Parameter(subclass.from_dense(mod.weight))
elif subclass and issubclass(subclass, QuantizedLinearWeightBase):
_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(subclass), filter_fn)

if compile:
model = torch.compile(model, mode='max-autotune')

res.update(benchmark(model, image))
res["img/s"] = 1 / (res['time'] / 1000 / res['batchsize'])
return res

if __name__ == "__main__":
print("BENCHMARKING")
ALL_RUNS = [run_once(qkv="quant+sparse (cutlass)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")]
# for option in tqdm(SUBCLASSES)]
# ALL_RUNS = [
# run_once(),
# run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"),
# run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"),
# run_once(qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"),
# run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
# ]
df = pd.DataFrame(ALL_RUNS)
df.to_csv("sam_benchmark_results.csv")
print(df)
5 changes: 5 additions & 0 deletions benchmarks/sam_benchmark_results.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
,block_only,batchsize,dtype,compile,qkv,proj,lin1,lin2,time,memory,img/s
0,False,32,torch.bfloat16,True,,,,,1457.0417301729321,28.280423936,21.96230851686177
1,False,32,torch.bfloat16,True,quant,quant,quant,quant,1318.5919532552361,28.261341696,24.268311300551254
2,False,32,torch.bfloat16,True,quant+sparse (cusparselt),quant,quant+sparse (cutlass),quant+sparse (cutlass),1253.1237555667758,28.18694656,25.536184960061433
3,False,32,torch.bfloat16,True,quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),1290.4946617782116,27.837008896,24.796693041648258
65 changes: 65 additions & 0 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import logging
import unittest

import torch
from torch import nn

from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
_is_linear,
)
from torch.testing._internal.common_utils import TestCase


logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)

class TestSemiStructuredSparse(TestCase):

def test_sparse(self):
input = torch.rand((128, 128), device="cuda").half()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.half()
.cuda()
)

apply_fake_sparsity(model)
dense_result = model(input)

apply_sparse_semi_structured(model)
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)


class TestQuantSemiSparse(TestCase):

def test_quant_semi_sparse(self):
input = torch.rand((128, 128), device="cuda").half()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.half()
.cuda()
)

apply_fake_sparsity(model)
dense_result = model(input)

_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear)
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1)


if __name__ == "__main__":
unittest.main()
5 changes: 4 additions & 1 deletion torchao/sparsity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

from .wanda import WandaSparsifier # noqa: F403
from .utils import PerChannelNormObserver # noqa: F403
from .sparse_api import apply_sparse_semi_structured, apply_fake_sparsity

__all__ = [
"WandaSparsifier",
"PerChannelNormObserver"
"PerChannelNormObserver",
"apply_sparse_semi_structured",
"apply_fake_sparsity",
]
Loading
Loading