Skip to content

Commit

Permalink
updated README and script
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip committed Sep 25, 2024
1 parent 898aa08 commit e304b9c
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 63 deletions.
8 changes: 6 additions & 2 deletions torchao/sparsity/prototype/superblock/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,19 @@ At least one GPU:
```
* Install PyTorch. For best performance, we recommend the pytorch nightlies
```
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
```
We ran our experiments with torch==2.6.0.dev20240924+cu124


## Benchmarking

For all our benchmarking results, you can run `benchmark.sh`. This will run benchmarks with random weights, only testing speedup.
Please use the evaluation script to validate accuracy.

These benchmarks were run on a NVIDIA-A10080GB:




## Training
Please refer to [TRAINING.md](TRAINING.md) for training from scratch. We use [Torchvision](https://github.com/pytorch/vision/tree/main/references/classification) as our framework for training. Supermask can be applied during training.
Expand Down
50 changes: 18 additions & 32 deletions torchao/sparsity/prototype/superblock/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import os
import time
import sys
import warnings
import hashlib
import torch
import torchvision

import presets
import torch
import torch.utils.data
import utils
from torch import nn
from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm
from torch.sparse._triton_ops_meta import dump as store_tuned_kernel_params
from torchao.sparsity.prototype.superblock.utils import accelerate_with_sparsity, simulate_sparsity
Expand All @@ -21,7 +12,6 @@

@torch.inference_mode
def main(args):
print(args)
device = torch.device(args.device)

# We disable the cudnn benchmarking because it can noticeably affect the accuracy
Expand All @@ -30,11 +20,9 @@ def main(args):
num_classes = 1000

dtype = getattr(torch, args.dtype)
print(f"Using dtype: {dtype}")

# BSR kernel tuning
if args.bsr and args.tune_kernel_params:
print("Tuning kernel params")
kwargs = dict(
dtype=torch.int8 if args.quantization else dtype,
sparsity=args.sparsity_linear, verbose=True,
Expand All @@ -59,23 +47,22 @@ def main(args):
# by default) but when used, it'll enables reusing the tuned
# parameters in subsequent runs of this script:
# store_tuned_kernel_params()
print("Creating model")
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes).eval()

# Fake sparsity necessary for BSR
simulate_sparsity(model, args)
# Fake sparsity necessary for BSR, since we find based on SuperBlock
sparsifier_or_none = simulate_sparsity(model, args)
if sparsifier_or_none is not None:
sparsifier_or_none.squash_mask()

if args.weights_path:
try:
checkpoint = torch.load(args.weights_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])
print(f"Loaded checkpoint successfully from: {args.weights_path}")
except FileNotFoundError:
raise FileNotFoundError(f"No checkpoint found at {args.weights_path}.")

model.to(device).to(dtype)

# Fake sparsity necessary for BSR
accelerate_with_sparsity(model, args)

# compile
Expand All @@ -96,8 +83,8 @@ def main(args):
def get_args_parser(add_help=True):
import argparse

parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
parser.add_argument("--model", default="resnet18", type=str, help="model name")
parser = argparse.ArgumentParser(description="PyTorch ImageNet Sparsity Benchmarking", add_help=add_help)
parser.add_argument("--model", default="vit_b_16", choices=["vit_b_16", "vit_h_14"], type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
Expand All @@ -109,20 +96,17 @@ def get_args_parser(add_help=True):
parser.add_argument("--weights-path", type=str, help="path of pretrained weights to load")
# NOTE: sparsity args
parser.add_argument("--sparsity", choices=["bsr", "semi_structured"], default=None, help='weight sparsification to apply')
parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)')
parser.add_argument("--sparsity-linear", type=float, default=0.0)
parser.add_argument("--sp-linear-tile-size", type=int, default=1)
parser.add_argument("--sparsity-conv1x1", type=float, default=0.0)
parser.add_argument("--sp-conv1x1-tile-size", type=int, default=1)
parser.add_argument("--sparsity-conv", type=float, default=0.0)
parser.add_argument("--sp-conv-tile-size", type=int, default=1)
parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)")
parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)")
parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)')
parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], help="data type", default="bfloat16")
parser.add_argument("--float16", action="store_true", help="Use float16")
parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], help="Data type", default="bfloat16")
parser.add_argument("--tune-kernel-params", action="store_true", help="Tune kernel params")
parser.add_argument("--profile", action="store_true", help="Profile the run and dump Prefetto trace")
parser.add_argument("--quantization", action="store_true", help="whether to run with quantization or not")
parser.add_argument("--quantization", action="store_true", help="Run with int8 dynamic quantization")
parser.add_argument("--profile", action="store_true", help="Dump Prefetto trace")
parser.add_argument("--header", action="store_true", help="Print header for first run")

return parser

Expand All @@ -131,7 +115,9 @@ def get_args_parser(add_help=True):
args = get_args_parser().parse_args()
result = main(args)
header = ["model", "batch_size", "dtype", "sparsity", "bsr", "sparsity_level", "quantization", "tune_kernel_params", "latency", "img/s"]
result_string = " | ".join(str(_) for _ in [args.model, args.batch_size, args.dtype, args.sparsity, args.bsr, args.sparsity_linear, args.quantization, args.tune_kernel_params, result, 1000/result])
with open("benchmark_results.txt", "w+") as f:
f.write(result_string)
result_string = ",".join(str(_) for _ in [args.model, args.batch_size, args.dtype, args.sparsity, args.bsr, args.sparsity_linear, args.quantization, args.tune_kernel_params, result, 1000/result])
with open("benchmark_results.txt", "a") as f:
if args.header:
f.write(",".join(header)+"\n")
f.write(result_string+"\n")
print(result_string)
38 changes: 35 additions & 3 deletions torchao/sparsity/prototype/superblock/benchmark.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,39 @@
MODEL=vit_h_14
BATCH_SIZE=256

python benchmark.py --model $MODEL --batch-size $BATCH_SIZE
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.8 --sp-linear-tile-size 64 --bsr 64 --sparsity bsr
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --header
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --quantization

python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr

python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured --quantization
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization

python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization --tune-kernel-params
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization --tune-kernel-params
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization --tune-kernel-params

MODEL=vit_b_16
BATCH_SIZE=256

python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --header
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --quantization

python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.8 --sp-linear-tile-size 64 --bsr 64 --sparsity bsr --quantization --tune-kernel-params
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr

python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured --quantization
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization

python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization --tune-kernel-params
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization --tune-kernel-params
python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization --tune-kernel-params
23 changes: 23 additions & 0 deletions torchao/sparsity/prototype/superblock/benchmark_results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,tune_kernel_params,latency,img/s
vit_b_16,256,bfloat16,None,None,0.0,False,False,59.8181787109375,16.717326096341917
vit_b_16,256,bfloat16,None,None,0.0,True,False,60.42583984375,16.5492114397717
vit_b_16,256,bfloat16,semi_structured,None,0.0,False,False,64.60806640625,15.47794347708977
vit_b_16,256,bfloat16,bsr,64,0.8,False,False,47.4581591796875,21.07119233625919
vit_b_16,256,bfloat16,bsr,64,0.84,False,False,45.565,21.94666959288928
vit_b_16,256,bfloat16,bsr,64,0.9,False,False,42.566181640625,23.492828378235455
vit_b_16,256,bfloat16,semi_structured,None,0.0,True,False,60.38123046875,16.561437920970242
vit_b_16,256,bfloat16,bsr,64,0.8,True,False,52.3312890625,19.109026701132585
vit_b_16,256,bfloat16,bsr,64,0.84,True,False,51.0024755859375,19.606891401085672
vit_b_16,256,bfloat16,bsr,64,0.9,True,False,48.94609375,20.430639574787314
vit_b_16,256,bfloat16,bsr,64,0.8,True,True,49.3368310546875,20.26883321491703
vit_b_16,256,bfloat16,bsr,64,0.84,True,True,48.820908203125,20.483027391448456
vit_b_16,256,bfloat16,bsr,64,0.9,True,True,47.728173828125,20.951985374532082
model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,tune_kernel_params,latency,img/s
vit_h_14,256,bfloat16,None,None,0.0,False,False,488.05046875,2.0489684244361257
vit_h_14,256,bfloat16,None,None,0.0,True,False,451.03265625,2.217134360767253
vit_h_14,256,bfloat16,semi_structured,None,0.0,False,False,505.3465234375,1.97884016931142
vit_h_14,256,bfloat16,bsr,64,0.8,False,False,361.4376953125,2.7667285758210616
model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,tune_kernel_params,latency,img/s
vit_h_14,256,bfloat16,None,None,0.0,False,False,491.9237890625,2.032835211945702
vit_h_14,256,bfloat16,None,None,0.0,True,False,454.1960546875,2.2016923962230996
vit_h_14,256,bfloat16,semi_structured,None,0.0,False,False,507.211640625,1.9715635839267664
8 changes: 2 additions & 6 deletions torchao/sparsity/prototype/superblock/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import TorchAOBaseTensor
from torchao.quantization.quant_api import _get_linear_subclass_inserter
from torch.sparse._triton_ops import bsr_dense_mm, _int_bsr_dense_addmm, broadcast_batch_dims, bsr_dense_addmm
from torch.sparse._triton_ops import bsr_dense_mm, bsr_dense_addmm, broadcast_batch_dims

aten = torch.ops.aten

# quantization support
@torch.library.custom_op("blocksparse::_int_mm", mutates_args=())
def blocksparse_int_mm(crow_indices: torch.Tensor,
Expand All @@ -18,11 +19,6 @@ def blocksparse_int_mm(crow_indices: torch.Tensor,
A: torch.Tensor) -> torch.Tensor:
assert values.dtype == torch.int8
weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K))

N = A.shape[-1]

# original_batch_dims_broadcasted = broadcast_batch_dims("_int_bsr_dense_addmm", weight_bsr, A)
# input = torch.zeros(M, N, dtype=torch.int32, device=A.device)
return bsr_dense_mm(weight_bsr, A).t().contiguous()

@torch.library.register_fake("blocksparse::_int_mm")
Expand Down
25 changes: 5 additions & 20 deletions torchao/sparsity/prototype/superblock/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,6 @@ def apply_sparsity(model):
if isinstance(module, SupermaskLinear) and "mlp" in name:
module.sparsify_offline()


def verify_sparsity(model):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
total_weights = module.weight.numel()
sparse_weights = (module.weight == 0).sum().item()
sparsity_percentage = (sparse_weights / total_weights) * 100
print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%")


# filter functions
def mlp_0_only(mod, name):
return isinstance(mod, torch.nn.Linear) and "mlp.0" in name
Expand Down Expand Up @@ -78,7 +68,6 @@ def mlp_only_with_args(
def accelerate_with_sparsity(model, args):
if args.sparsity == "bsr":
apply_sparsity(model)
verify_sparsity(model)
if args.quantization:
from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType

Expand All @@ -94,10 +83,10 @@ def accelerate_with_sparsity(model, args):
sparsify_(model, block_sparse_weight(blocksize=args.bsr), superblock_only)
elif args.sparsity == "semi_structured":
if args.quantization:
from torchao.dtypes.affine_quantized_tensor import SemiSparseLayoutType
quantize_(
model, int8_dynamic_activation_int8_semi_sparse_weight(), mlp_0_only
model, int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()), mlp_only
)
sparsify_(model, semi_sparse_weight(), mlp_3_only)
else:
sparsify_(model, semi_sparse_weight(), mlp_only)
else:
Expand All @@ -110,11 +99,11 @@ def simulate_sparsity(model, args):
apply_supermask(
model,
linear_sparsity=args.sparsity_linear,
linear_sp_tilesize=args.sp_linear_tile_size,
linear_sp_tilesize=args.bsr,
conv1x1_sparsity=args.sparsity_conv1x1,
conv1x1_sp_tilesize=args.sp_conv1x1_tile_size,
conv1x1_sp_tilesize=args.bsr,
conv_sparsity=args.sparsity_conv,
conv_sp_tilesize=args.sp_conv_tile_size,
conv_sp_tilesize=args.bsr,
skip_last_layer_sparsity=args.skip_last_layer_sparsity,
skip_first_transformer_sparsity=args.skip_first_transformer_sparsity,
device=args.device,
Expand All @@ -135,12 +124,8 @@ def simulate_sparsity(model, args):
sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2
)
sparsifier.prepare(model, sparse_config)
for line in sparse_config:
print(line)
sparsifier.step()
return sparsifier
else:
print("No sparsity applied!")


### Existing torchvision utils
Expand Down

0 comments on commit e304b9c

Please sign in to comment.