diff --git a/benchmarks/benchmark_sam.py b/benchmarks/benchmark_sam.py new file mode 100644 index 000000000..14c2d8bc5 --- /dev/null +++ b/benchmarks/benchmark_sam.py @@ -0,0 +1,129 @@ +import pandas as pd +import torch +from segment_anything import sam_model_registry +from torch.utils.benchmark import Timer +from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT +from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, + _get_subclass_inserter, + _is_linear, + QuantizedLinearWeightBase, + Int8DynamicallyQuantizedLinearWeight, +) +from torchao.quantization import change_linear_weights_to_int8_dqtensors +from torchao.sparsity import ( + apply_sparse_semi_structured, + apply_fake_sparsity, +) +from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight +from itertools import product +from tqdm import tqdm + +sam_checkpoint_base_path = "/home/jessecai/local/MODELS" +model_type = 'vit_h' +model_name = 'sam_vit_h_4b8939.pth' +checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}" + +torch._inductor.config.epilogue_fusion = True +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.coordinate_descent_check_all_directions = True +torch._inductor.config.force_fuse_int_mm_with_mul = True + +@torch.no_grad() +def benchmark(f, *args, **kwargs): + for _ in range(3): + f(*args, **kwargs) + torch.cuda.synchronize() + + torch.cuda.reset_peak_memory_stats() + t0 = Timer( + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} + ) + res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20) + return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9} + +def get_sam_model(only_one_block=False, batchsize=1): + sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda() + model = sam.image_encoder.eval() + image = torch.randn(batchsize, 3, 1024, 1024, device='cuda') + + # code to use just a single block of the model + if only_one_block: + model = model.blocks[0] + image = torch.randn(batchsize, 64, 64, 1280, device='cuda') + return model, image + +def qkv_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'qkv' in name + +def proj_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'proj' in name + +def lin1_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'lin1' in name + +def lin2_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'lin2' in name + +SUBCLASSES = { + "quant" : Int8DynamicallyQuantizedLinearWeight, + "quant+sparse (cutlass)" : Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight, + "quant+sparse (cusparselt)" : Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, + "sparse (cutlass)" : SparseSemiStructuredTensorCUTLASS, + "sparse (cusparselt)" : SparseSemiStructuredTensorCUSPARSELT, +} + +def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True, qkv=None, proj=None, lin1=None, lin2=None): + res = { + "block_only": block_only, + "batchsize": batchsize, + "dtype": dtype, + "compile": compile, + "qkv" : qkv, + "proj": proj, + "lin1": lin1, + "lin2": lin2, + } + with torch.no_grad(): + model, image = get_sam_model(block_only, batchsize) + model = model.to(dtype) + image = image.to(dtype) + + # 2:4 prune model + apply_fake_sparsity(model) + option_and_filter_fn = zip([qkv, proj, lin1, lin2], [qkv_only, proj_only, lin1_only, lin2_only]) + + for option, filter_fn in option_and_filter_fn: + subclass = SUBCLASSES.get(option, None) + if subclass and issubclass(subclass, SparseSemiStructuredTensor): + # replace with to_sparse_semi_structured + for name, mod in model.named_modules(): + if filter_fn(mod, name): + mod.weight = torch.nn.Parameter(subclass.from_dense(mod.weight)) + elif subclass and issubclass(subclass, QuantizedLinearWeightBase): + _replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(subclass), filter_fn) + + if compile: + model = torch.compile(model, mode='max-autotune') + + res.update(benchmark(model, image)) + res["img/s"] = 1 / (res['time'] / 1000 / res['batchsize']) + return res + +if __name__ == "__main__": + print("BENCHMARKING") + ALL_RUNS = [run_once(qkv="quant+sparse (cutlass)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")] + # for option in tqdm(SUBCLASSES)] + # ALL_RUNS = [ + # run_once(), + # run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"), + # run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"), + # run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), + # run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"), + # run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"), + # run_once(qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"), + # run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), + # ] + df = pd.DataFrame(ALL_RUNS) + df.to_csv("sam_benchmark_results.csv") + print(df) diff --git a/benchmarks/sam_benchmark_results.csv b/benchmarks/sam_benchmark_results.csv new file mode 100644 index 000000000..7cfb27faf --- /dev/null +++ b/benchmarks/sam_benchmark_results.csv @@ -0,0 +1,5 @@ +,block_only,batchsize,dtype,compile,qkv,proj,lin1,lin2,time,memory,img/s +0,False,32,torch.bfloat16,True,,,,,1457.0417301729321,28.280423936,21.96230851686177 +1,False,32,torch.bfloat16,True,quant,quant,quant,quant,1318.5919532552361,28.261341696,24.268311300551254 +2,False,32,torch.bfloat16,True,quant+sparse (cusparselt),quant,quant+sparse (cutlass),quant+sparse (cutlass),1253.1237555667758,28.18694656,25.536184960061433 +3,False,32,torch.bfloat16,True,quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),1290.4946617782116,27.837008896,24.796693041648258 diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py new file mode 100644 index 000000000..83c0544f6 --- /dev/null +++ b/test/sparsity/test_sparse_api.py @@ -0,0 +1,71 @@ +import logging +import unittest + +import torch +from torch import nn + +from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured +from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight +from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, + _get_subclass_inserter, + _is_linear, +) +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 +from torch.testing._internal.common_utils import TestCase + + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + +class TestSemiStructuredSparse(TestCase): + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_sparse(self): + input = torch.rand((128, 128)).half().cuda() + model = ( + nn.Sequential( + nn.Linear(128, 256), + nn.Linear(256, 128), + ) + .half() + .cuda() + ) + + apply_fake_sparsity(model) + dense_result = model(input) + + apply_sparse_semi_structured(model) + sparse_result = model(input) + + assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + + +class TestQuantSemiSparse(TestCase): + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_quant_semi_sparse(self): + input = torch.rand((128, 128)).half().cuda() + model = ( + nn.Sequential( + nn.Linear(128, 256), + nn.Linear(256, 128), + ) + .half() + .cuda() + ) + + apply_fake_sparsity(model) + dense_result = model(input) + + _replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear) + sparse_result = model(input) + + assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index f4e2838f4..6621d086d 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -6,8 +6,11 @@ from .wanda import WandaSparsifier # noqa: F403 from .utils import PerChannelNormObserver # noqa: F403 +from .sparse_api import apply_sparse_semi_structured, apply_fake_sparsity __all__ = [ "WandaSparsifier", - "PerChannelNormObserver" + "PerChannelNormObserver", + "apply_sparse_semi_structured", + "apply_fake_sparsity", ] diff --git a/torchao/sparsity/prototype/dynamic_quant_sparse.py b/torchao/sparsity/prototype/dynamic_quant_sparse.py new file mode 100644 index 000000000..aced0b82c --- /dev/null +++ b/torchao/sparsity/prototype/dynamic_quant_sparse.py @@ -0,0 +1,300 @@ +import torch +import torch.nn as nn +from typing import Tuple, Optional + +from torchao.quantization.quant_primitives import ( + dynamically_quantize_per_channel, + quant_int8_dynamic_per_token_linear, + quantize_activation_per_token_absmax, + dequantize_per_channel, +) + +from torchao.quantization.subclass import ( + Int8DynamicallyQuantizedLinearWeight, + QuantizedLinearWeightBase, +) + +from torch.sparse import to_sparse_semi_structured + +# Quant + Sparse helper functinos +def sparse_quant_int8_dynamic_linear( + x : torch.Tensor, + w_vals_int8_packed : torch.Tensor, + w_meta_int32 : Optional[torch.Tensor], + w_scales : torch.Tensor, + bias : Optional[torch.Tensor], + out_dtype : torch.dtype, + fuse_mul=False, +): + x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) + # w_meta_int32 is either None or meta tensor + if w_meta_int32 is None: + if fuse_mul: + mm_out = sparse_quant_int8_cslt_matmul_fuse_mul( + x_vals_int8, x_scales, w_vals_int8_packed, w_scales, out_dtype, + ) + else: + mm_out = sparse_quant_int8_cslt_matmul( + x_vals_int8, x_scales, w_vals_int8_packed, w_scales, out_dtype, + ) + else: + mm_out = sparse_quant_int8_cutlass_matmul( + x_vals_int8, x_scales, w_vals_int8_packed, w_meta_int32, w_scales, out_dtype, + ) + + if bias is not None: + mm_out += bias + return mm_out + +def sparse_quant_int8_cslt_matmul_fuse_mul( + x_vals_int8, + x_scales, + w_vals_int8, + w_scales, + out_dtype, +): + + assert ( + x_vals_int8.dtype == torch.int8 + ), f"x dtype {x_vals_int8.dtype} not yet supported" + assert ( + w_vals_int8.dtype == torch.int8 + ), f"w dtype {w_vals_int8.dtype} not yet supported" + # assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' + + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() + + assert x_scales.dtype in [ + torch.float, + torch.bfloat16, + ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" + + y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( + w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16 + ).t() + y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( + *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] + ) + y = y.to(out_dtype) + + return y + +def sparse_quant_int8_cslt_matmul( + x_vals_int8, + x_scales, + w_vals_int8, + w_scales, + out_dtype, +): + + assert ( + x_vals_int8.dtype == torch.int8 + ), f"x dtype {x_vals_int8.dtype} not yet supported" + assert ( + w_vals_int8.dtype == torch.int8 + ), f"w dtype {w_vals_int8.dtype} not yet supported" + # assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' + + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() + + assert x_scales.dtype in [ + torch.float, + torch.bfloat16, + ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" + + y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( + w_vals_int8, tmp.t(), out_dtype=torch.bfloat16 + ).t() + y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1) * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] + ) + y = y.to(out_dtype) + + return y + + +def sparse_quant_int8_cutlass_matmul( + x_vals_int8, + x_scales, + w_vals_int8, + w_meta_int32, + w_scales, + out_dtype, +): + assert ( + x_vals_int8.dtype == torch.int8 + ), f"x dtype {x_vals_int8.dtype} not yet supported" + assert ( + w_vals_int8.dtype == torch.int8 + ), f"w dtype {w_vals_int8.dtype} not yet supported" + assert w_scales.dtype == out_dtype, f"{w_scales.dtype} does not match {out_dtype}" + assert w_meta_int32.dtype == torch.int32, f"{w_meta_int32.dtype} not yet supported" + + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() + + assert x_scales.dtype in [ + torch.float, + torch.bfloat16, + ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" + + y_dot_int32 = torch._sparse_semi_structured_linear( + tmp, w_vals_int8, w_meta_int32.view(torch.int32), out_dtype=torch.int32 + ) + y = (y_dot_int32 * x_scales.reshape(-1, 1) * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_int32.shape[-1] + ) + y = y.to(out_dtype) + return y + +class Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight( + Int8DynamicallyQuantizedLinearWeight +): + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + return sparse_quant_int8_dynamic_linear( + act_mat, w_qtensor.int_data, None, w_qtensor.q_scales, bias, act_mat.dtype, + fuse_mul=True + ) + + @classmethod + def from_float(cls, input_float, qmin=-8, qmax=7): + + assert input_float.is_cuda + + w_int_repr, w_scales, _ = dynamically_quantize_per_channel( + input_float, qmin, qmax, torch.int8 + ) + + int_data = w_int_repr.contiguous() + int_data = torch._cslt_compress(int_data) + + return cls( + int_data, + w_scales, + False, + input_float.shape, + dtype=input_float.dtype, + ) + + +class Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight(QuantizedLinearWeightBase): + + @staticmethod + def __new__(cls, int_data, mask_meta, q_scales, transposed, shape, **kwargs): + kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype) + return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, int_data, mask_meta, q_scales, transposed, shape, **kwargs): + self.q_scales = q_scales + self.mask_meta = mask_meta + super().__init__(int_data, transposed) + + def dequantize(self, dtype=None): + """ + Obtain the dequantized version of the quantized tensor subclass + """ + dq_t = dequantize_per_channel( + self.int_data, self.q_scales, 0, self.dtype if dtype is None else dtype + ).to(self.dtype) + # data was transposed to dequantize so make sure shape is correct + return dq_t if not self.transposed else dq_t.t() + + def int_repr(self): + """ + Get the internal integer representation of the quantized tensor + """ + return self.int_data if self.transposed else self.int_data.t() + + def q_params(self): + """ + Get the quantization scales for the quantized tensor + """ + return {"q_scales": self.q_scales} + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.int_data.to(kwargs["device"]), + self.mask_meta.to(kwargs["device"]), + self.q_scales.to(kwargs["device"]), + self.transposed, + self.shape, + **kwargs, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.mask_meta), + fn(self.q_scales), + self.transposed, + self.shape, + dtype=self.dtype, + ) + + def _change_shape(self, shape): + return self.__class__( + self.int_data, + self.mask_meta, + self.q_scales, + self.transposed, + shape, + dtype=self.dtype, + ) + + def __tensor_flatten__(self): + return ["int_data", "mask_meta", "q_scales"], [ + self.transposed, + self.dtype, + self.shape, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None + ): + int_data, q_scales = tensor_data_dict["int_data"], tensor_data_dict["q_scales"] + mask_meta = tensor_data_dict["mask_meta"] + transposed, dtype, shape = tensor_attributes + return cls( + int_data, + mask_meta, + q_scales, + transposed, + shape if outer_size is None else outer_size, + dtype=dtype, + strides=outer_stride, + ) + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + return sparse_quant_int8_dynamic_linear( + act_mat, + w_qtensor.int_data, + w_qtensor.mask_meta, + w_qtensor.q_scales, + bias, + act_mat.dtype, + ) + + @classmethod + def from_float(cls, input_float, qmin=-128, qmax=127): + + assert input_float.is_cuda + + w_int_repr, w_scales, _ = dynamically_quantize_per_channel( + input_float, qmin, qmax, torch.int8 + ) + + int_data = w_int_repr.contiguous() + sparse_tensor = to_sparse_semi_structured(int_data) + + return cls( + sparse_tensor.packed, + sparse_tensor.meta, + w_scales, + False, + input_float.shape, + dtype=input_float.dtype, + ) diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py new file mode 100644 index 000000000..90e35a412 --- /dev/null +++ b/torchao/sparsity/sparse_api.py @@ -0,0 +1,32 @@ +import torch +from torch.ao.pruning import WeightNormSparsifier +from torch.sparse import to_sparse_semi_structured +from torchao.quantization.quant_api import _is_linear + +# Sparsity helper functions +def apply_fake_sparsity(model): + """ + This function simulates 2:4 sparsity on all linear layers in a model. + It uses the torch.ao.pruning flow. + """ + # torch.ao.pruning flow + sparse_config = [] + for name, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear): + sparse_config.append({"tensor_fqn": f"{name}.weight"}) + + sparsifier = WeightNormSparsifier( + sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2 + ) + sparsifier.prepare(model, sparse_config) + sparsifier.step() + sparsifier.squash_mask() + + +def apply_sparse_semi_structured(model, **kwargs): + filter_fn = kwargs.pop("filter_fn", _is_linear) + + apply_fake_sparsity(model) + for name, mod in model.named_modules(): + if filter_fn(mod, name): + mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))