diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index f0dd31fdfc..b115608365 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -24,6 +24,8 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, ) +from torchao.sparsity.blocksparse import BlockSparseTensor +from torchao.sparsity.utils import create_block_sparse_tensor logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO @@ -167,6 +169,38 @@ def test_sparse(self, compile, input_shape): torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, + "pytorch 2.4+ feature due to need for custom op support", + ) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @common_utils.parametrize("compile", [False]) + @common_utils.parametrize("blocksize", [(64, 8), (8, 64)]) + def test_non_square_sparse(self, compile, blocksize): + input_shape = 1024 + input = torch.rand((input_shape, 1024)).half().cuda() + model = ( + nn.Sequential( + nn.Linear(1024, 2048), + nn.Linear(2048, 1024), + ) + .half() + .cuda() + .eval() + ) + + M, N = model[0].weight.shape + model[0].weight.data = create_block_sparse_tensor(M, N, blocksize, 0.5, torch.float16) + M, N = model[1].weight.shape + model[1].weight.data = create_block_sparse_tensor(M, N, blocksize, 0.5, torch.float16) + dense_result = model(input) + + from torchao.sparsity import block_sparse_weight + sparsify_(model, block_sparse_weight(blocksize=blocksize)) + sparse_result = model(input) + + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-2, atol=1e-2) + class TestQuantBlockSparseWeight(common_utils.TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "pytorch 2.6+ feature") diff --git a/torchao/kernel/bsr_triton_ops.py b/torchao/kernel/bsr_triton_ops.py index 18cfba9ad9..a7368ab1b6 100644 --- a/torchao/kernel/bsr_triton_ops.py +++ b/torchao/kernel/bsr_triton_ops.py @@ -30,6 +30,8 @@ AUTOTUNE = os.getenv("BSR_AUTOTUNE", False) +MIN_BLOCK_SIZE = 16 + def tune_bsr_dense_addmm( input, @@ -64,7 +66,8 @@ def tune_bsr_dense_addmm( BM, BK = values.shape[batch_ndim + 1 : batch_ndim + 3] # Reference parameters is a set of parameters that leads to a - # successful kernel call and the corresponding timing is used as a + # successful kernel call (might not be optimal but works) + # and the corresponding timing is used as a # reference for computing speedups. Avoid changing the reference # parameters when possible. reference_meta = dict( @@ -83,7 +86,7 @@ def tune_bsr_dense_addmm( else: version_dtype = (dtype, out_dtype) version = (0, version_dtype, sparsity) - key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1) + key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1, N % max(N // BM, 1) == 0) # For tuning, for an initial state, use parameters from the # database if available, otherwise, use the reference parameters. @@ -123,7 +126,7 @@ def step_meta_parameter(name, value, direction, meta, M=M, N=N, K=K, BM=BM, BK=B is_log = name in {"SPLIT_N", "num_warps"} min_value = dict(SPLIT_N=1, num_warps=1, num_stages=1, GROUP_SIZE_ROW=1)[name] max_value = dict(SPLIT_N=max(N // BM, 1)).get(name) - value_step = dict(SPLIT_N=2, num_warps=2, num_stages=1, GROUP_SIZE_ROW=1)[name] + value_step = dict(SPLIT_N=2, num_warps=2, num_stages=1, GROUP_SIZE_ROW=2)[name] if is_log: next_value = ( value * value_step**direction @@ -136,7 +139,7 @@ def step_meta_parameter(name, value, direction, meta, M=M, N=N, K=K, BM=BM, BK=B next_value = max(next_value, min_value) if max_value is not None: next_value = min(next_value, max_value) - if name == "SPLIT_N" and N % next_value != 0: + if name == "SPLIT_N" and N % (next_value * BM) != 0: return value return next_value @@ -171,8 +174,7 @@ def bsr_dense_addmm_meta( M, K, N, - Ms, - Ks, + blocksize: tuple[int, int], beta, alpha, SPLIT_N=None, @@ -194,9 +196,12 @@ def bsr_dense_addmm_meta( out_dtype = dtype if sparsity is None: sparsity = 0.5 - if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}: + BM, BK = blocksize + # Calculate a default SPLIT_N that ensures BN is valid + default_split_n = max(N // BM, 1) + if {num_warps, num_stages, GROUP_SIZE_ROW} == {None}: device_name = torch.cuda.get_device_name() - key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1) + key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1, N % default_split_n == 0) if dtype is out_dtype: version_dtype = dtype else: @@ -219,19 +224,17 @@ def bsr_dense_addmm_meta( "bsr_dense_addmm", key, device_name, version=(_version, dtype, 0.5) ) if meta is None: - # find approximate meta such that N % SPLIT_N == 0. + # If still no meta found, search for approximate matches considering N divisibility + approx_key = (*key[:2], "*", *key[3:-1], True) matching_meta = get_meta( "bsr_dense_addmm", - (*key[:2], "*", *key[3:]), + approx_key, device_name, version=(_version, version_dtype, 0.5), ) if matching_meta is None and dtype is not out_dtype: matching_meta = get_meta( - "bsr_dense_addmm", - (*key[:2], "*", *key[3:]), - device_name, - version=(_version, dtype, 0.5), + "bsr_dense_addmm", approx_key, device_name, version=(_version, dtype, 0.5) ) for mkey in sorted(matching_meta or {}): meta_ = matching_meta[mkey] @@ -241,17 +244,18 @@ def bsr_dense_addmm_meta( if N % c == 0 and n <= N: meta = dict(meta_) meta["SPLIT_N"] = N // c + break if meta is not None: meta.update(**extra) return meta else: warn_once( "bsr_dense_addmm uses non-optimal triton kernel parameters" - f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=} {out_dtype=}. " + f" for {M=} {K=} {N=} {BM=}, {BK=} {beta=} {alpha=} {dtype=} {out_dtype=}. " "To find optimal triton kernel parameters, run with BSR_AUTOTUNE=1" ) - SPLIT_N = SPLIT_N or max(N // Ms, 1) + SPLIT_N = SPLIT_N or max(N // BM, 1) GROUP_SIZE_ROW = GROUP_SIZE_ROW or 4 num_stages = num_stages or 4 num_warps = num_warps or 4 @@ -292,7 +296,7 @@ def bsr_dense_addmm( col_indices = bsr.col_indices() batch_ndim = crow_indices.dim() - 1 M, K = bsr.shape[batch_ndim : batch_ndim + 2] - blocksize = values.shape[batch_ndim + 1 : batch_ndim + 3] + BM, BK = values.shape[batch_ndim + 1 : batch_ndim + 3] N = dense.shape[-1] original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) @@ -309,7 +313,7 @@ def bsr_dense_addmm( return out if meta is None: - sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2) + sparsity = round(1 - bsr._nnz() * BM * BK / (M * K), 2) if AUTOTUNE: meta = tune_bsr_dense_addmm( input, @@ -330,8 +334,7 @@ def bsr_dense_addmm( M, K, N, - blocksize[0], - blocksize[1], + (BM, BK), beta, alpha, sparsity=sparsity, @@ -376,10 +379,13 @@ def bsr_dense_addmm( out, ) = prepare_inputs(bsr, input, dense, left_alpha, right_alpha, out) - BM, BK = blocksize + BM, BK = values.shape[batch_ndim + 1 : batch_ndim + 3] SPLIT_N = meta.get("SPLIT_N", max(N // BM, 1)) BN = N // SPLIT_N + if N % SPLIT_N != 0: + raise ValueError(f"N ({N}) must be divisible by SPLIT_N ({SPLIT_N})") + out_untiled = out out = tile_to_blocksize(out, (BM, BN)) dense = tile_to_blocksize(dense, (BK, BN)) @@ -387,7 +393,7 @@ def bsr_dense_addmm( left_alpha = tile_to_blocksize(left_alpha, (BM, BN)) right_alpha = tile_to_blocksize(right_alpha, (BM, BN)) - # tl.dot supports float16, float32, int32 as accumulator types. + # Determine accumulator type based on output dtype dot_out_dtype = { torch.float16: tl.float32, torch.bfloat16: tl.float32, @@ -553,17 +559,28 @@ def _bsr_strided_addmm_kernel( # Compute nnz for the row with number row_block_pid. row_nnz = nnz_offset_next - nnz_offset - row_block_arange = tl.arange(0, BLOCKSIZE_ROW) - inner_block_arange = tl.arange(0, BLOCKSIZE_INNER) + # --- Set up padding for block sizes < MIN_BLOCK_SIZE --- + if BLOCKSIZE_ROW < MIN_BLOCK_SIZE: + PADDED_BLOCKSIZE_ROW: tl.constexpr = MIN_BLOCK_SIZE + else: + PADDED_BLOCKSIZE_ROW: tl.constexpr = BLOCKSIZE_ROW + + if BLOCKSIZE_INNER < MIN_BLOCK_SIZE: + PADDED_BLOCKSIZE_INNER: tl.constexpr = MIN_BLOCK_SIZE + else: + PADDED_BLOCKSIZE_INNER: tl.constexpr = BLOCKSIZE_INNER - if BLOCKSIZE_COL < 16 or BLOCKSIZE_COL % 16 != 0: - PADDED_BLOCKSIZE_COL: tl.constexpr = 16 + if BLOCKSIZE_COL < MIN_BLOCK_SIZE or BLOCKSIZE_COL % MIN_BLOCK_SIZE != 0: + PADDED_BLOCKSIZE_COL: tl.constexpr = tl.cdiv(BLOCKSIZE_COL, MIN_BLOCK_SIZE) * MIN_BLOCK_SIZE else: PADDED_BLOCKSIZE_COL: tl.constexpr = BLOCKSIZE_COL + # --- Create block aranges based on padded sizes --- + row_block_arange = tl.arange(0, PADDED_BLOCKSIZE_ROW) + inner_block_arange = tl.arange(0, PADDED_BLOCKSIZE_INNER) col_block_arange = tl.arange(0, PADDED_BLOCKSIZE_COL) - # Pointers are set to the first block of the current row. + # --- Initialize pointers --- values_block_ptrs = ( values_ptr + values_batch_stride * batch_pid @@ -572,8 +589,10 @@ def _bsr_strided_addmm_kernel( + values_col_block_stride * inner_block_arange[None, :] ) - # NOTE: dense is advanced into all dimensions but the tiled row one. - # That will be advanced in the loop according to values in col_indices. + # Mask for loading values (handle row and inner padding) + values_load_mask = (row_block_arange[:, None] < BLOCKSIZE_ROW) & \ + (inner_block_arange[None, :] < BLOCKSIZE_INNER) + dense_block_ptrs = ( dense_ptr + dense_batch_stride * batch_pid @@ -582,7 +601,11 @@ def _bsr_strided_addmm_kernel( + dense_col_block_stride * col_block_arange[None, :] ) - # Pointers are set to exact write-to locations + # Mask for loading dense (handle inner and col padding) + dense_load_mask = (inner_block_arange[:, None] < BLOCKSIZE_INNER) & \ + (col_block_arange[None, :] < BLOCKSIZE_COL) + + # Output pointers set to exact write locations for the current block output_ptrs = ( output_ptr + output_batch_stride * batch_pid @@ -592,6 +615,10 @@ def _bsr_strided_addmm_kernel( + output_col_block_stride * col_block_arange[None, :] ) + # Mask for storing output (handle row and col padding) + output_store_mask = (row_block_arange[:, None] < BLOCKSIZE_ROW) & \ + (col_block_arange[None, :] < BLOCKSIZE_COL) + # Set pointer to the first nonzero element in the current row col_index_nnz_ptr = ( col_indices_ptr @@ -600,20 +627,23 @@ def _bsr_strided_addmm_kernel( ) output_acc_block = tl.zeros( - (BLOCKSIZE_ROW, PADDED_BLOCKSIZE_COL), dtype=acc_dtype + (PADDED_BLOCKSIZE_ROW, PADDED_BLOCKSIZE_COL), dtype=acc_dtype ) for _ in range(row_nnz): - values_block = tl.load(values_block_ptrs) + # Load sparse block values with row and inner padding mask + values_block = tl.load(values_block_ptrs, mask=values_load_mask, other=0.0) - # find which row of dense needs to get loaded - # for multiplication with values_block. dense_row_idx = tl.load(col_index_nnz_ptr) + # Load dense block with inner and col padding mask dense_block = tl.load( dense_block_ptrs + dense_tiled_row_stride * dense_row_idx, - mask=col_block_arange[None, :] < BLOCKSIZE_COL, + mask=dense_load_mask, + other=0.0, ) - # do block mm + # do block mm: tl.dot inputs now have logical shapes + # (PADDED_BLOCKSIZE_ROW, PADDED_BLOCKSIZE_INNER) and + # (PADDED_BLOCKSIZE_INNER, PADDED_BLOCKSIZE_COL), satisfying the assertion. output_acc_block += tl.dot( values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype ) @@ -622,10 +652,12 @@ def _bsr_strided_addmm_kernel( values_block_ptrs += values_nnz_stride col_index_nnz_ptr += col_indices_stride + # --- Apply alpha and beta scaling --- if not alpha_is_one: output_acc_block *= alpha if not left_alpha_is_one: + left_alpha_load_mask = row_block_arange[:, None] < BLOCKSIZE_ROW left_alpha_ptrs = ( left_alpha_ptr + left_alpha_batch_stride * batch_pid @@ -634,9 +666,10 @@ def _bsr_strided_addmm_kernel( + left_alpha_row_block_stride * row_block_arange[:, None] + left_alpha_col_block_stride * col_block_arange[None, :] ) - output_acc_block *= tl.load(left_alpha_ptrs) + output_acc_block *= tl.load(left_alpha_ptrs, mask=left_alpha_load_mask, other=1.0) if not right_alpha_is_one: + right_alpha_load_mask = col_block_arange[None, :] < BLOCKSIZE_COL right_alpha_ptrs = ( right_alpha_ptr + right_alpha_batch_stride * batch_pid @@ -645,9 +678,11 @@ def _bsr_strided_addmm_kernel( + right_alpha_row_block_stride * row_block_arange[:, None] + right_alpha_col_block_stride * col_block_arange[None, :] ) - output_acc_block *= tl.load(right_alpha_ptrs) + output_acc_block *= tl.load(right_alpha_ptrs, mask=right_alpha_load_mask, other=1.0) if beta_is_nonzero: + input_load_mask = (row_block_arange[:, None] < BLOCKSIZE_ROW) & \ + (col_block_arange[None, :] < BLOCKSIZE_COL) input_ptrs = ( input_ptr + input_batch_stride * batch_pid @@ -656,16 +691,18 @@ def _bsr_strided_addmm_kernel( + input_row_block_stride * row_block_arange[:, None] + input_col_block_stride * col_block_arange[None, :] ) + input_block = tl.load(input_ptrs, mask=input_load_mask, other=0.0) if beta_is_one: - output_acc_block += tl.load(input_ptrs) + output_acc_block += input_block else: - output_acc_block += beta * tl.load(input_ptrs) + output_acc_block += beta * input_block - # write back the result + # --- Write back the result --- + # Use the combined row and col padding mask for storing tl.store( output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty), - mask=col_block_arange[None, :] < BLOCKSIZE_COL, + mask=output_store_mask, ) else: diff --git a/torchao/sparsity/blocksparse.py b/torchao/sparsity/blocksparse.py index 6a33736130..b694a3a457 100644 --- a/torchao/sparsity/blocksparse.py +++ b/torchao/sparsity/blocksparse.py @@ -129,7 +129,7 @@ class BlockSparseTensor(TorchAOBaseTensor): bsr_crow_indices: Optional[torch.Tensor] bsr_col_indices: Optional[torch.Tensor] bsr_values: Optional[torch.Tensor] - blocksize: int + blocksize: Tuple[int, int] __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values"] @@ -137,7 +137,7 @@ class BlockSparseTensor(TorchAOBaseTensor): def __new__( # noqa: PYI034 cls, shape: torch.Size, - blocksize: int, + blocksize: Tuple[int, int], bsr_crow_indices: Optional[torch.Tensor], bsr_col_indices: Optional[torch.Tensor], bsr_values: Optional[torch.Tensor], @@ -165,7 +165,7 @@ def __new__( # noqa: PYI034 def __repr__(self) -> str: # type: ignore[override] assert hasattr(self, "shape") - return f"{self.__class__.__name__}(shape={self.shape})" + return f"{self.__class__.__name__}(shape={self.shape}, blocksize={self.blocksize})" def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool, int]]: inner_tensors = list( @@ -178,7 +178,7 @@ def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool, int]]: def __tensor_unflatten__( cls, inner_tensors, - tensor_meta: Tuple[torch.Size, bool, int], + tensor_meta: Tuple[torch.Size, bool, Tuple[int, int]], outer_size, outer_stride, ) -> torch.Tensor: @@ -259,7 +259,8 @@ def my_mul(bsr, t): assert t.dim() == 3 assert not bsr.requires_grad assert t.size(0) == 1 - t_blocked = t.view(t.size(0), t.size(1) // bsr.blocksize, bsr.blocksize, 1) + BM, BK = bsr.blocksize + t_blocked = t.view(t.size(0), t.size(1) // BM, BM, 1) masked_t = t_blocked.transpose(0, 1).index_select(0, bsr.col_indices()) new_values = bsr.values() * masked_t return BlockSparseTensor( @@ -307,6 +308,7 @@ def block_sparse_linear(func, types, args, kwargs): x = x_orig.reshape(-1, x_orig.size(-1)).t() M = w.shape[0] K = w.shape[1] + BM, BK = w.blocksize out = torch.ops.blocksparse.addmm( x,