Skip to content

Commit

Permalink
W4A8 based on CUTLASS
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Oct 18, 2024
1 parent 3475aed commit 492a5fa
Show file tree
Hide file tree
Showing 9 changed files with 700 additions and 9 deletions.
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def get_extensions():
extension = CUDAExtension if use_cuda else CppExtension

if not IS_WINDOWS:
import cutlass_library
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
cutlass_include_dir = os.path.join(cutlass_library_dir, "source", "include")

extra_link_args = []
extra_compile_args = {
"cxx": [
Expand All @@ -74,6 +78,7 @@ def get_extensions():
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
"-I" + cutlass_include_dir,
]
}

Expand Down
60 changes: 60 additions & 0 deletions test/test_s8s4_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# FIXME: move this test to the appropriate test file, and make sure it passes!!!

import copy

from torchao.quantization import quantize_
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight

import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

import pytest


class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(128, 256)
self.linear2 = torch.nn.Linear(256, 128, bias=False)

def forward(self, x):
x = self.linear1(x)
x = torch.nn.functional.relu(x)
x = self.linear2(x)
return x


class TestS8S4LinearCUTLASS(TestCase):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_s8s4_linear_cutlass_(self):
# FIXME: remove this!
torch.manual_seed(0)

dtype = torch.float16 # torch.bfloat16

input = torch.rand((64, 128)).to(dtype).cuda()
model = ToyModel().to(dtype).cuda()

output_ref = model(input)

modelq = copy.deepcopy(model)
quantize_(
modelq,
int8_dynamic_activation_int4_weight(
group_size=None,
mapping_type=MappingType.SYMMETRIC,
input_mapping_type=MappingType.SYMMETRIC,
pack_bits=True,
)
)
output = modelq(input)

assert torch.allclose(output, output_ref, rtol=1e-1, atol=0)


if __name__ == "__main__":
run_tests()
14 changes: 13 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torchao
import torch._dynamo.config
import torch._inductor.config
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

Expand Down Expand Up @@ -209,6 +210,7 @@ def main(
int8_weight_only,
int8_dynamic_activation_int8_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
fpx_weight_only,
uintx_weight_only,
autoquant,
Expand All @@ -221,6 +223,16 @@ def main(
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
quantize_(model, int8_dynamic_activation_int8_weight())
if "w4a8-cutlass" in quantization:
quantize_(
model,
int8_dynamic_activation_int4_weight(
group_size=None,
mapping_type=MappingType.SYMMETRIC,
input_mapping_type=MappingType.SYMMETRIC,
pack_bits=True,
)
)
if "int4wo" in quantization:
if "hqq" in quantization:
use_hqq=True
Expand Down Expand Up @@ -459,7 +471,7 @@ def callback(x):
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
'Which quantization techniques to apply: int8dq, w4a8-cutlass, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
)
)
Expand Down
Loading

0 comments on commit 492a5fa

Please sign in to comment.