Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W4A8 based on CUTLASS #880

Merged
merged 1 commit into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/float8_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
submodules: recursive
script: |
conda create -n venv python=3.9 -y
conda activate venv
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nightly_smoke_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
submodules: recursive
script: |
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
submodules: recursive
script: |
conda create -n venv python=3.9 -y
conda activate venv
Expand Down Expand Up @@ -93,6 +94,7 @@ jobs:
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
submodules: recursive
script: |
conda create -n venv python=3.9 -y
conda activate venv
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass
53 changes: 53 additions & 0 deletions benchmarks/benchmark_s8s4_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import pandas as pd
from torchao.utils import benchmark_torch_function_in_microseconds
from torchao.ops import s8s4_linear_cutlass
from tqdm import tqdm


def get_problem(m, n, k):
groupsize = k

dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((k, n), dtype=torch.half, device=dev)

A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
B = torch.randint(-128, 127, size=(n, k // 2), dtype=torch.int8, device=dev)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
C = None

return A_ref, B_ref, A, A_scale, B, B_scale, C


def benchmark(m: int, k: int, n: int):
A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k)

fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds(
s8s4_linear_cutlass, A, A_scale, B, B_scale, C
)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time,
"speedup (d/s)": fp16_time / s8s4_linear_cutlass_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)

results = []
for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n))

df = pd.DataFrame(results)
df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False)
print(df.to_markdown(index=False))
12 changes: 12 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ def get_extensions():
extra_compile_args["nvcc"].append("-g")
extra_link_args.append("/DEBUG")

use_cutlass = False
if use_cuda and not IS_WINDOWS:
use_cutlass = True
this_dir = os.path.abspath(os.path.curdir)
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass")
cutlass_include_dir = os.path.join(cutlass_dir, "include")
if use_cutlass:
extra_compile_args["nvcc"].extend([
"-DTORCHAO_USE_CUTLASS",
"-I" + cutlass_include_dir,
])

this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
Expand Down
11 changes: 10 additions & 1 deletion test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
run_tests,
)

from torchao.dtypes import Int4CPULayout, SemiSparseLayout
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_weight_only,
Expand Down Expand Up @@ -48,6 +48,15 @@ def get_quantization_functions(
)
else:
base_functions.append(int4_weight_only(group_size=32))
if device == "cuda":
base_functions.append(
int8_dynamic_activation_int4_weight(
group_size=None,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=CutlassInt4PackedLayout(),
)
)

if do_sparse:
base_functions.append(
Expand Down
80 changes: 80 additions & 0 deletions test/test_s8s4_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import itertools

import torch

import torchao
from torchao.ops import s8s4_linear_cutlass
from torchao.quantization.utils import group_quantize_tensor_symmetric
from torchao.utils import compute_max_diff

import pytest


S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
S8S4_LINEAR_CUTLASS_SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]
S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True]
S8S4_LINEAR_CUTLASS_TEST_PARAMS = list(
itertools.product(
S8S4_LINEAR_CUTLASS_DTYPE,
S8S4_LINEAR_CUTLASS_BATCH_SIZE,
S8S4_LINEAR_CUTLASS_SIZE_MNK,
S8S4_LINEAR_CUTLASS_USE_BIAS,
)
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS
)
def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias):
size_m, size_n, size_k = size_mnk

input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None

input_2d = input.view(-1, input.shape[-1])
input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
input_2d, 8, size_k, dtype
)
assert torch.all(input_2d_zeros == 0)
input_s8 = input_2d_s8.reshape(input.shape)
input_scales = input_2d_scales.reshape(input.shape[:-1])

weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
weight, 4, size_n, dtype
)
assert torch.all(weight_zeros == 0)
weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)

# If torch.nn.functional.linear(input, weight, bias) used as
# reference, the error would be too big. The calculation below is
# approximately what s8s4_linear_cutlass kernel is doing (except
# that matrrix multiplication is over integers there)).
size_m_2d = input_2d.shape[0]
output_ref = (
(input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
* input_2d_scales.view(size_m_2d, 1)
* weight_scales.view(1, size_n)
)
if bias is not None:
output_ref += bias
output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))

fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias)
try:
output = s8s4_linear_cutlass(*fn_inputs)
except NotImplementedError as e:
pytest.xfail("s8s4_linear_cutlass() op not implemented")

max_diff = compute_max_diff(output, output_ref)
assert max_diff < 5e-3
1 change: 1 addition & 0 deletions third_party/cutlass
Submodule cutlass added at bf9da7
14 changes: 13 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,18 @@ def ffn_or_attn_only(mod, fqn):
]
), f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size))
elif "int8adq-int4w-symm" in quantization:
from torchao.dtypes import CutlassInt4PackedLayout

quantize_(
model,
int8_dynamic_activation_int4_weight(
group_size=None,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=CutlassInt4PackedLayout(),
),
)
if "marlin" in quantization:
if "qqq" in quantization:
from torchao.dtypes import MarlinQQQLayout
Expand Down Expand Up @@ -1058,7 +1070,7 @@ def callback(x):
help=(
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
+ "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>"
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, int8adq-int4w-symm"
),
)
parser.add_argument(
Expand Down
Loading
Loading