diff --git a/torchao/sparsity/prototype/superblock/README.md b/torchao/sparsity/prototype/superblock/README.md index d08e8cc6e..d449d50e2 100644 --- a/torchao/sparsity/prototype/superblock/README.md +++ b/torchao/sparsity/prototype/superblock/README.md @@ -38,7 +38,7 @@ At least one GPU: ``` * Install PyTorch. For best performance, we recommend the pytorch nightlies ``` - + pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 ``` We ran our experiments with torch==2.6.0.dev20240924+cu124 @@ -46,7 +46,11 @@ At least one GPU: ## Benchmarking For all our benchmarking results, you can run `benchmark.sh`. This will run benchmarks with random weights, only testing speedup. -Please use the evaluation script to validate accuracy. + +These benchmarks were run on a NVIDIA-A10080GB: + + + ## Training Please refer to [TRAINING.md](TRAINING.md) for training from scratch. We use [Torchvision](https://github.com/pytorch/vision/tree/main/references/classification) as our framework for training. Supermask can be applied during training. diff --git a/torchao/sparsity/prototype/superblock/benchmark.py b/torchao/sparsity/prototype/superblock/benchmark.py index 83eab0fac..c893fe2cb 100644 --- a/torchao/sparsity/prototype/superblock/benchmark.py +++ b/torchao/sparsity/prototype/superblock/benchmark.py @@ -1,17 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import os -import time -import sys -import warnings -import hashlib +import torch import torchvision -import presets -import torch -import torch.utils.data -import utils -from torch import nn from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm from torch.sparse._triton_ops_meta import dump as store_tuned_kernel_params from torchao.sparsity.prototype.superblock.utils import accelerate_with_sparsity, simulate_sparsity @@ -21,7 +12,6 @@ @torch.inference_mode def main(args): - print(args) device = torch.device(args.device) # We disable the cudnn benchmarking because it can noticeably affect the accuracy @@ -30,11 +20,9 @@ def main(args): num_classes = 1000 dtype = getattr(torch, args.dtype) - print(f"Using dtype: {dtype}") # BSR kernel tuning if args.bsr and args.tune_kernel_params: - print("Tuning kernel params") kwargs = dict( dtype=torch.int8 if args.quantization else dtype, sparsity=args.sparsity_linear, verbose=True, @@ -59,23 +47,22 @@ def main(args): # by default) but when used, it'll enables reusing the tuned # parameters in subsequent runs of this script: # store_tuned_kernel_params() - print("Creating model") - model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) + model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes).eval() - # Fake sparsity necessary for BSR - simulate_sparsity(model, args) + # Fake sparsity necessary for BSR, since we find based on SuperBlock + sparsifier_or_none = simulate_sparsity(model, args) + if sparsifier_or_none is not None: + sparsifier_or_none.squash_mask() if args.weights_path: try: checkpoint = torch.load(args.weights_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) - print(f"Loaded checkpoint successfully from: {args.weights_path}") except FileNotFoundError: raise FileNotFoundError(f"No checkpoint found at {args.weights_path}.") model.to(device).to(dtype) - # Fake sparsity necessary for BSR accelerate_with_sparsity(model, args) # compile @@ -96,8 +83,8 @@ def main(args): def get_args_parser(add_help=True): import argparse - parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) - parser.add_argument("--model", default="resnet18", type=str, help="model name") + parser = argparse.ArgumentParser(description="PyTorch ImageNet Sparsity Benchmarking", add_help=add_help) + parser.add_argument("--model", default="vit_b_16", choices=["vit_b_16", "vit_h_14"], type=str, help="model name") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument( "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" @@ -109,20 +96,17 @@ def get_args_parser(add_help=True): parser.add_argument("--weights-path", type=str, help="path of pretrained weights to load") # NOTE: sparsity args parser.add_argument("--sparsity", choices=["bsr", "semi_structured"], default=None, help='weight sparsification to apply') + parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') parser.add_argument("--sparsity-linear", type=float, default=0.0) - parser.add_argument("--sp-linear-tile-size", type=int, default=1) parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) - parser.add_argument("--sp-conv1x1-tile-size", type=int, default=1) parser.add_argument("--sparsity-conv", type=float, default=0.0) - parser.add_argument("--sp-conv-tile-size", type=int, default=1) parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)") parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)") - parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') - parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], help="data type", default="bfloat16") - parser.add_argument("--float16", action="store_true", help="Use float16") + parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], help="Data type", default="bfloat16") parser.add_argument("--tune-kernel-params", action="store_true", help="Tune kernel params") - parser.add_argument("--profile", action="store_true", help="Profile the run and dump Prefetto trace") - parser.add_argument("--quantization", action="store_true", help="whether to run with quantization or not") + parser.add_argument("--quantization", action="store_true", help="Run with int8 dynamic quantization") + parser.add_argument("--profile", action="store_true", help="Dump Prefetto trace") + parser.add_argument("--header", action="store_true", help="Print header for first run") return parser @@ -131,7 +115,9 @@ def get_args_parser(add_help=True): args = get_args_parser().parse_args() result = main(args) header = ["model", "batch_size", "dtype", "sparsity", "bsr", "sparsity_level", "quantization", "tune_kernel_params", "latency", "img/s"] - result_string = " | ".join(str(_) for _ in [args.model, args.batch_size, args.dtype, args.sparsity, args.bsr, args.sparsity_linear, args.quantization, args.tune_kernel_params, result, 1000/result]) - with open("benchmark_results.txt", "w+") as f: - f.write(result_string) + result_string = ",".join(str(_) for _ in [args.model, args.batch_size, args.dtype, args.sparsity, args.bsr, args.sparsity_linear, args.quantization, args.tune_kernel_params, result, 1000/result]) + with open("benchmark_results.txt", "a") as f: + if args.header: + f.write(",".join(header)+"\n") + f.write(result_string+"\n") print(result_string) diff --git a/torchao/sparsity/prototype/superblock/benchmark.sh b/torchao/sparsity/prototype/superblock/benchmark.sh index c4a6a0f71..3fc2a9869 100644 --- a/torchao/sparsity/prototype/superblock/benchmark.sh +++ b/torchao/sparsity/prototype/superblock/benchmark.sh @@ -1,7 +1,39 @@ MODEL=vit_h_14 BATCH_SIZE=256 -python benchmark.py --model $MODEL --batch-size $BATCH_SIZE -python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.8 --sp-linear-tile-size 64 --bsr 64 --sparsity bsr +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --header +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --quantization + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization --tune-kernel-params +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization --tune-kernel-params +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization --tune-kernel-params + +MODEL=vit_b_16 +BATCH_SIZE=256 + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --header +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --quantization + python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured -python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.8 --sp-linear-tile-size 64 --bsr 64 --sparsity bsr --quantization --tune-kernel-params +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization --tune-kernel-params +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization --tune-kernel-params +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization --tune-kernel-params diff --git a/torchao/sparsity/prototype/superblock/benchmark_results.txt b/torchao/sparsity/prototype/superblock/benchmark_results.txt new file mode 100644 index 000000000..81bbf20b8 --- /dev/null +++ b/torchao/sparsity/prototype/superblock/benchmark_results.txt @@ -0,0 +1,23 @@ +model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,tune_kernel_params,latency,img/s +vit_b_16,256,bfloat16,None,None,0.0,False,False,59.8181787109375,16.717326096341917 +vit_b_16,256,bfloat16,None,None,0.0,True,False,60.42583984375,16.5492114397717 +vit_b_16,256,bfloat16,semi_structured,None,0.0,False,False,64.60806640625,15.47794347708977 +vit_b_16,256,bfloat16,bsr,64,0.8,False,False,47.4581591796875,21.07119233625919 +vit_b_16,256,bfloat16,bsr,64,0.84,False,False,45.565,21.94666959288928 +vit_b_16,256,bfloat16,bsr,64,0.9,False,False,42.566181640625,23.492828378235455 +vit_b_16,256,bfloat16,semi_structured,None,0.0,True,False,60.38123046875,16.561437920970242 +vit_b_16,256,bfloat16,bsr,64,0.8,True,False,52.3312890625,19.109026701132585 +vit_b_16,256,bfloat16,bsr,64,0.84,True,False,51.0024755859375,19.606891401085672 +vit_b_16,256,bfloat16,bsr,64,0.9,True,False,48.94609375,20.430639574787314 +vit_b_16,256,bfloat16,bsr,64,0.8,True,True,49.3368310546875,20.26883321491703 +vit_b_16,256,bfloat16,bsr,64,0.84,True,True,48.820908203125,20.483027391448456 +vit_b_16,256,bfloat16,bsr,64,0.9,True,True,47.728173828125,20.951985374532082 +model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,tune_kernel_params,latency,img/s +vit_h_14,256,bfloat16,None,None,0.0,False,False,488.05046875,2.0489684244361257 +vit_h_14,256,bfloat16,None,None,0.0,True,False,451.03265625,2.217134360767253 +vit_h_14,256,bfloat16,semi_structured,None,0.0,False,False,505.3465234375,1.97884016931142 +vit_h_14,256,bfloat16,bsr,64,0.8,False,False,361.4376953125,2.7667285758210616 +model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,tune_kernel_params,latency,img/s +vit_h_14,256,bfloat16,None,None,0.0,False,False,491.9237890625,2.032835211945702 +vit_h_14,256,bfloat16,None,None,0.0,True,False,454.1960546875,2.2016923962230996 +vit_h_14,256,bfloat16,semi_structured,None,0.0,False,False,507.211640625,1.9715635839267664 diff --git a/torchao/sparsity/prototype/superblock/blocksparse.py b/torchao/sparsity/prototype/superblock/blocksparse.py index 4015217f5..6f293ddd5 100644 --- a/torchao/sparsity/prototype/superblock/blocksparse.py +++ b/torchao/sparsity/prototype/superblock/blocksparse.py @@ -5,9 +5,10 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.utils import TorchAOBaseTensor from torchao.quantization.quant_api import _get_linear_subclass_inserter -from torch.sparse._triton_ops import bsr_dense_mm, _int_bsr_dense_addmm, broadcast_batch_dims, bsr_dense_addmm +from torch.sparse._triton_ops import bsr_dense_mm, bsr_dense_addmm, broadcast_batch_dims aten = torch.ops.aten + # quantization support @torch.library.custom_op("blocksparse::_int_mm", mutates_args=()) def blocksparse_int_mm(crow_indices: torch.Tensor, @@ -18,11 +19,6 @@ def blocksparse_int_mm(crow_indices: torch.Tensor, A: torch.Tensor) -> torch.Tensor: assert values.dtype == torch.int8 weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) - - N = A.shape[-1] - - # original_batch_dims_broadcasted = broadcast_batch_dims("_int_bsr_dense_addmm", weight_bsr, A) - # input = torch.zeros(M, N, dtype=torch.int32, device=A.device) return bsr_dense_mm(weight_bsr, A).t().contiguous() @torch.library.register_fake("blocksparse::_int_mm") diff --git a/torchao/sparsity/prototype/superblock/utils.py b/torchao/sparsity/prototype/superblock/utils.py index e75390442..0912334de 100644 --- a/torchao/sparsity/prototype/superblock/utils.py +++ b/torchao/sparsity/prototype/superblock/utils.py @@ -33,16 +33,6 @@ def apply_sparsity(model): if isinstance(module, SupermaskLinear) and "mlp" in name: module.sparsify_offline() - -def verify_sparsity(model): - for name, module in model.named_modules(): - if isinstance(module, torch.nn.Linear): - total_weights = module.weight.numel() - sparse_weights = (module.weight == 0).sum().item() - sparsity_percentage = (sparse_weights / total_weights) * 100 - print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%") - - # filter functions def mlp_0_only(mod, name): return isinstance(mod, torch.nn.Linear) and "mlp.0" in name @@ -78,7 +68,6 @@ def mlp_only_with_args( def accelerate_with_sparsity(model, args): if args.sparsity == "bsr": apply_sparsity(model) - verify_sparsity(model) if args.quantization: from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType @@ -94,10 +83,10 @@ def accelerate_with_sparsity(model, args): sparsify_(model, block_sparse_weight(blocksize=args.bsr), superblock_only) elif args.sparsity == "semi_structured": if args.quantization: + from torchao.dtypes.affine_quantized_tensor import SemiSparseLayoutType quantize_( - model, int8_dynamic_activation_int8_semi_sparse_weight(), mlp_0_only + model, int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()), mlp_only ) - sparsify_(model, semi_sparse_weight(), mlp_3_only) else: sparsify_(model, semi_sparse_weight(), mlp_only) else: @@ -110,11 +99,11 @@ def simulate_sparsity(model, args): apply_supermask( model, linear_sparsity=args.sparsity_linear, - linear_sp_tilesize=args.sp_linear_tile_size, + linear_sp_tilesize=args.bsr, conv1x1_sparsity=args.sparsity_conv1x1, - conv1x1_sp_tilesize=args.sp_conv1x1_tile_size, + conv1x1_sp_tilesize=args.bsr, conv_sparsity=args.sparsity_conv, - conv_sp_tilesize=args.sp_conv_tile_size, + conv_sp_tilesize=args.bsr, skip_last_layer_sparsity=args.skip_last_layer_sparsity, skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, device=args.device, @@ -135,12 +124,8 @@ def simulate_sparsity(model, args): sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2 ) sparsifier.prepare(model, sparse_config) - for line in sparse_config: - print(line) sparsifier.step() return sparsifier - else: - print("No sparsity applied!") ### Existing torchvision utils