Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse marlin 2:4 gemm op #733

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 116 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
run_tests,
)
from torch.testing._internal.optests import opcheck
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
from torchao.prototype.quant_llm import from_scaled_tc_fpx
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
import pytest

if is_fbcode():
Expand Down Expand Up @@ -302,5 +303,119 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
test_utils=test_utils,
)


MARLIN_24_K_CHUNKS = [128]
MARLIN_24_N_CHUNKS = [512]
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
]
MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]

MARLIN_TEST_PARAMS = list(itertools.product(
MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS,
MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
))

def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int):
orig_device = w.device
size_k, size_n = w.shape

assert w.is_floating_point(), "w must be float"

if group_size == -1:
group_size = size_k
assert group_size <= size_k

max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2

# Reshape to [groupsize, -1]
if group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))

# Compute scale for each group
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
s *= 2 / max_q_val # 2 => symmetric

# Quantize
q_w = torch.round(w / s).int()
q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val)

# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s

# Restore original shapes
if group_size < size_k:

def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w

q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)

s = s.reshape((-1, size_n)).contiguous()

return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s.to(device=orig_device),
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str)
def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda")

# Inject 2:4 sparsity
w_24, _ = inject_24(b_weight, size_k, size_n)

# Symmetric quantize
w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size)

# Obtains reference output
output_ref = torch.matmul(a_input, w_24_ref)

# Packs to marlin 2:4
marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size)
workspace_24 = marlin_24_workspace(size_n)

fn_inputs = (
a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1],
)
output = torchao.ops.marlin_24_gemm(*fn_inputs)
torch.cuda.synchronize()

max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04

# Performs opcheck
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"]
opcheck(
torch.ops.torchao.marlin_24_gemm,
fn_inputs,
test_utils=test_utils,
)


if __name__ == "__main__":
run_tests()
51 changes: 51 additions & 0 deletions torchao/csrc/cuda/sparse_marlin/base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace torchao {

constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }

// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};

template <int M_, int N_, int K_>
struct ShapeBase {
static constexpr int M = M_, N = N_, K = K_;
};

using I4 = Vec<int, 4>;

// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragM = Vec<uint, 1>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales

} // namespace torchao
Loading
Loading