Skip to content

Commit

Permalink
Add BSR subclass +torch.compile and clean up superblock (#680)
Browse files Browse the repository at this point in the history
This PR adds in torch.compile support for block sparsity. 

In a custom op, we create the `sprase_bsr_tensor` from the explicit `crow_indices, col_indices, values`  tensors that are passed in to the custom op. 

I also created a tensor subclass which holds these same values. 

At dispatch, when we see a `torch.nn.functional.linear` call, we dispatch into our custom op  `torch.ops.blocksparse.linear`, using the tensors stored in the subclass. 

This will allow us to add a public API similar to `semi_sparse_weight()`, which I plan to do in a future PR. 

This PR also cleans up the superblock prototype implementation, as there was a lot of repeated code, and also adds in kernel tuning for BSR. 

For bfloat16 I see the following numbers, for a 1.23x gain:
```
New compile baseline: 63.431 ms
New compile + bsr: 53.514 ms 
New compile + bsr + tuning: 51.485 ms 
```
  • Loading branch information
jcaip authored Aug 15, 2024
1 parent 5998389 commit b16f0dc
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 103 deletions.
3 changes: 3 additions & 0 deletions torchao/sparsity/prototype/superblock/.gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
*/*.pyc

# Model checkpoints
*.pth

# Editor temporaries
*.swa
*.swb
Expand Down
Empty file.
58 changes: 21 additions & 37 deletions torchao/sparsity/prototype/superblock/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
import torch.utils.data
import utils
from torch import nn
from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from supermask import apply_supermask, SupermaskLinear
from blocksparse import BlockSparseTensor
from utils import benchmark_inference


def apply_sparsity(model):
Expand All @@ -25,20 +28,12 @@ def apply_sparsity(model):

def apply_bsr(model, blocksize):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and "mlp" in name:
try:
module.weight = torch.nn.Parameter(to_bsr(module.weight.data, blocksize))
print(f"Converted {name} to bsr format.")
except ValueError as e:
print(f"Unable to convert weight of {name} to bsr format: {e}")


def to_bsr(tensor, blocksize):
if tensor.ndim != 2:
raise ValueError("to_bsr expects 2D tensor")
if tensor.size(0) % blocksize or tensor.size(1) % blocksize:
raise ValueError("Tensor dimensions must be divisible by blocksize")
return tensor.to_sparse_bsr(blocksize)
if isinstance(module, torch.nn.Linear) and "mlp" in name:
try:
module.weight = torch.nn.Parameter(BlockSparseTensor.from_dense(module.weight.data, blocksize))
print(f"Converted {name} to bsr format.")
except ValueError as e:
print(f"Unable to convert weight of {name} to bsr format: {e}")


def verify_sparsity(model):
Expand All @@ -49,23 +44,7 @@ def verify_sparsity(model):
sparsity_percentage = (sparse_weights / total_weights) * 100
print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%")


def benchmark_in_ms(warmup, iters, f, *args, **kwargs):
for _ in range(warmup):
f(*args, **kwargs)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

for _ in range(iters):
f(*args, **kwargs)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / float(iters)


@torch.inference_mode
def main(args):
print(args)
device = torch.device(args.device)
Expand All @@ -83,8 +62,11 @@ def main(args):
print("Using float16")
dtype = torch.float16

# Sample input
# input = torch.rand(32, 3, 224, 224, dtype=dtype).to(device)
if args.bsr and args.tune_kernel_params:
print("Tuning kernel params")
assert args.model == "vit_b_16", "--tune-kernel-params only supported for vit-b-16!"
optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)

print("Creating model")
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
Expand Down Expand Up @@ -112,7 +94,6 @@ def main(args):
raise FileNotFoundError(f"No checkpoint found at {args.weights_path}.")

model.to(device)
# output0 = model(input)

if args.sparsify_weights:
apply_sparsity(model)
Expand All @@ -134,9 +115,11 @@ def main(args):
# output2 = model(input)
# assert torch.allclose(output2, output1), "Output of model before and after changing format to BSR should be equal"

model = torch.compile(model, mode='max-autotune')

image = torch.empty(args.batch_size, 3, args.val_crop_size, args.val_crop_size, dtype=dtype, device=device)
# model = torch.compile(model, mode='max-autotune')
return benchmark_in_ms(10, 100, model, image)

return benchmark_inference(10, 100, model, image)


def get_args_parser(add_help=True):
Expand Down Expand Up @@ -169,11 +152,12 @@ def get_args_parser(add_help=True):
parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)')
parser.add_argument("--bfloat16", action="store_true", help="Use bfloat16")
parser.add_argument("--float16", action="store_true", help="Use float16")
parser.add_argument("--tune-kernel-params", action="store_true", help="Tune kernel params")

return parser


if __name__ == "__main__":
args = get_args_parser().parse_args()
result = main(args)
print(f"{result} ms", file=sys.stderr)
print(f"{result:.3f} ms", file=sys.stderr)
138 changes: 138 additions & 0 deletions torchao/sparsity/prototype/superblock/blocksparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch
from typing import Optional, Tuple, List, Dict, Any, Callable
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.dtypes.utils import (
_implements,
_dispatch__torch_function__,
_dispatch__torch_dispatch__,
)
aten = torch.ops.aten

# bsr wrapper custom op
@torch.library.custom_op("blocksparse::linear", mutates_args=())
def blocksparse_linear(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K: int, bias: torch.Tensor) -> torch.Tensor:
weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K))
return torch.nn.functional.linear(A, weight_bsr, bias)

@torch.library.register_fake("blocksparse::linear")
def blocksparse_linear_abstract(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K:int , bias: torch.Tensor) -> torch.Tensor:
new_shape = A.shape[:-1] + (bias.shape[0],)
return torch.empty(new_shape, dtype=A.dtype, device=A.device)

# Subclass definition
class BlockSparseTensor(torch.Tensor):
bsr_crow_indices: Optional[torch.Tensor]
bsr_col_indices: Optional[torch.Tensor]
bsr_values: Optional[torch.Tensor]

__slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values"]

implements = classmethod(_implements)
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
__torch_function__ = classmethod(_dispatch__torch_function__)

@staticmethod
def __new__( # noqa: PYI034
cls,
shape: torch.Size,
bsr_crow_indices: Optional[torch.Tensor],
bsr_col_indices: Optional[torch.Tensor],
bsr_values: Optional[torch.Tensor],
requires_grad: bool = False,
):
if bsr_values is None:
raise ValueError("bsr values must be provided!")
else:
previous_tensor = bsr_values

kwargs = {
"device": previous_tensor.device,
"dtype": previous_tensor.dtype,
"layout": previous_tensor.layout,
"requires_grad": requires_grad,
}
tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
tensor.bsr_crow_indices = bsr_crow_indices
tensor.bsr_col_indices = bsr_col_indices
tensor.bsr_values = bsr_values
return tensor

def __repr__(self) -> str: # type: ignore[override]
assert hasattr(self, "shape")
return f"{self.__class__.__name__}(shape={self.shape})"

def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool]]:
inner_tensors = list(
filter(lambda x: getattr(self, x) is not None, self.__slots__)
)
tensor_meta = (self.shape, self.requires_grad)
return inner_tensors, tensor_meta

@classmethod
def __tensor_unflatten__(
cls,
inner_tensors,
tensor_meta: Tuple[torch.Size, bool],
outer_size,
outer_stride,
) -> torch.Tensor:
shape, requires_grad = tensor_meta
return cls(
shape=shape,
bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None),
bsr_col_indices=inner_tensors.get("bsr_col_indices", None),
bsr_values=inner_tensors.get("bsr_values", None),
requires_grad=requires_grad,
)

@classmethod
def from_dense(cls, dense_tensor, blocksize):
bsr_tensor = dense_tensor.to_sparse_bsr(blocksize)
return cls(
shape=dense_tensor.shape,
bsr_crow_indices=bsr_tensor.crow_indices(),
bsr_col_indices=bsr_tensor.col_indices(),
bsr_values=bsr_tensor.values(),
requires_grad=False,
)

def apply_fn_to_shard(self, func):
return BlockSparseTensor(
shape = self.shape,
bsr_crow_indices=func(self.bsr_crow_indices),
bsr_col_indices=func(self.bsr_col_indices),
bsr_values=func(self.bsr_values),
requires_grad=self.requires_grad,
)

# Subclass op dispatch registration
implements = BlockSparseTensor.implements

@implements(aten.detach.default)
def block_sparse_detach(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0].apply_fn_to_shard(torch.detach))

@implements(aten.values.default)
def block_sparse_values(func, types, args, kwargs):
return args[0].bsr_values.detach()

@implements(aten.crow_indices.default)
def block_sparse_crow_indices(func, types, args, kwargs):
return args[0].bsr_crow_indices.detach()

@implements(aten.col_indices.default)
def block_sparse_col_indices(func, types, args, kwargs):
return args[0].bsr_col_indices.detach()

@implements(aten._nnz.default)
def block_sparse__nnz(func, types, args, kwargs):
return args[0].bsr_values.shape[0]

@implements(torch.nn.functional.linear)
def block_sparse_linear(func, types, args, kwargs):
x, w, bias = args
return torch.ops.blocksparse.linear(x,
w.crow_indices(),
w.col_indices(),
w.values(),
w.shape[0], w.shape[1], bias)
51 changes: 9 additions & 42 deletions torchao/sparsity/prototype/superblock/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,7 @@

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from supermask import apply_supermask, SupermaskLinear


def apply_sparsity(model):
for name, module in model.named_modules():
if isinstance(module, SupermaskLinear) and "mlp" in name:
module.sparsify_offline()


def apply_bsr(model):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and "mlp" in name:
try:
module.weight = torch.nn.Parameter(to_bsr(module.weight.data, args.bsr))
print(f"Converted {name} to bsr format.")
except ValueError as e:
print(f"Unable to convert weight of {name} to bsr format: {e}")


def to_bsr(tensor, blocksize):
if tensor.ndim != 2:
raise ValueError("to_bsr expects 2D tensor")
if tensor.size(0) % blocksize or tensor.size(1) % blocksize:
raise ValueError("Tensor dimensions must be divisible by blocksize")
return tensor.to_sparse_bsr(blocksize)


def verify_sparsity(model):
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
total_weights = module.weight.numel()
sparse_weights = (module.weight == 0).sum().item()
sparsity_percentage = (sparse_weights / total_weights) * 100
print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%")

from benchmark import apply_sparsity, apply_bsr, verify_sparsity

def _get_cache_path(filepath):
h = hashlib.sha1(filepath.encode()).hexdigest()
Expand Down Expand Up @@ -82,16 +49,16 @@ def load_data(valdir, args):
)

# for META internal
dataset_test = torchvision.datasets.ImageFolder(
valdir,
preprocessing,
)
# for OSS
# dataset_test = torchvision.datasets.ImageNet(
# dataset_test = torchvision.datasets.ImageFolder(
# valdir,
# split='val',
# transform=preprocessing
# preprocessing,
# )
#for OSS
dataset_test = torchvision.datasets.ImageNet(
valdir,
split='val',
transform=preprocessing
)
if args.cache_dataset:
print(f"Saving dataset_test to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
Expand Down
29 changes: 5 additions & 24 deletions torchao/sparsity/prototype/superblock/supermask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch.nn.functional as F
import numpy as np


# original supermask
scores_min=None
scores_max=9e9
Expand All @@ -21,34 +20,16 @@
def percentile(t, q):
"""Return the value that is larger than q% of t"""
k = 1 + round(.01 * float(q) * (t.numel() - 1))
return t.view(-1).kthvalue(k).values.item()


def to_bsr(tensor, blocksize=256):
if tensor.ndim != 2:
print("Tensor is not 2D, skipping BSR conversion.")
return tensor

if tensor.size(0) % blocksize or tensor.size(1) % blocksize:
print("Tensor dimensions are not divisible by blocksize, skipping BSR conversion.")
return tensor

try:
converted_tensor = tensor.to_sparse_bsr(blocksize=blocksize)
print(f"Converted tensor to BSR format with blocksize: {blocksize}")
return converted_tensor
except ValueError as e:
print(f"Unable to convert tensor to BSR format: {e}")
return tensor
return t.view(-1).kthvalue(k).values


class GetSubnet(torch.autograd.Function):
"""Supermask STE function"""
@staticmethod
def forward(ctx, scores, zeros, ones, sparsity):
scores.clamp_(min=scores_min,max=scores_max)
k_val = percentile(scores, sparsity*100)
return torch.where(scores < k_val, zeros.to(scores.device), ones.to(scores.device))
clamped_scores = scores.clamp(min=scores_min,max=scores_max)
k_val = percentile(clamped_scores, sparsity*100)
return torch.where(clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device))
@staticmethod
def backward(ctx, g):
return g, None, None, None
Expand Down Expand Up @@ -130,7 +111,7 @@ def forward(self, x):
subnet = self.get_mask()
w = (self.weight*self.scale+self.shift) * subnet
else:
w = self.weight.data
w = self.weight
return F.linear(x, w, self.bias)


Expand Down
Loading

0 comments on commit b16f0dc

Please sign in to comment.