Skip to content

add non-square bsr support #2007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
)
from torchao.sparsity.blocksparse import BlockSparseTensor
from torchao.sparsity.utils import create_block_sparse_tensor

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
Expand Down Expand Up @@ -167,6 +169,38 @@ def test_sparse(self, compile, input_shape):

torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4,
"pytorch 2.4+ feature due to need for custom op support",
)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("compile", [False])
@common_utils.parametrize("blocksize", [(64, 8), (8, 64)])
def test_non_square_sparse(self, compile, blocksize):
input_shape = 1024
input = torch.rand((input_shape, 1024)).half().cuda()
model = (
nn.Sequential(
nn.Linear(1024, 2048),
nn.Linear(2048, 1024),
)
.half()
.cuda()
.eval()
)

M, N = model[0].weight.shape
model[0].weight.data = create_block_sparse_tensor(M, N, blocksize, 0.5, torch.float16)
M, N = model[1].weight.shape
model[1].weight.data = create_block_sparse_tensor(M, N, blocksize, 0.5, torch.float16)
dense_result = model(input)

from torchao.sparsity import block_sparse_weight
sparsify_(model, block_sparse_weight(blocksize=blocksize))
sparse_result = model(input)

torch.testing.assert_close(dense_result, sparse_result, rtol=1e-2, atol=1e-2)


class TestQuantBlockSparseWeight(common_utils.TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "pytorch 2.6+ feature")
Expand Down
121 changes: 79 additions & 42 deletions torchao/kernel/bsr_triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

AUTOTUNE = os.getenv("BSR_AUTOTUNE", False)

MIN_BLOCK_SIZE = 16


