Skip to content

Commit

Permalink
cleaned up blocksparse
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip committed Aug 15, 2024
1 parent c68822f commit d14e320
Showing 1 changed file with 31 additions and 36 deletions.
67 changes: 31 additions & 36 deletions torchao/sparsity/prototype/superblock/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
aten = torch.ops.aten

# bsr wrapper custom op
@torch.library.custom_op("blocksparse::linear", mutates_args=())
def blocksparse_linear(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K: int, bias: torch.Tensor) -> torch.Tensor:
weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K))
Expand All @@ -18,13 +19,13 @@ def blocksparse_linear_abstract(A: torch.Tensor, crow_indices: torch.Tensor, col
new_shape = A.shape[:-1] + (bias.shape[0],)
return torch.empty(new_shape, dtype=A.dtype, device=A.device)


# Subclass definition
class BlockSparseTensor(torch.Tensor):
bsr_crow_indicies: Optional[torch.Tensor]
bsr_col_indicies: Optional[torch.Tensor]
bsr_crow_indices: Optional[torch.Tensor]
bsr_col_indices: Optional[torch.Tensor]
bsr_values: Optional[torch.Tensor]

__slots__ = ["bsr_crow_indicies", "bsr_col_indicies", "bsr_values"]
__slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values"]

implements = classmethod(_implements)
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
Expand All @@ -34,15 +35,15 @@ class BlockSparseTensor(torch.Tensor):
def __new__( # noqa: PYI034
cls,
shape: torch.Size,
bsr_crow_indicies: Optional[torch.Tensor],
bsr_col_indicies: Optional[torch.Tensor],
bsr_crow_indices: Optional[torch.Tensor],
bsr_col_indices: Optional[torch.Tensor],
bsr_values: Optional[torch.Tensor],
requires_grad: bool = False,
):
if bsr_values is not None:
previous_tensor = bsr_values
else:
if bsr_values is None:
raise ValueError("bsr values must be provided!")
else:
previous_tensor = bsr_values

kwargs = {
"device": previous_tensor.device,
Expand All @@ -51,67 +52,60 @@ def __new__( # noqa: PYI034
"requires_grad": requires_grad,
}
tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
tensor.bsr_crow_indicies = bsr_crow_indicies
tensor.bsr_col_indicies = bsr_col_indicies
tensor.bsr_crow_indices = bsr_crow_indices
tensor.bsr_col_indices = bsr_col_indices
tensor.bsr_values = bsr_values
return tensor

def __repr__(self) -> str: # type: ignore[override]
assert hasattr(self, "shape")
return f"{self.__class__.__name__}(shape={self.shape})"

def __tensor_flatten__(
self,
) -> Tuple[List[str], Tuple[torch.Size, int, bool]]:
def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool]]:
inner_tensors = list(
filter(lambda x: getattr(self, x) is not None, self.__slots__)
)
tensor_meta = (
self.shape,
self.requires_grad,
)
tensor_meta = (self.shape, self.requires_grad)
return inner_tensors, tensor_meta

@classmethod
def __tensor_unflatten__(
cls,
inner_tensors,
tensor_meta: Tuple[torch.Size, int, bool],
tensor_meta: Tuple[torch.Size, bool],
outer_size,
outer_stride,
) -> torch.Tensor:
shape, requires_grad = tensor_meta
return cls(
shape=shape,
bsr_crow_indicies=inner_tensors.get("bsr_crow_indicies", None),
bsr_col_indicies=inner_tensors.get("bsr_col_indicies", None),
bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None),
bsr_col_indices=inner_tensors.get("bsr_col_indices", None),
bsr_values=inner_tensors.get("bsr_values", None),
requires_grad=requires_grad,
)

@classmethod
def from_dense(cls, dense_tensor, blocksize):
bsr_tensor = dense_tensor.to_sparse_bsr(blocksize)
crow_indicies = bsr_tensor.crow_indices()
col_indicies = bsr_tensor.col_indices()
values = bsr_tensor.values()
return cls(
shape=dense_tensor.shape,
bsr_crow_indicies=crow_indicies,
bsr_col_indicies=col_indicies,
bsr_values=values,
bsr_crow_indices=bsr_tensor.crow_indices(),
bsr_col_indices=bsr_tensor.col_indices(),
bsr_values=bsr_tensor.values(),
requires_grad=False,
)

def apply_fn_to_shard(self, func):
return BlockSparseTensor(
shape = self.shape,
bsr_crow_indicies=func(self.bsr_crow_indicies),
bsr_col_indicies=func(self.bsr_col_indicies),
bsr_crow_indices=func(self.bsr_crow_indices),
bsr_col_indices=func(self.bsr_col_indices),
bsr_values=func(self.bsr_values),
requires_grad=self.requires_grad,
)

# Subclass op dispatch registration
implements = BlockSparseTensor.implements

@implements(aten.detach.default)
Expand All @@ -123,12 +117,12 @@ def block_sparse_values(func, types, args, kwargs):
return args[0].bsr_values.detach()

@implements(aten.crow_indices.default)
def block_sparse_crow_indicies(func, types, args, kwargs):
return args[0].bsr_crow_indicies.detach()
def block_sparse_crow_indices(func, types, args, kwargs):
return args[0].bsr_crow_indices.detach()

@implements(aten.col_indices.default)
def block_sparse_col_indices(func, types, args, kwargs):
return args[0].bsr_col_indicies.detach()
return args[0].bsr_col_indices.detach()

@implements(aten._nnz.default)
def block_sparse__nnz(func, types, args, kwargs):
Expand All @@ -137,7 +131,8 @@ def block_sparse__nnz(func, types, args, kwargs):
@implements(torch.nn.functional.linear)
def block_sparse_linear(func, types, args, kwargs):
x, w, bias = args
crow_indicies = w.crow_indices()
col_indices = w.col_indices()
values = w.values()
return torch.ops.blocksparse.linear(x, crow_indicies, col_indices, values, w.shape[0], w.shape[1], bias)
return torch.ops.blocksparse.linear(x,
w.crow_indices(),
w.col_indices(),
w.values(),
w.shape[0], w.shape[1], bias)

0 comments on commit d14e320

Please sign in to comment.