Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BSR subclass +torch.compile and clean up superblock #680

Merged
merged 14 commits into from
Aug 15, 2024
Merged
3 changes: 3 additions & 0 deletions torchao/sparsity/prototype/superblock/.gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
*/*.pyc

# Model checkpoints
*.pth

# Editor temporaries
*.swa
*.swb
Expand Down
Empty file.
58 changes: 21 additions & 37 deletions torchao/sparsity/prototype/superblock/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
import torch.utils.data
import utils
from torch import nn
from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from supermask import apply_supermask, SupermaskLinear
from blocksparse import BlockSparseTensor
from utils import benchmark_inference


def apply_sparsity(model):
Expand All @@ -25,20 +28,12 @@ def apply_sparsity(model):

def apply_bsr(model, blocksize):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and "mlp" in name:
try:
module.weight = torch.nn.Parameter(to_bsr(module.weight.data, blocksize))
print(f"Converted {name} to bsr format.")
except ValueError as e:
print(f"Unable to convert weight of {name} to bsr format: {e}")


def to_bsr(tensor, blocksize):
if tensor.ndim != 2:
raise ValueError("to_bsr expects 2D tensor")
if tensor.size(0) % blocksize or tensor.size(1) % blocksize:
raise ValueError("Tensor dimensions must be divisible by blocksize")
return tensor.to_sparse_bsr(blocksize)
if isinstance(module, torch.nn.Linear) and "mlp" in name:
try:
module.weight = torch.nn.Parameter(BlockSparseTensor.from_dense(module.weight.data, blocksize))
print(f"Converted {name} to bsr format.")
except ValueError as e:
print(f"Unable to convert weight of {name} to bsr format: {e}")


def verify_sparsity(model):
Expand All @@ -49,23 +44,7 @@ def verify_sparsity(model):
sparsity_percentage = (sparse_weights / total_weights) * 100
print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%")


def benchmark_in_ms(warmup, iters, f, *args, **kwargs):
for _ in range(warmup):
f(*args, **kwargs)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

for _ in range(iters):
f(*args, **kwargs)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / float(iters)


@torch.inference_mode
def main(args):
print(args)
device = torch.device(args.device)
Expand All @@ -83,8 +62,11 @@ def main(args):
print("Using float16")
dtype = torch.float16

# Sample input
# input = torch.rand(32, 3, 224, 224, dtype=dtype).to(device)
if args.bsr and args.tune_kernel_params:
print("Tuning kernel params")
assert args.model == "vit_b_16", "--tune-kernel-params only supported for vit-b-16!"
optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)

print("Creating model")
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
Expand Down Expand Up @@ -112,7 +94,6 @@ def main(args):
raise FileNotFoundError(f"No checkpoint found at {args.weights_path}.")

model.to(device)
# output0 = model(input)

if args.sparsify_weights:
apply_sparsity(model)
Expand All @@ -134,9 +115,11 @@ def main(args):
# output2 = model(input)
# assert torch.allclose(output2, output1), "Output of model before and after changing format to BSR should be equal"

model = torch.compile(model, mode='max-autotune')

image = torch.empty(args.batch_size, 3, args.val_crop_size, args.val_crop_size, dtype=dtype, device=device)
# model = torch.compile(model, mode='max-autotune')
return benchmark_in_ms(10, 100, model, image)

return benchmark_inference(10, 100, model, image)


def get_args_parser(add_help=True):
Expand Down Expand Up @@ -169,11 +152,12 @@ def get_args_parser(add_help=True):
parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)')
parser.add_argument("--bfloat16", action="store_true", help="Use bfloat16")
parser.add_argument("--float16", action="store_true", help="Use float16")
parser.add_argument("--tune-kernel-params", action="store_true", help="Tune kernel params")

return parser


if __name__ == "__main__":
args = get_args_parser().parse_args()
result = main(args)
print(f"{result} ms", file=sys.stderr)
print(f"{result:.3f} ms", file=sys.stderr)
138 changes: 138 additions & 0 deletions torchao/sparsity/prototype/superblock/blocksparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch
from typing import Optional, Tuple, List, Dict, Any, Callable
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.dtypes.utils import (
_implements,
_dispatch__torch_function__,
_dispatch__torch_dispatch__,
)
aten = torch.ops.aten

# bsr wrapper custom op
@torch.library.custom_op("blocksparse::linear", mutates_args=())
def blocksparse_linear(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K: int, bias: torch.Tensor) -> torch.Tensor:
weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K))
return torch.nn.functional.linear(A, weight_bsr, bias)

@torch.library.register_fake("blocksparse::linear")
def blocksparse_linear_abstract(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K:int , bias: torch.Tensor) -> torch.Tensor:
new_shape = A.shape[:-1] + (bias.shape[0],)
return torch.empty(new_shape, dtype=A.dtype, device=A.device)

# Subclass definition
class BlockSparseTensor(torch.Tensor):
bsr_crow_indices: Optional[torch.Tensor]
bsr_col_indices: Optional[torch.Tensor]
bsr_values: Optional[torch.Tensor]

__slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values"]

implements = classmethod(_implements)
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
__torch_function__ = classmethod(_dispatch__torch_function__)

@staticmethod
def __new__( # noqa: PYI034
cls,
shape: torch.Size,
bsr_crow_indices: Optional[torch.Tensor],
bsr_col_indices: Optional[torch.Tensor],
bsr_values: Optional[torch.Tensor],
requires_grad: bool = False,
):
if bsr_values is None:
raise ValueError("bsr values must be provided!")
else:
previous_tensor = bsr_values

kwargs = {
"device": previous_tensor.device,
"dtype": previous_tensor.dtype,
"layout": previous_tensor.layout,
"requires_grad": requires_grad,
}
tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
tensor.bsr_crow_indices = bsr_crow_indices
tensor.bsr_col_indices = bsr_col_indices
tensor.bsr_values = bsr_values
return tensor

def __repr__(self) -> str: # type: ignore[override]
assert hasattr(self, "shape")
return f"{self.__class__.__name__}(shape={self.shape})"

def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool]]:
inner_tensors = list(
filter(lambda x: getattr(self, x) is not None, self.__slots__)
)
tensor_meta = (self.shape, self.requires_grad)
return inner_tensors, tensor_meta

@classmethod
def __tensor_unflatten__(
cls,
inner_tensors,
tensor_meta: Tuple[torch.Size, bool],
outer_size,
outer_stride,
) -> torch.Tensor:
shape, requires_grad = tensor_meta
return cls(
shape=shape,
bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None),
bsr_col_indices=inner_tensors.get("bsr_col_indices", None),
bsr_values=inner_tensors.get("bsr_values", None),
requires_grad=requires_grad,
)

@classmethod
def from_dense(cls, dense_tensor, blocksize):
bsr_tensor = dense_tensor.to_sparse_bsr(blocksize)
return cls(
shape=dense_tensor.shape,
bsr_crow_indices=bsr_tensor.crow_indices(),
bsr_col_indices=bsr_tensor.col_indices(),
bsr_values=bsr_tensor.values(),
requires_grad=False,
)

def apply_fn_to_shard(self, func):
return BlockSparseTensor(
shape = self.shape,
bsr_crow_indices=func(self.bsr_crow_indices),
bsr_col_indices=func(self.bsr_col_indices),
bsr_values=func(self.bsr_values),
requires_grad=self.requires_grad,
)

# Subclass op dispatch registration
implements = BlockSparseTensor.implements

@implements(aten.detach.default)
def block_sparse_detach(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0].apply_fn_to_shard(torch.detach))

@implements(aten.values.default)
def block_sparse_values(func, types, args, kwargs):
return args[0].bsr_values.detach()

@implements(aten.crow_indices.default)
def block_sparse_crow_indices(func, types, args, kwargs):
return args[0].bsr_crow_indices.detach()

@implements(aten.col_indices.default)
def block_sparse_col_indices(func, types, args, kwargs):
return args[0].bsr_col_indices.detach()

@implements(aten._nnz.default)
def block_sparse__nnz(func, types, args, kwargs):
return args[0].bsr_values.shape[0]

@implements(torch.nn.functional.linear)
def block_sparse_linear(func, types, args, kwargs):
x, w, bias = args
return torch.ops.blocksparse.linear(x,
w.crow_indices(),
w.col_indices(),
w.values(),
w.shape[0], w.shape[1], bias)
51 changes: 9 additions & 42 deletions torchao/sparsity/prototype/superblock/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,7 @@

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from supermask import apply_supermask, SupermaskLinear


def apply_sparsity(model):
for name, module in model.named_modules():
if isinstance(module, SupermaskLinear) and "mlp" in name:
module.sparsify_offline()


def apply_bsr(model):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and "mlp" in name:
try:
module.weight = torch.nn.Parameter(to_bsr(module.weight.data, args.bsr))
print(f"Converted {name} to bsr format.")
except ValueError as e:
print(f"Unable to convert weight of {name} to bsr format: {e}")


def to_bsr(tensor, blocksize):
if tensor.ndim != 2:
raise ValueError("to_bsr expects 2D tensor")
if tensor.size(0) % blocksize or tensor.size(1) % blocksize:
raise ValueError("Tensor dimensions must be divisible by blocksize")
return tensor.to_sparse_bsr(blocksize)


def verify_sparsity(model):
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
total_weights = module.weight.numel()
sparse_weights = (module.weight == 0).sum().item()
sparsity_percentage = (sparse_weights / total_weights) * 100
print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%")

from benchmark import apply_sparsity, apply_bsr, verify_sparsity

def _get_cache_path(filepath):
h = hashlib.sha1(filepath.encode()).hexdigest()
Expand Down Expand Up @@ -82,16 +49,16 @@ def load_data(valdir, args):
)

# for META internal
dataset_test = torchvision.datasets.ImageFolder(
valdir,
preprocessing,
)
# for OSS
# dataset_test = torchvision.datasets.ImageNet(
# dataset_test = torchvision.datasets.ImageFolder(
# valdir,
# split='val',
# transform=preprocessing
# preprocessing,
# )
#for OSS
dataset_test = torchvision.datasets.ImageNet(
valdir,
split='val',
transform=preprocessing
)
if args.cache_dataset:
print(f"Saving dataset_test to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
Expand Down
29 changes: 5 additions & 24 deletions torchao/sparsity/prototype/superblock/supermask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch.nn.functional as F
import numpy as np


# original supermask
scores_min=None
scores_max=9e9
Expand All @@ -21,34 +20,16 @@
def percentile(t, q):
"""Return the value that is larger than q% of t"""
k = 1 + round(.01 * float(q) * (t.numel() - 1))
return t.view(-1).kthvalue(k).values.item()


def to_bsr(tensor, blocksize=256):
if tensor.ndim != 2:
print("Tensor is not 2D, skipping BSR conversion.")
return tensor

if tensor.size(0) % blocksize or tensor.size(1) % blocksize:
print("Tensor dimensions are not divisible by blocksize, skipping BSR conversion.")
return tensor

try:
converted_tensor = tensor.to_sparse_bsr(blocksize=blocksize)
print(f"Converted tensor to BSR format with blocksize: {blocksize}")
return converted_tensor
except ValueError as e:
print(f"Unable to convert tensor to BSR format: {e}")
return tensor
return t.view(-1).kthvalue(k).values


class GetSubnet(torch.autograd.Function):
"""Supermask STE function"""
@staticmethod
def forward(ctx, scores, zeros, ones, sparsity):
scores.clamp_(min=scores_min,max=scores_max)
k_val = percentile(scores, sparsity*100)
return torch.where(scores < k_val, zeros.to(scores.device), ones.to(scores.device))
clamped_scores = scores.clamp(min=scores_min,max=scores_max)
k_val = percentile(clamped_scores, sparsity*100)
return torch.where(clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device))
@staticmethod
def backward(ctx, g):
return g, None, None, None
Expand Down Expand Up @@ -130,7 +111,7 @@ def forward(self, x):
subnet = self.get_mask()
w = (self.weight*self.scale+self.shift) * subnet
else:
w = self.weight.data
w = self.weight
return F.linear(x, w, self.bias)


Expand Down
Loading
Loading