def tune_bsr_dense_addmm(
input,
Expand Down Expand Up @@ -64,7 +66,8 @@ def tune_bsr_dense_addmm(
BM, BK = values.shape[batch_ndim + 1 : batch_ndim + 3]

# Reference parameters is a set of parameters that leads to a
# successful kernel call and the corresponding timing is used as a
# successful kernel call (might not be optimal but works)
# and the corresponding timing is used as a
# reference for computing speedups. Avoid changing the reference
# parameters when possible.
reference_meta = dict(
Expand All @@ -83,7 +86,7 @@ def tune_bsr_dense_addmm(
else:
version_dtype = (dtype, out_dtype)
version = (0, version_dtype, sparsity)
key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1)
key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1, N % max(N // BM, 1) == 0)

# For tuning, for an initial state, use parameters from the
# database if available, otherwise, use the reference parameters.
Expand Down Expand Up @@ -123,7 +126,7 @@ def step_meta_parameter(name, value, direction, meta, M=M, N=N, K=K, BM=BM, BK=B
is_log = name in {"SPLIT_N", "num_warps"}
min_value = dict(SPLIT_N=1, num_warps=1, num_stages=1, GROUP_SIZE_ROW=1)[name]
max_value = dict(SPLIT_N=max(N // BM, 1)).get(name)
value_step = dict(SPLIT_N=2, num_warps=2, num_stages=1, GROUP_SIZE_ROW=1)[name]
value_step = dict(SPLIT_N=2, num_warps=2, num_stages=1, GROUP_SIZE_ROW=2)[name]
if is_log:
next_value = (
value * value_step**direction
Expand All @@ -136,7 +139,7 @@ def step_meta_parameter(name, value, direction, meta, M=M, N=N, K=K, BM=BM, BK=B
next_value = max(next_value, min_value)
if max_value is not None:
next_value = min(next_value, max_value)
if name == "SPLIT_N" and N % next_value != 0:
if name == "SPLIT_N" and N % (next_value * BM) != 0:
return value
return next_value

Expand Down Expand Up @@ -171,8 +174,7 @@ def bsr_dense_addmm_meta(
M,
K,
N,
Ms,
Ks,
blocksize: tuple[int, int],
beta,
alpha,
SPLIT_N=None,
Expand All @@ -194,9 +196,12 @@ def bsr_dense_addmm_meta(
out_dtype = dtype
if sparsity is None:
sparsity = 0.5
if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}:
BM, BK = blocksize
# Calculate a default SPLIT_N that ensures BN is valid
default_split_n = max(N // BM, 1)
if {num_warps, num_stages, GROUP_SIZE_ROW} == {None}:
device_name = torch.cuda.get_device_name()
key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1)
key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1, N % default_split_n == 0)
if dtype is out_dtype:
version_dtype = dtype
else:
Expand All @@ -219,19 +224,17 @@ def bsr_dense_addmm_meta(
"bsr_dense_addmm", key, device_name, version=(_version, dtype, 0.5)
)
if meta is None:
# find approximate meta such that N % SPLIT_N == 0.
# If still no meta found, search for approximate matches considering N divisibility
approx_key = (*key[:2], "*", *key[3:-1], True)
matching_meta = get_meta(
"bsr_dense_addmm",
(*key[:2], "*", *key[3:]),
approx_key,
device_name,
version=(_version, version_dtype, 0.5),
)
if matching_meta is None and dtype is not out_dtype:
matching_meta = get_meta(
"bsr_dense_addmm",
(*key[:2], "*", *key[3:]),
device_name,
version=(_version, dtype, 0.5),
"bsr_dense_addmm", approx_key, device_name, version=(_version, dtype, 0.5)
)
for mkey in sorted(matching_meta or {}):
meta_ = matching_meta[mkey]
Expand All @@ -241,17 +244,18 @@ def bsr_dense_addmm_meta(
if N % c == 0 and n <= N:
meta = dict(meta_)
meta["SPLIT_N"] = N // c
break
if meta is not None:
meta.update(**extra)
return meta
else:
warn_once(
"bsr_dense_addmm uses non-optimal triton kernel parameters"
f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=} {out_dtype=}. "
f" for {M=} {K=} {N=} {BM=}, {BK=} {beta=} {alpha=} {dtype=} {out_dtype=}. "
"To find optimal triton kernel parameters, run with BSR_AUTOTUNE=1"
)

SPLIT_N = SPLIT_N or max(N // Ms, 1)
SPLIT_N = SPLIT_N or max(N // BM, 1)
GROUP_SIZE_ROW = GROUP_SIZE_ROW or 4
num_stages = num_stages or 4
num_warps = num_warps or 4
Expand Down Expand Up @@ -292,7 +296,7 @@ def bsr_dense_addmm(
col_indices = bsr.col_indices()
batch_ndim = crow_indices.dim() - 1
M, K = bsr.shape[batch_ndim : batch_ndim + 2]
blocksize = values.shape[batch_ndim + 1 : batch_ndim + 3]
BM, BK = values.shape[batch_ndim + 1 : batch_ndim + 3]
N = dense.shape[-1]

original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
Expand All @@ -309,7 +313,7 @@ def bsr_dense_addmm(
return out

if meta is None:
sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2)
sparsity = round(1 - bsr._nnz() * BM * BK / (M * K), 2)
if AUTOTUNE:
meta = tune_bsr_dense_addmm(
input,
Expand All @@ -330,8 +334,7 @@ def bsr_dense_addmm(
M,
K,
N,
blocksize[0],
blocksize[1],
(BM, BK),
beta,
alpha,
sparsity=sparsity,
Expand Down Expand Up @@ -376,18 +379,21 @@ def bsr_dense_addmm(
out,
) = prepare_inputs(bsr, input, dense, left_alpha, right_alpha, out)

BM, BK = blocksize
BM, BK = values.shape[batch_ndim + 1 : batch_ndim + 3]
SPLIT_N = meta.get("SPLIT_N", max(N // BM, 1))
BN = N // SPLIT_N

if N % SPLIT_N != 0:
raise ValueError(f"N ({N}) must be divisible by SPLIT_N ({SPLIT_N})")

out_untiled = out
out = tile_to_blocksize(out, (BM, BN))
dense = tile_to_blocksize(dense, (BK, BN))
input = tile_to_blocksize(input, (BM, BN))
left_alpha = tile_to_blocksize(left_alpha, (BM, BN))
right_alpha = tile_to_blocksize(right_alpha, (BM, BN))

# tl.dot supports float16, float32, int32 as accumulator types.
# Determine accumulator type based on output dtype
dot_out_dtype = {
torch.float16: tl.float32,
torch.bfloat16: tl.float32,
Expand Down Expand Up @@ -553,17 +559,28 @@ def _bsr_strided_addmm_kernel(
# Compute nnz for the row with number row_block_pid.
row_nnz = nnz_offset_next - nnz_offset

row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
inner_block_arange = tl.arange(0, BLOCKSIZE_INNER)
# --- Set up padding for block sizes < MIN_BLOCK_SIZE ---
if BLOCKSIZE_ROW < MIN_BLOCK_SIZE:
PADDED_BLOCKSIZE_ROW: tl.constexpr = MIN_BLOCK_SIZE
else:
PADDED_BLOCKSIZE_ROW: tl.constexpr = BLOCKSIZE_ROW

if BLOCKSIZE_INNER < MIN_BLOCK_SIZE:
PADDED_BLOCKSIZE_INNER: tl.constexpr = MIN_BLOCK_SIZE
else:
PADDED_BLOCKSIZE_INNER: tl.constexpr = BLOCKSIZE_INNER

if BLOCKSIZE_COL < 16 or BLOCKSIZE_COL % 16 != 0:
PADDED_BLOCKSIZE_COL: tl.constexpr = 16
if BLOCKSIZE_COL < MIN_BLOCK_SIZE or BLOCKSIZE_COL % MIN_BLOCK_SIZE != 0:
PADDED_BLOCKSIZE_COL: tl.constexpr = tl.cdiv(BLOCKSIZE_COL, MIN_BLOCK_SIZE) * MIN_BLOCK_SIZE
else:
PADDED_BLOCKSIZE_COL: tl.constexpr = BLOCKSIZE_COL

# --- Create block aranges based on padded sizes ---
row_block_arange = tl.arange(0, PADDED_BLOCKSIZE_ROW)
inner_block_arange = tl.arange(0, PADDED_BLOCKSIZE_INNER)
col_block_arange = tl.arange(0, PADDED_BLOCKSIZE_COL)

# Pointers are set to the first block of the current row.
# --- Initialize pointers ---
values_block_ptrs = (
values_ptr
+ values_batch_stride * batch_pid
Expand All @@ -572,8 +589,10 @@ def _bsr_strided_addmm_kernel(
+ values_col_block_stride * inner_block_arange[None, :]
)

# NOTE: dense is advanced into all dimensions but the tiled row one.
# That will be advanced in the loop according to values in col_indices.
# Mask for loading values (handle row and inner padding)
values_load_mask = (row_block_arange[:, None] < BLOCKSIZE_ROW) & \
(inner_block_arange[None, :] < BLOCKSIZE_INNER)

dense_block_ptrs = (
dense_ptr
+ dense_batch_stride * batch_pid
Expand All @@ -582,7 +601,11 @@ def _bsr_strided_addmm_kernel(
+ dense_col_block_stride * col_block_arange[None, :]
)

# Pointers are set to exact write-to locations
# Mask for loading dense (handle inner and col padding)
dense_load_mask = (inner_block_arange[:, None] < BLOCKSIZE_INNER) & \
(col_block_arange[None, :] < BLOCKSIZE_COL)

# Output pointers set to exact write locations for the current block
output_ptrs = (
output_ptr
+ output_batch_stride * batch_pid
Expand All @@ -592,6 +615,10 @@ def _bsr_strided_addmm_kernel(
+ output_col_block_stride * col_block_arange[None, :]
)

# Mask for storing output (handle row and col padding)
output_store_mask = (row_block_arange[:, None] < BLOCKSIZE_ROW) & \
(col_block_arange[None, :] < BLOCKSIZE_COL)

# Set pointer to the first nonzero element in the current row
col_index_nnz_ptr = (
col_indices_ptr
Expand All @@ -600,20 +627,23 @@ def _bsr_strided_addmm_kernel(
)

output_acc_block = tl.zeros(
(BLOCKSIZE_ROW, PADDED_BLOCKSIZE_COL), dtype=acc_dtype
(PADDED_BLOCKSIZE_ROW, PADDED_BLOCKSIZE_COL), dtype=acc_dtype
)
for _ in range(row_nnz):
values_block = tl.load(values_block_ptrs)
# Load sparse block values with row and inner padding mask
values_block = tl.load(values_block_ptrs, mask=values_load_mask, other=0.0)

# find which row of dense needs to get loaded
# for multiplication with values_block.
dense_row_idx = tl.load(col_index_nnz_ptr)
# Load dense block with inner and col padding mask
dense_block = tl.load(
dense_block_ptrs + dense_tiled_row_stride * dense_row_idx,
mask=col_block_arange[None, :] < BLOCKSIZE_COL,
mask=dense_load_mask,
other=0.0,
)

# do block mm
# do block mm: tl.dot inputs now have logical shapes
# (PADDED_BLOCKSIZE_ROW, PADDED_BLOCKSIZE_INNER) and
# (PADDED_BLOCKSIZE_INNER, PADDED_BLOCKSIZE_COL), satisfying the assertion.
output_acc_block += tl.dot(
values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype
)
Expand All @@ -622,10 +652,12 @@ def _bsr_strided_addmm_kernel(
values_block_ptrs += values_nnz_stride
col_index_nnz_ptr += col_indices_stride

# --- Apply alpha and beta scaling ---
if not alpha_is_one:
output_acc_block *= alpha

if not left_alpha_is_one:
left_alpha_load_mask = row_block_arange[:, None] < BLOCKSIZE_ROW
left_alpha_ptrs = (
left_alpha_ptr
+ left_alpha_batch_stride * batch_pid
Expand All @@ -634,9 +666,10 @@ def _bsr_strided_addmm_kernel(
+ left_alpha_row_block_stride * row_block_arange[:, None]
+ left_alpha_col_block_stride * col_block_arange[None, :]
)
output_acc_block *= tl.load(left_alpha_ptrs)
output_acc_block *= tl.load(left_alpha_ptrs, mask=left_alpha_load_mask, other=1.0)

if not right_alpha_is_one:
right_alpha_load_mask = col_block_arange[None, :] < BLOCKSIZE_COL
right_alpha_ptrs = (
right_alpha_ptr
+ right_alpha_batch_stride * batch_pid
Expand All @@ -645,9 +678,11 @@ def _bsr_strided_addmm_kernel(
+ right_alpha_row_block_stride * row_block_arange[:, None]
+ right_alpha_col_block_stride * col_block_arange[None, :]
)
output_acc_block *= tl.load(right_alpha_ptrs)
output_acc_block *= tl.load(right_alpha_ptrs, mask=right_alpha_load_mask, other=1.0)

if beta_is_nonzero:
input_load_mask = (row_block_arange[:, None] < BLOCKSIZE_ROW) & \
(col_block_arange[None, :] < BLOCKSIZE_COL)
input_ptrs = (
input_ptr
+ input_batch_stride * batch_pid
Expand All @@ -656,16 +691,18 @@ def _bsr_strided_addmm_kernel(
+ input_row_block_stride * row_block_arange[:, None]
+ input_col_block_stride * col_block_arange[None, :]
)
input_block = tl.load(input_ptrs, mask=input_load_mask, other=0.0)
if beta_is_one:
output_acc_block += tl.load(input_ptrs)
output_acc_block += input_block
else:
output_acc_block += beta * tl.load(input_ptrs)
output_acc_block += beta * input_block

# write back the result
# --- Write back the result ---
# Use the combined row and col padding mask for storing
tl.store(
output_ptrs,
output_acc_block.to(output_ptr.dtype.element_ty),
mask=col_block_arange[None, :] < BLOCKSIZE_COL,
mask=output_store_mask,
)

else:
Expand Down
Loading
Loading