-
Notifications
You must be signed in to change notification settings - Fork 169
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add BSR subclass +torch.compile and clean up superblock (#680)
This PR adds in torch.compile support for block sparsity. In a custom op, we create the `sprase_bsr_tensor` from the explicit `crow_indices, col_indices, values` tensors that are passed in to the custom op. I also created a tensor subclass which holds these same values. At dispatch, when we see a `torch.nn.functional.linear` call, we dispatch into our custom op `torch.ops.blocksparse.linear`, using the tensors stored in the subclass. This will allow us to add a public API similar to `semi_sparse_weight()`, which I plan to do in a future PR. This PR also cleans up the superblock prototype implementation, as there was a lot of repeated code, and also adds in kernel tuning for BSR. For bfloat16 I see the following numbers, for a 1.23x gain: ``` New compile baseline: 63.431 ms New compile + bsr: 53.514 ms New compile + bsr + tuning: 51.485 ms ```
- Loading branch information
Showing
7 changed files
with
195 additions
and
103 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
*/*.pyc | ||
|
||
# Model checkpoints | ||
*.pth | ||
|
||
# Editor temporaries | ||
*.swa | ||
*.swb | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import torch | ||
from typing import Optional, Tuple, List, Dict, Any, Callable | ||
from torch.utils._python_dispatch import return_and_correct_aliasing | ||
from torchao.dtypes.utils import ( | ||
_implements, | ||
_dispatch__torch_function__, | ||
_dispatch__torch_dispatch__, | ||
) | ||
aten = torch.ops.aten | ||
|
||
# bsr wrapper custom op | ||
@torch.library.custom_op("blocksparse::linear", mutates_args=()) | ||
def blocksparse_linear(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K: int, bias: torch.Tensor) -> torch.Tensor: | ||
weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) | ||
return torch.nn.functional.linear(A, weight_bsr, bias) | ||
|
||
@torch.library.register_fake("blocksparse::linear") | ||
def blocksparse_linear_abstract(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K:int , bias: torch.Tensor) -> torch.Tensor: | ||
new_shape = A.shape[:-1] + (bias.shape[0],) | ||
return torch.empty(new_shape, dtype=A.dtype, device=A.device) | ||
|
||
# Subclass definition | ||
class BlockSparseTensor(torch.Tensor): | ||
bsr_crow_indices: Optional[torch.Tensor] | ||
bsr_col_indices: Optional[torch.Tensor] | ||
bsr_values: Optional[torch.Tensor] | ||
|
||
__slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values"] | ||
|
||
implements = classmethod(_implements) | ||
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) | ||
__torch_function__ = classmethod(_dispatch__torch_function__) | ||
|
||
@staticmethod | ||
def __new__( # noqa: PYI034 | ||
cls, | ||
shape: torch.Size, | ||
bsr_crow_indices: Optional[torch.Tensor], | ||
bsr_col_indices: Optional[torch.Tensor], | ||
bsr_values: Optional[torch.Tensor], | ||
requires_grad: bool = False, | ||
): | ||
if bsr_values is None: | ||
raise ValueError("bsr values must be provided!") | ||
else: | ||
previous_tensor = bsr_values | ||
|
||
kwargs = { | ||
"device": previous_tensor.device, | ||
"dtype": previous_tensor.dtype, | ||
"layout": previous_tensor.layout, | ||
"requires_grad": requires_grad, | ||
} | ||
tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] | ||
tensor.bsr_crow_indices = bsr_crow_indices | ||
tensor.bsr_col_indices = bsr_col_indices | ||
tensor.bsr_values = bsr_values | ||
return tensor | ||
|
||
def __repr__(self) -> str: # type: ignore[override] | ||
assert hasattr(self, "shape") | ||
return f"{self.__class__.__name__}(shape={self.shape})" | ||
|
||
def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool]]: | ||
inner_tensors = list( | ||
filter(lambda x: getattr(self, x) is not None, self.__slots__) | ||
) | ||
tensor_meta = (self.shape, self.requires_grad) | ||
return inner_tensors, tensor_meta | ||
|
||
@classmethod | ||
def __tensor_unflatten__( | ||
cls, | ||
inner_tensors, | ||
tensor_meta: Tuple[torch.Size, bool], | ||
outer_size, | ||
outer_stride, | ||
) -> torch.Tensor: | ||
shape, requires_grad = tensor_meta | ||
return cls( | ||
shape=shape, | ||
bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), | ||
bsr_col_indices=inner_tensors.get("bsr_col_indices", None), | ||
bsr_values=inner_tensors.get("bsr_values", None), | ||
requires_grad=requires_grad, | ||
) | ||
|
||
@classmethod | ||
def from_dense(cls, dense_tensor, blocksize): | ||
bsr_tensor = dense_tensor.to_sparse_bsr(blocksize) | ||
return cls( | ||
shape=dense_tensor.shape, | ||
bsr_crow_indices=bsr_tensor.crow_indices(), | ||
bsr_col_indices=bsr_tensor.col_indices(), | ||
bsr_values=bsr_tensor.values(), | ||
requires_grad=False, | ||
) | ||
|
||
def apply_fn_to_shard(self, func): | ||
return BlockSparseTensor( | ||
shape = self.shape, | ||
bsr_crow_indices=func(self.bsr_crow_indices), | ||
bsr_col_indices=func(self.bsr_col_indices), | ||
bsr_values=func(self.bsr_values), | ||
requires_grad=self.requires_grad, | ||
) | ||
|
||
# Subclass op dispatch registration | ||
implements = BlockSparseTensor.implements | ||
|
||
@implements(aten.detach.default) | ||
def block_sparse_detach(func, types, args, kwargs): | ||
return return_and_correct_aliasing(func, args, kwargs, args[0].apply_fn_to_shard(torch.detach)) | ||
|
||
@implements(aten.values.default) | ||
def block_sparse_values(func, types, args, kwargs): | ||
return args[0].bsr_values.detach() | ||
|
||
@implements(aten.crow_indices.default) | ||
def block_sparse_crow_indices(func, types, args, kwargs): | ||
return args[0].bsr_crow_indices.detach() | ||
|
||
@implements(aten.col_indices.default) | ||
def block_sparse_col_indices(func, types, args, kwargs): | ||
return args[0].bsr_col_indices.detach() | ||
|
||
@implements(aten._nnz.default) | ||
def block_sparse__nnz(func, types, args, kwargs): | ||
return args[0].bsr_values.shape[0] | ||
|
||
@implements(torch.nn.functional.linear) | ||
def block_sparse_linear(func, types, args, kwargs): | ||
x, w, bias = args | ||
return torch.ops.blocksparse.linear(x, | ||
w.crow_indices(), | ||
w.col_indices(), | ||
w.values(), | ||
w.shape[0], w.shape[1], bias) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.