From b16f0dc5e4b3534cff3cc5b19bd2ea1ee9d70d80 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 14 Aug 2024 21:42:37 -0600 Subject: [PATCH] Add BSR subclass +torch.compile and clean up superblock (#680) This PR adds in torch.compile support for block sparsity. In a custom op, we create the `sprase_bsr_tensor` from the explicit `crow_indices, col_indices, values` tensors that are passed in to the custom op. I also created a tensor subclass which holds these same values. At dispatch, when we see a `torch.nn.functional.linear` call, we dispatch into our custom op `torch.ops.blocksparse.linear`, using the tensors stored in the subclass. This will allow us to add a public API similar to `semi_sparse_weight()`, which I plan to do in a future PR. This PR also cleans up the superblock prototype implementation, as there was a lot of repeated code, and also adds in kernel tuning for BSR. For bfloat16 I see the following numbers, for a 1.23x gain: ``` New compile baseline: 63.431 ms New compile + bsr: 53.514 ms New compile + bsr + tuning: 51.485 ms ``` --- .../sparsity/prototype/superblock/.gitignore | 3 + .../sparsity/prototype/superblock/__init__.py | 0 .../prototype/superblock/benchmark.py | 58 +++----- .../prototype/superblock/blocksparse.py | 138 ++++++++++++++++++ .../sparsity/prototype/superblock/evaluate.py | 51 ++----- .../prototype/superblock/supermask.py | 29 +--- .../sparsity/prototype/superblock/utils.py | 19 +++ 7 files changed, 195 insertions(+), 103 deletions(-) create mode 100644 torchao/sparsity/prototype/superblock/__init__.py create mode 100644 torchao/sparsity/prototype/superblock/blocksparse.py diff --git a/torchao/sparsity/prototype/superblock/.gitignore b/torchao/sparsity/prototype/superblock/.gitignore index cf2b7c4b2..dd0446104 100644 --- a/torchao/sparsity/prototype/superblock/.gitignore +++ b/torchao/sparsity/prototype/superblock/.gitignore @@ -1,5 +1,8 @@ */*.pyc +# Model checkpoints +*.pth + # Editor temporaries *.swa *.swb diff --git a/torchao/sparsity/prototype/superblock/__init__.py b/torchao/sparsity/prototype/superblock/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchao/sparsity/prototype/superblock/benchmark.py b/torchao/sparsity/prototype/superblock/benchmark.py index 65d16c91a..d849fc3d3 100644 --- a/torchao/sparsity/prototype/superblock/benchmark.py +++ b/torchao/sparsity/prototype/superblock/benchmark.py @@ -12,9 +12,12 @@ import torch.utils.data import utils from torch import nn +from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from supermask import apply_supermask, SupermaskLinear +from blocksparse import BlockSparseTensor +from utils import benchmark_inference def apply_sparsity(model): @@ -25,20 +28,12 @@ def apply_sparsity(model): def apply_bsr(model, blocksize): for name, module in model.named_modules(): - if isinstance(module, torch.nn.Linear) and "mlp" in name: - try: - module.weight = torch.nn.Parameter(to_bsr(module.weight.data, blocksize)) - print(f"Converted {name} to bsr format.") - except ValueError as e: - print(f"Unable to convert weight of {name} to bsr format: {e}") - - -def to_bsr(tensor, blocksize): - if tensor.ndim != 2: - raise ValueError("to_bsr expects 2D tensor") - if tensor.size(0) % blocksize or tensor.size(1) % blocksize: - raise ValueError("Tensor dimensions must be divisible by blocksize") - return tensor.to_sparse_bsr(blocksize) + if isinstance(module, torch.nn.Linear) and "mlp" in name: + try: + module.weight = torch.nn.Parameter(BlockSparseTensor.from_dense(module.weight.data, blocksize)) + print(f"Converted {name} to bsr format.") + except ValueError as e: + print(f"Unable to convert weight of {name} to bsr format: {e}") def verify_sparsity(model): @@ -49,23 +44,7 @@ def verify_sparsity(model): sparsity_percentage = (sparse_weights / total_weights) * 100 print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%") - -def benchmark_in_ms(warmup, iters, f, *args, **kwargs): - for _ in range(warmup): - f(*args, **kwargs) - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - - for _ in range(iters): - f(*args, **kwargs) - - end_event.record() - torch.cuda.synchronize() - return start_event.elapsed_time(end_event) / float(iters) - - +@torch.inference_mode def main(args): print(args) device = torch.device(args.device) @@ -83,8 +62,11 @@ def main(args): print("Using float16") dtype = torch.float16 - # Sample input - # input = torch.rand(32, 3, 224, 224, dtype=dtype).to(device) + if args.bsr and args.tune_kernel_params: + print("Tuning kernel params") + assert args.model == "vit_b_16", "--tune-kernel-params only supported for vit-b-16!" + optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True) + optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True) print("Creating model") model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) @@ -112,7 +94,6 @@ def main(args): raise FileNotFoundError(f"No checkpoint found at {args.weights_path}.") model.to(device) - # output0 = model(input) if args.sparsify_weights: apply_sparsity(model) @@ -134,9 +115,11 @@ def main(args): # output2 = model(input) # assert torch.allclose(output2, output1), "Output of model before and after changing format to BSR should be equal" + model = torch.compile(model, mode='max-autotune') + image = torch.empty(args.batch_size, 3, args.val_crop_size, args.val_crop_size, dtype=dtype, device=device) - # model = torch.compile(model, mode='max-autotune') - return benchmark_in_ms(10, 100, model, image) + + return benchmark_inference(10, 100, model, image) def get_args_parser(add_help=True): @@ -169,6 +152,7 @@ def get_args_parser(add_help=True): parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') parser.add_argument("--bfloat16", action="store_true", help="Use bfloat16") parser.add_argument("--float16", action="store_true", help="Use float16") + parser.add_argument("--tune-kernel-params", action="store_true", help="Tune kernel params") return parser @@ -176,4 +160,4 @@ def get_args_parser(add_help=True): if __name__ == "__main__": args = get_args_parser().parse_args() result = main(args) - print(f"{result} ms", file=sys.stderr) + print(f"{result:.3f} ms", file=sys.stderr) diff --git a/torchao/sparsity/prototype/superblock/blocksparse.py b/torchao/sparsity/prototype/superblock/blocksparse.py new file mode 100644 index 000000000..b57ed5635 --- /dev/null +++ b/torchao/sparsity/prototype/superblock/blocksparse.py @@ -0,0 +1,138 @@ +import torch +from typing import Optional, Tuple, List, Dict, Any, Callable +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.utils import ( + _implements, + _dispatch__torch_function__, + _dispatch__torch_dispatch__, +) +aten = torch.ops.aten + +# bsr wrapper custom op +@torch.library.custom_op("blocksparse::linear", mutates_args=()) +def blocksparse_linear(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K: int, bias: torch.Tensor) -> torch.Tensor: + weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) + return torch.nn.functional.linear(A, weight_bsr, bias) + +@torch.library.register_fake("blocksparse::linear") +def blocksparse_linear_abstract(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K:int , bias: torch.Tensor) -> torch.Tensor: + new_shape = A.shape[:-1] + (bias.shape[0],) + return torch.empty(new_shape, dtype=A.dtype, device=A.device) + +# Subclass definition +class BlockSparseTensor(torch.Tensor): + bsr_crow_indices: Optional[torch.Tensor] + bsr_col_indices: Optional[torch.Tensor] + bsr_values: Optional[torch.Tensor] + + __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values"] + + implements = classmethod(_implements) + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) + __torch_function__ = classmethod(_dispatch__torch_function__) + + @staticmethod + def __new__( # noqa: PYI034 + cls, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + requires_grad: bool = False, + ): + if bsr_values is None: + raise ValueError("bsr values must be provided!") + else: + previous_tensor = bsr_values + + kwargs = { + "device": previous_tensor.device, + "dtype": previous_tensor.dtype, + "layout": previous_tensor.layout, + "requires_grad": requires_grad, + } + tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + tensor.bsr_crow_indices = bsr_crow_indices + tensor.bsr_col_indices = bsr_col_indices + tensor.bsr_values = bsr_values + return tensor + + def __repr__(self) -> str: # type: ignore[override] + assert hasattr(self, "shape") + return f"{self.__class__.__name__}(shape={self.shape})" + + def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool]]: + inner_tensors = list( + filter(lambda x: getattr(self, x) is not None, self.__slots__) + ) + tensor_meta = (self.shape, self.requires_grad) + return inner_tensors, tensor_meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors, + tensor_meta: Tuple[torch.Size, bool], + outer_size, + outer_stride, + ) -> torch.Tensor: + shape, requires_grad = tensor_meta + return cls( + shape=shape, + bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), + bsr_col_indices=inner_tensors.get("bsr_col_indices", None), + bsr_values=inner_tensors.get("bsr_values", None), + requires_grad=requires_grad, + ) + + @classmethod + def from_dense(cls, dense_tensor, blocksize): + bsr_tensor = dense_tensor.to_sparse_bsr(blocksize) + return cls( + shape=dense_tensor.shape, + bsr_crow_indices=bsr_tensor.crow_indices(), + bsr_col_indices=bsr_tensor.col_indices(), + bsr_values=bsr_tensor.values(), + requires_grad=False, + ) + + def apply_fn_to_shard(self, func): + return BlockSparseTensor( + shape = self.shape, + bsr_crow_indices=func(self.bsr_crow_indices), + bsr_col_indices=func(self.bsr_col_indices), + bsr_values=func(self.bsr_values), + requires_grad=self.requires_grad, + ) + +# Subclass op dispatch registration +implements = BlockSparseTensor.implements + +@implements(aten.detach.default) +def block_sparse_detach(func, types, args, kwargs): + return return_and_correct_aliasing(func, args, kwargs, args[0].apply_fn_to_shard(torch.detach)) + +@implements(aten.values.default) +def block_sparse_values(func, types, args, kwargs): + return args[0].bsr_values.detach() + +@implements(aten.crow_indices.default) +def block_sparse_crow_indices(func, types, args, kwargs): + return args[0].bsr_crow_indices.detach() + +@implements(aten.col_indices.default) +def block_sparse_col_indices(func, types, args, kwargs): + return args[0].bsr_col_indices.detach() + +@implements(aten._nnz.default) +def block_sparse__nnz(func, types, args, kwargs): + return args[0].bsr_values.shape[0] + +@implements(torch.nn.functional.linear) +def block_sparse_linear(func, types, args, kwargs): + x, w, bias = args + return torch.ops.blocksparse.linear(x, + w.crow_indices(), + w.col_indices(), + w.values(), + w.shape[0], w.shape[1], bias) diff --git a/torchao/sparsity/prototype/superblock/evaluate.py b/torchao/sparsity/prototype/superblock/evaluate.py index 23e825f65..5b1542cd1 100644 --- a/torchao/sparsity/prototype/superblock/evaluate.py +++ b/torchao/sparsity/prototype/superblock/evaluate.py @@ -15,40 +15,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from supermask import apply_supermask, SupermaskLinear - - -def apply_sparsity(model): - for name, module in model.named_modules(): - if isinstance(module, SupermaskLinear) and "mlp" in name: - module.sparsify_offline() - - -def apply_bsr(model): - for name, module in model.named_modules(): - if isinstance(module, torch.nn.Linear) and "mlp" in name: - try: - module.weight = torch.nn.Parameter(to_bsr(module.weight.data, args.bsr)) - print(f"Converted {name} to bsr format.") - except ValueError as e: - print(f"Unable to convert weight of {name} to bsr format: {e}") - - -def to_bsr(tensor, blocksize): - if tensor.ndim != 2: - raise ValueError("to_bsr expects 2D tensor") - if tensor.size(0) % blocksize or tensor.size(1) % blocksize: - raise ValueError("Tensor dimensions must be divisible by blocksize") - return tensor.to_sparse_bsr(blocksize) - - -def verify_sparsity(model): - for name, module in model.named_modules(): - if isinstance(module, nn.Linear): - total_weights = module.weight.numel() - sparse_weights = (module.weight == 0).sum().item() - sparsity_percentage = (sparse_weights / total_weights) * 100 - print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%") - +from benchmark import apply_sparsity, apply_bsr, verify_sparsity def _get_cache_path(filepath): h = hashlib.sha1(filepath.encode()).hexdigest() @@ -82,16 +49,16 @@ def load_data(valdir, args): ) # for META internal - dataset_test = torchvision.datasets.ImageFolder( - valdir, - preprocessing, - ) - # for OSS - # dataset_test = torchvision.datasets.ImageNet( + # dataset_test = torchvision.datasets.ImageFolder( # valdir, - # split='val', - # transform=preprocessing + # preprocessing, # ) + #for OSS + dataset_test = torchvision.datasets.ImageNet( + valdir, + split='val', + transform=preprocessing + ) if args.cache_dataset: print(f"Saving dataset_test to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) diff --git a/torchao/sparsity/prototype/superblock/supermask.py b/torchao/sparsity/prototype/superblock/supermask.py index 6c2a314f7..e3cf2c6c9 100644 --- a/torchao/sparsity/prototype/superblock/supermask.py +++ b/torchao/sparsity/prototype/superblock/supermask.py @@ -7,7 +7,6 @@ import torch.nn.functional as F import numpy as np - # original supermask scores_min=None scores_max=9e9 @@ -21,34 +20,16 @@ def percentile(t, q): """Return the value that is larger than q% of t""" k = 1 + round(.01 * float(q) * (t.numel() - 1)) - return t.view(-1).kthvalue(k).values.item() - - -def to_bsr(tensor, blocksize=256): - if tensor.ndim != 2: - print("Tensor is not 2D, skipping BSR conversion.") - return tensor - - if tensor.size(0) % blocksize or tensor.size(1) % blocksize: - print("Tensor dimensions are not divisible by blocksize, skipping BSR conversion.") - return tensor - - try: - converted_tensor = tensor.to_sparse_bsr(blocksize=blocksize) - print(f"Converted tensor to BSR format with blocksize: {blocksize}") - return converted_tensor - except ValueError as e: - print(f"Unable to convert tensor to BSR format: {e}") - return tensor + return t.view(-1).kthvalue(k).values class GetSubnet(torch.autograd.Function): """Supermask STE function""" @staticmethod def forward(ctx, scores, zeros, ones, sparsity): - scores.clamp_(min=scores_min,max=scores_max) - k_val = percentile(scores, sparsity*100) - return torch.where(scores < k_val, zeros.to(scores.device), ones.to(scores.device)) + clamped_scores = scores.clamp(min=scores_min,max=scores_max) + k_val = percentile(clamped_scores, sparsity*100) + return torch.where(clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device)) @staticmethod def backward(ctx, g): return g, None, None, None @@ -130,7 +111,7 @@ def forward(self, x): subnet = self.get_mask() w = (self.weight*self.scale+self.shift) * subnet else: - w = self.weight.data + w = self.weight return F.linear(x, w, self.bias) diff --git a/torchao/sparsity/prototype/superblock/utils.py b/torchao/sparsity/prototype/superblock/utils.py index 8f4a5a8ed..f71ef389c 100644 --- a/torchao/sparsity/prototype/superblock/utils.py +++ b/torchao/sparsity/prototype/superblock/utils.py @@ -12,6 +12,25 @@ import torch import torch.distributed as dist +### IMAGENET UTILS +@torch.inference_mode +def benchmark_inference(warmup, iters, f, *args, **kwargs): + for _ in range(warmup): + f(*args, **kwargs) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + + for _ in range(iters): + f(*args, **kwargs) + + end_event.record() + torch.cuda.synchronize() + return start_event.elapsed_time(end_event) / float(iters) + + class SmoothedValue: """Track a series of values and provide access to smoothed values over a