From d14e32081e31ce869322ae81e1f15e076416e628 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 14 Aug 2024 19:23:05 -0700 Subject: [PATCH] cleaned up blocksparse --- .../prototype/superblock/blocksparse.py | 67 +++++++++---------- 1 file changed, 31 insertions(+), 36 deletions(-) diff --git a/torchao/sparsity/prototype/superblock/blocksparse.py b/torchao/sparsity/prototype/superblock/blocksparse.py index f845ce6f8..b57ed5635 100644 --- a/torchao/sparsity/prototype/superblock/blocksparse.py +++ b/torchao/sparsity/prototype/superblock/blocksparse.py @@ -8,6 +8,7 @@ ) aten = torch.ops.aten +# bsr wrapper custom op @torch.library.custom_op("blocksparse::linear", mutates_args=()) def blocksparse_linear(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K: int, bias: torch.Tensor) -> torch.Tensor: weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) @@ -18,13 +19,13 @@ def blocksparse_linear_abstract(A: torch.Tensor, crow_indices: torch.Tensor, col new_shape = A.shape[:-1] + (bias.shape[0],) return torch.empty(new_shape, dtype=A.dtype, device=A.device) - +# Subclass definition class BlockSparseTensor(torch.Tensor): - bsr_crow_indicies: Optional[torch.Tensor] - bsr_col_indicies: Optional[torch.Tensor] + bsr_crow_indices: Optional[torch.Tensor] + bsr_col_indices: Optional[torch.Tensor] bsr_values: Optional[torch.Tensor] - __slots__ = ["bsr_crow_indicies", "bsr_col_indicies", "bsr_values"] + __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values"] implements = classmethod(_implements) __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) @@ -34,15 +35,15 @@ class BlockSparseTensor(torch.Tensor): def __new__( # noqa: PYI034 cls, shape: torch.Size, - bsr_crow_indicies: Optional[torch.Tensor], - bsr_col_indicies: Optional[torch.Tensor], + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], bsr_values: Optional[torch.Tensor], requires_grad: bool = False, ): - if bsr_values is not None: - previous_tensor = bsr_values - else: + if bsr_values is None: raise ValueError("bsr values must be provided!") + else: + previous_tensor = bsr_values kwargs = { "device": previous_tensor.device, @@ -51,8 +52,8 @@ def __new__( # noqa: PYI034 "requires_grad": requires_grad, } tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - tensor.bsr_crow_indicies = bsr_crow_indicies - tensor.bsr_col_indicies = bsr_col_indicies + tensor.bsr_crow_indices = bsr_crow_indices + tensor.bsr_col_indices = bsr_col_indices tensor.bsr_values = bsr_values return tensor @@ -60,31 +61,26 @@ def __repr__(self) -> str: # type: ignore[override] assert hasattr(self, "shape") return f"{self.__class__.__name__}(shape={self.shape})" - def __tensor_flatten__( - self, - ) -> Tuple[List[str], Tuple[torch.Size, int, bool]]: + def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool]]: inner_tensors = list( filter(lambda x: getattr(self, x) is not None, self.__slots__) ) - tensor_meta = ( - self.shape, - self.requires_grad, - ) + tensor_meta = (self.shape, self.requires_grad) return inner_tensors, tensor_meta @classmethod def __tensor_unflatten__( cls, inner_tensors, - tensor_meta: Tuple[torch.Size, int, bool], + tensor_meta: Tuple[torch.Size, bool], outer_size, outer_stride, ) -> torch.Tensor: shape, requires_grad = tensor_meta return cls( shape=shape, - bsr_crow_indicies=inner_tensors.get("bsr_crow_indicies", None), - bsr_col_indicies=inner_tensors.get("bsr_col_indicies", None), + bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), + bsr_col_indices=inner_tensors.get("bsr_col_indices", None), bsr_values=inner_tensors.get("bsr_values", None), requires_grad=requires_grad, ) @@ -92,26 +88,24 @@ def __tensor_unflatten__( @classmethod def from_dense(cls, dense_tensor, blocksize): bsr_tensor = dense_tensor.to_sparse_bsr(blocksize) - crow_indicies = bsr_tensor.crow_indices() - col_indicies = bsr_tensor.col_indices() - values = bsr_tensor.values() return cls( shape=dense_tensor.shape, - bsr_crow_indicies=crow_indicies, - bsr_col_indicies=col_indicies, - bsr_values=values, + bsr_crow_indices=bsr_tensor.crow_indices(), + bsr_col_indices=bsr_tensor.col_indices(), + bsr_values=bsr_tensor.values(), requires_grad=False, ) def apply_fn_to_shard(self, func): return BlockSparseTensor( shape = self.shape, - bsr_crow_indicies=func(self.bsr_crow_indicies), - bsr_col_indicies=func(self.bsr_col_indicies), + bsr_crow_indices=func(self.bsr_crow_indices), + bsr_col_indices=func(self.bsr_col_indices), bsr_values=func(self.bsr_values), requires_grad=self.requires_grad, ) +# Subclass op dispatch registration implements = BlockSparseTensor.implements @implements(aten.detach.default) @@ -123,12 +117,12 @@ def block_sparse_values(func, types, args, kwargs): return args[0].bsr_values.detach() @implements(aten.crow_indices.default) -def block_sparse_crow_indicies(func, types, args, kwargs): - return args[0].bsr_crow_indicies.detach() +def block_sparse_crow_indices(func, types, args, kwargs): + return args[0].bsr_crow_indices.detach() @implements(aten.col_indices.default) def block_sparse_col_indices(func, types, args, kwargs): - return args[0].bsr_col_indicies.detach() + return args[0].bsr_col_indices.detach() @implements(aten._nnz.default) def block_sparse__nnz(func, types, args, kwargs): @@ -137,7 +131,8 @@ def block_sparse__nnz(func, types, args, kwargs): @implements(torch.nn.functional.linear) def block_sparse_linear(func, types, args, kwargs): x, w, bias = args - crow_indicies = w.crow_indices() - col_indices = w.col_indices() - values = w.values() - return torch.ops.blocksparse.linear(x, crow_indicies, col_indices, values, w.shape[0], w.shape[1], bias) + return torch.ops.blocksparse.linear(x, + w.crow_indices(), + w.col_indices(), + w.values(), + w.shape[0], w.shape[1], bias)