From 356ab282a736f6531debb85045daeae5b6cefc2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Mon, 9 Sep 2024 21:32:06 +0200 Subject: [PATCH] W4A8 based on CUTLASS CUTLASS-based s8s4_linear_cutlass() operator is introduced, performing linear transformation over quantized 8-bit input and quantized 4-bit weight tensors, with corresponding floating point scale tensors attached. A benchmark script, for comparing performance of MM based on this linear operator with MM over 16-bit floating point tensors is supplied in benchmarks/benchmarks/benchmark_s8s4_cutlass.py. The Llama generator script torchao/_models/llama/generate.py is changed, to add "int8adq-int4w-symm" quantization as an option, that will in turn activate s8s4_linear_cutlass() operator. With this type of quantization activated, i.e. if generate.py script run as follows: python generate.py --compile --precision=torch.float16 -q int8adq-int4w-symm the generator achieves around 133 tok/sec on A100, vs. around 93 tok/sec without quantization, i.e. when generate.py script run as follows: python generate.py --compile --precision=torch.float16 --- .github/workflows/float8_test.yml | 1 + .github/workflows/nightly_smoke_test.yml | 1 + .github/workflows/regression_test.yml | 2 + .gitmodules | 3 + benchmarks/benchmark_s8s4_cutlass.py | 53 ++ setup.py | 12 + test/dtypes/test_affine_quantized.py | 11 +- test/test_s8s4_linear_cutlass.py | 80 +++ third_party/cutlass | 1 + torchao/_models/llama/generate.py | 14 +- .../s8s4_linear_cutlass.cu | 536 ++++++++++++++++++ torchao/dtypes/__init__.py | 2 + torchao/dtypes/affine_quantized_tensor.py | 1 + torchao/dtypes/affine_quantized_tensor_ops.py | 8 + torchao/dtypes/uintx/__init__.py | 4 + .../uintx/cutlass_int4_packed_layout.py | 160 ++++++ torchao/dtypes/uintx/plain_layout.py | 1 + torchao/kernel/intmm.py | 4 +- torchao/ops.py | 106 ++++ torchao/quantization/quant_api.py | 31 +- 20 files changed, 1024 insertions(+), 7 deletions(-) create mode 100644 .gitmodules create mode 100644 benchmarks/benchmark_s8s4_cutlass.py create mode 100644 test/test_s8s4_linear_cutlass.py create mode 160000 third_party/cutlass create mode 100644 torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu create mode 100644 torchao/dtypes/uintx/cutlass_int4_packed_layout.py diff --git a/.github/workflows/float8_test.yml b/.github/workflows/float8_test.yml index 760beb6319..75482c9e24 100644 --- a/.github/workflows/float8_test.yml +++ b/.github/workflows/float8_test.yml @@ -35,6 +35,7 @@ jobs: runner: ${{ matrix.runs-on }} gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }} + submodules: recursive script: | conda create -n venv python=3.9 -y conda activate venv diff --git a/.github/workflows/nightly_smoke_test.yml b/.github/workflows/nightly_smoke_test.yml index 9f3dc3c0fb..d215f22ed2 100644 --- a/.github/workflows/nightly_smoke_test.yml +++ b/.github/workflows/nightly_smoke_test.yml @@ -31,6 +31,7 @@ jobs: runner: ${{ matrix.runs-on }} gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }} + submodules: recursive script: | python -m pip install --upgrade pip pip install ${{ matrix.torch-spec }} diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 0488e6d922..74b39d2ef2 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -40,6 +40,7 @@ jobs: runner: ${{ matrix.runs-on }} gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }} + submodules: recursive script: | conda create -n venv python=3.9 -y conda activate venv @@ -93,6 +94,7 @@ jobs: runner: ${{ matrix.runs-on }} gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }} + submodules: recursive script: | conda create -n venv python=3.9 -y conda activate venv diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..3f0af4cd52 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/cutlass"] + path = third_party/cutlass + url = https://github.com/NVIDIA/cutlass diff --git a/benchmarks/benchmark_s8s4_cutlass.py b/benchmarks/benchmark_s8s4_cutlass.py new file mode 100644 index 0000000000..397544b658 --- /dev/null +++ b/benchmarks/benchmark_s8s4_cutlass.py @@ -0,0 +1,53 @@ +import torch +import pandas as pd +from torchao.utils import benchmark_torch_function_in_microseconds +from torchao.ops import s8s4_linear_cutlass +from tqdm import tqdm + + +def get_problem(m, n, k): + groupsize = k + + dev = torch.device("cuda") + A_ref = torch.randn((m, k), dtype=torch.half, device=dev) + B_ref = torch.randn((k, n), dtype=torch.half, device=dev) + + A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev) + A_scale = torch.randn((m,), dtype=torch.half, device=dev) + B = torch.randint(-128, 127, size=(n, k // 2), dtype=torch.int8, device=dev) + B_scale = torch.randn((n,), dtype=torch.half, device=dev) + C = None + + return A_ref, B_ref, A, A_scale, B, B_scale, C + + +def benchmark(m: int, k: int, n: int): + A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k) + + fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref) + s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds( + s8s4_linear_cutlass, A, A_scale, B, B_scale, C + ) + + return { + "m": m, + "k": k, + "n": n, + "fp16_latency (ms)": fp16_time, + "s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time, + "speedup (d/s)": fp16_time / s8s4_linear_cutlass_time, + } + + +if __name__ == "__main__": + k_vals = (8192, 8192, 8192, 28672) + n_vals = (8192, 10240, 57344, 8192) + + results = [] + for m in tqdm([1 << i for i in range(10)]): + for n, k in zip(n_vals, k_vals): + results.append(benchmark(m, k, n)) + + df = pd.DataFrame(results) + df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False) + print(df.to_markdown(index=False)) diff --git a/setup.py b/setup.py index 261a620a38..21049d98ff 100644 --- a/setup.py +++ b/setup.py @@ -93,6 +93,18 @@ def get_extensions(): extra_compile_args["nvcc"].append("-g") extra_link_args.append("/DEBUG") + use_cutlass = False + if use_cuda and not IS_WINDOWS: + use_cutlass = True + this_dir = os.path.abspath(os.path.curdir) + cutlass_dir = os.path.join(this_dir, "third_party", "cutlass") + cutlass_include_dir = os.path.join(cutlass_dir, "include") + if use_cutlass: + extra_compile_args["nvcc"].extend([ + "-DTORCHAO_USE_CUTLASS", + "-I" + cutlass_include_dir, + ]) + this_dir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(this_dir, "torchao", "csrc") sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 0939c49f5d..f08ba7aa72 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -8,7 +8,7 @@ run_tests, ) -from torchao.dtypes import Int4CPULayout, SemiSparseLayout +from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout from torchao.quantization import ( float8_weight_only, int4_weight_only, @@ -48,6 +48,15 @@ def get_quantization_functions( ) else: base_functions.append(int4_weight_only(group_size=32)) + if device == "cuda": + base_functions.append( + int8_dynamic_activation_int4_weight( + group_size=None, + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, + layout=CutlassInt4PackedLayout(), + ) + ) if do_sparse: base_functions.append( diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py new file mode 100644 index 0000000000..15b4c2673d --- /dev/null +++ b/test/test_s8s4_linear_cutlass.py @@ -0,0 +1,80 @@ +import itertools + +import torch + +import torchao +from torchao.ops import s8s4_linear_cutlass +from torchao.quantization.utils import group_quantize_tensor_symmetric +from torchao.utils import compute_max_diff + +import pytest + + +S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] +S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] +S8S4_LINEAR_CUTLASS_SIZE_MNK = [ + (2, 512, 128), + (3, 2048, 2048), + (4, 3584, 640), + (13, 8704, 8576), + (26, 18944, 1664), + (67, 6656, 1408), +] +S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True] +S8S4_LINEAR_CUTLASS_TEST_PARAMS = list( + itertools.product( + S8S4_LINEAR_CUTLASS_DTYPE, + S8S4_LINEAR_CUTLASS_BATCH_SIZE, + S8S4_LINEAR_CUTLASS_SIZE_MNK, + S8S4_LINEAR_CUTLASS_USE_BIAS, + ) +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS +) +def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias): + size_m, size_n, size_k = size_mnk + + input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda") + weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda") + bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None + + input_2d = input.view(-1, input.shape[-1]) + input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric( + input_2d, 8, size_k, dtype + ) + assert torch.all(input_2d_zeros == 0) + input_s8 = input_2d_s8.reshape(input.shape) + input_scales = input_2d_scales.reshape(input.shape[:-1]) + + weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric( + weight, 4, size_n, dtype + ) + assert torch.all(weight_zeros == 0) + weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF) + + # If torch.nn.functional.linear(input, weight, bias) used as + # reference, the error would be too big. The calculation below is + # approximately what s8s4_linear_cutlass kernel is doing (except + # that matrrix multiplication is over integers there)). + size_m_2d = input_2d.shape[0] + output_ref = ( + (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T) + * input_2d_scales.view(size_m_2d, 1) + * weight_scales.view(1, size_n) + ) + if bias is not None: + output_ref += bias + output_ref = output_ref.reshape(input.shape[:-1] + (size_n,)) + + fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias) + try: + output = s8s4_linear_cutlass(*fn_inputs) + except NotImplementedError as e: + pytest.xfail("s8s4_linear_cutlass() op not implemented") + + max_diff = compute_max_diff(output, output_ref) + assert max_diff < 5e-3 diff --git a/third_party/cutlass b/third_party/cutlass new file mode 160000 index 0000000000..bf9da7b76c --- /dev/null +++ b/third_party/cutlass @@ -0,0 +1 @@ +Subproject commit bf9da7b76c766d7ee7d536afc77880a4ef1f1156 diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index c56ee18f92..a111e3e7c8 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -434,6 +434,18 @@ def ffn_or_attn_only(mod, fqn): ] ), f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" quantize_(model, int4_weight_only(group_size=group_size)) + elif "int8adq-int4w-symm" in quantization: + from torchao.dtypes import CutlassInt4PackedLayout + + quantize_( + model, + int8_dynamic_activation_int4_weight( + group_size=None, + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, + layout=CutlassInt4PackedLayout(), + ), + ) if "marlin" in quantization: if "qqq" in quantization: from torchao.dtypes import MarlinQQQLayout @@ -1058,7 +1070,7 @@ def callback(x): help=( "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, " + "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx--, uintx---hqq, sparse-marlin, spinquant, " - + "embed-int8wo, marlin_qqq, gemlite---" + + "embed-int8wo, marlin_qqq, gemlite---, int8adq-int4w-symm" ), ) parser.add_argument( diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu new file mode 100644 index 0000000000..2daefb7773 --- /dev/null +++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu @@ -0,0 +1,536 @@ +#include + +#include +#include +#include +#include + +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 11080) +#define BUILD_S8S4_LINEAR_CUTLASS +#endif + +#if defined(BUILD_S8S4_LINEAR_CUTLASS) +#include +#include +#include +#include +#include + +#define CUTLASS_STATUS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + __func__, " : Got CUTLASS error: ", \ + cutlassGetStatusString(status)); \ + } +#endif + +namespace torchao { + +#if defined(BUILD_S8S4_LINEAR_CUTLASS) +template< + typename ElementA, + typename ElementAScale, + typename ElementB, + typename ElementBScale, + typename ElementC, + typename ElementAccumulator, + typename ElementEpilogue, + typename ElementOutput, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + int NumStages, + bool use_tensor_c> +void s8s4_linear_kernel_cutlass( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + const int m = tensor_a.size(0); + const int n = tensor_b.size(0); + const int k = tensor_a.size(1); + + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentAScale = + 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentBScale = + 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentOutput = + 128 / cutlass::sizeof_bits::value; + + // Check for current CUTLASS limitations w.r.t. alignments. + TORCH_CHECK(k % AlignmentA == 0, + __func__, " : Number of columns of tensor A must be divisible ", + "by ", AlignmentA); + TORCH_CHECK(k % AlignmentB == 0, + __func__, " : Number of columns of tensor B must be divisible ", + "by ", AlignmentB); + TORCH_CHECK(n % AlignmentC == 0, + __func__, " : Number of columns of tensor C must be divisible ", + "by ", AlignmentC); + + using SmArch = cutlass::arch::Sm80; + using ThreadblockSwizzle = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; + + constexpr auto NumEVTEpilogueStages = 1; + + using TensorAScaleTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementAScale, + AlignmentAScale, + NumEVTEpilogueStages>; + using TensorBScaleTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementBScale, + AlignmentBScale, + NumEVTEpilogueStages>; + using TensorCTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementC, + AlignmentC, + NumEVTEpilogueStages>; + using OutputTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementOutput, + AlignmentOutput, + NumEVTEpilogueStages>; + + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + using TensorAScale = + cutlass::epilogue::threadblock::VisitorColBroadcast< + TensorAScaleTileThreadMap, + ElementAScale, + cute::Stride>; + using TensorAScaleArguments = typename TensorAScale::Arguments; + + using TensorBScale = + cutlass::epilogue::threadblock::VisitorRowBroadcast< + TensorBScaleTileThreadMap, + ElementBScale, + cute::Stride>; + using TensorBScaleArguments = typename TensorBScale::Arguments; + + using TensorCScalar = + cutlass::epilogue::threadblock::VisitorScalarBroadcast; + using TensorCTensor = + cutlass::epilogue::threadblock::VisitorRowBroadcast< + TensorCTileThreadMap, + ElementC, + cute::Stride>; + using TensorC = + std::conditional_t; + using TensorCArguments = typename TensorC::Arguments; + + using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT< + ApplyAScale, + Accum, + TensorAScale>; + + using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT< + ApplyBScale, + EVTApplyAScale, + TensorBScale>; + + using ApplySum = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT< + ApplySum, + EVTApplyBScale, + TensorC>; + + using Output = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementOutput, + cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride // StrideMNL + >; + + using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT< + Output, + EVTApplySum>; + + using EVTKernel = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, + ElementC, LayoutC, AlignmentC, + ElementAccumulator, + ElementEpilogue, + cutlass::arch::OpClassTensorOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EVTOutput, + ThreadblockSwizzle, + NumStages, + cutlass::arch::OpMultiplyAddMixedInputUpcast, + NumEVTEpilogueStages + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalBase; + + cutlass::gemm::GemmCoord problem_size(m, n, k); + constexpr auto SplitKFactor = 1; + + TensorAScaleArguments tensor_a_scale_arguments{ + (ElementAScale*)tensor_a_scale.data_ptr(), + ElementAScale(1), + {cute::_1{}, cute::_0{}, problem_size.m()} + }; + TensorBScaleArguments tensor_b_scale_arguments{ + (ElementBScale*)tensor_b_scale.data_ptr(), + ElementBScale(1), + {cute::_0{}, cute::_1{}, problem_size.n()} + }; + TensorCArguments tensor_c_arguments{ + [&]() -> TensorCArguments { + if constexpr (use_tensor_c) { + return {(ElementC*)tensor_c.data_ptr(), + ElementC(0), + {cute::_0{}, cute::_1{}, problem_size.n()}}; + } else { + return {ElementC(0)}; + } + }() + }; + typename Output::Arguments output_arguments{ + (ElementOutput*)tensor_d.data_ptr(), + {problem_size.n(), cute::_1{}, problem_size.mn().product()} + }; + typename EVTOutput::Arguments callback_arguments{ + { + { + { + {}, // Accum + tensor_a_scale_arguments, // TensorAScale + {} // ApplyAScale + }, // EVTApplyAScale + tensor_b_scale_arguments, // TensorBScale + {}, // ApplyBScale + }, // EVTApplyBScale + tensor_c_arguments, // TensorC + {} // ApplySum + }, // EVTApplySum + output_arguments // Output + }; // EVTOutput + constexpr auto AvailSms = -1; + + typename Gemm::Arguments arguments( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + SplitKFactor, + callback_arguments, // arguments of EVT callbacks + (ElementA*)tensor_a.data_ptr(), + (ElementB*)tensor_b.data_ptr(), + nullptr, // ptr C (unused) + nullptr, // ptr D (unused) + problem_size.mk().product(), // batch stride A + problem_size.nk().product(), // batch stride B + 0, // batch stride C (unused) + 0, // batch stride D (unused) + problem_size.k(), // stride A + problem_size.k(), // stride B + 0, // stride C (unused) + 0, // stride D (unused) + AvailSms); + + Gemm gemm_op; + + cutlass::Status status; + + // Verify that GEMM operation with given arguments can be performed + // by CUTLASS. + status = gemm_op.can_implement(arguments); + CUTLASS_STATUS_CHECK(status); + + // Allocate workspace for CUTLASS mixed datatypes GEMM kernel. + const auto workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = tensor_a.new_empty({(int64_t)workspace_size}, + at::TensorOptions().dtype(at::kByte)); + + // Initialize CUTLASS mixed datatypes GEMM object. + status = gemm_op.initialize(arguments, workspace.data_ptr(), + at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); + + // Perform mixed datatypes GEMM operation. + status = gemm_op.run(at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); +} + +template< + typename ElementA, + typename ElementAScale, + typename ElementB, + typename ElementBScale, + typename ElementC, + typename ElementAccumulator, + typename ElementEpilogue, + typename ElementOutput, + bool use_tensor_c> +void +s8s4_linear_cutlass_dispatch_shapes( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + // A minimal heuristic to improve performance for small number of + // inputs cases. + if (tensor_a.size(0) <= 16) { + using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; + constexpr auto NumStages = 6; + s8s4_linear_kernel_cutlass< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, + ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else if (tensor_a.size(0) <= 32) { + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; + constexpr auto NumStages = 5; + s8s4_linear_kernel_cutlass< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, + ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + constexpr auto NumStages = 4; + s8s4_linear_kernel_cutlass< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, + ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } +} + +#endif + +// Perform linear operation, using corresponding CUTLASS mixed +// data-types GEMM kernel, to given arguments: +// result = (input * input_scale) @ (weight * weight_scale).T + bias +// Notes: The "input_scale" tensor is expected to be a vector, of size +// equal to number of rows of "input" tensor. The "weight_scale" +// tensor is expected to be a vector, of size equal to number of rows +// of "weight" tensor. The "bias" tensor is expected to be a vector, +// of size equal to number of rows of "weight" tensor. +at::Tensor +s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, + const at::Tensor& weight, const at::Tensor& weight_scale, + const at::Tensor& bias) { +#if defined(BUILD_S8S4_LINEAR_CUTLASS) + // For now, only CC 8.x devices are supported. + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm8x = dprops->major == 8; + TORCH_CHECK(is_sm8x, + __func__, " : Supported only on GPUs with compute capability " + "8.x"); + + // Validate datatypes of arguments. + TORCH_CHECK(input.dtype() == at::kChar, + __func__, " : The input datatype ", input.dtype(), + " not supported"); + TORCH_CHECK(input_scale.dtype() == at::kHalf || + input_scale.dtype() == at::kBFloat16, + __func__, " : The input scale datatype ", input_scale.dtype(), + " not supported"); + TORCH_CHECK(weight.dtype() == at::kChar, " : The weight datatype ", + weight.dtype(), " not supported"); + TORCH_CHECK(weight_scale.dtype() == input_scale.dtype(), + __func__, " : Expected weight scale datatype ", + input_scale.dtype(), ", got ", weight_scale.dtype()); + if (bias.numel() > 0) { + TORCH_CHECK(bias.dtype() == input_scale.dtype(), + __func__, " : Expected bias datatype ", input_scale.dtype(), + ", got ", bias.dtype()); + } + + // Validate layouts of arguments. + TORCH_CHECK(input.dim() >= 2, + __func__, " : Expected input argument to be 2D or " + "higher-dimensional tensor, got ", input.dim(), " dims"); + TORCH_CHECK(input.layout() == at::Layout::Strided, + __func__, " : Expected input argument to be strided, got layout ", + input.layout()); + TORCH_CHECK(input_scale.dim() == input.dim() - 1, + __func__, " : Expected input scale argument to be ", + input.dim() - 1, "D tensor, got ", input_scale.dim(), " dims"); + TORCH_CHECK(input_scale.layout() == at::Layout::Strided, + __func__, " : Expected input scale argument to be strided, got " + "layout ", input_scale.layout()); + TORCH_CHECK(weight.dim() == 2, + __func__, " : Expected weight argument to be 2D tensor, got ", + weight.dim(), " dims"); + TORCH_CHECK(weight.layout() == at::Layout::Strided, + __func__, + " : Expected weight argument to be strided, got layout ", + weight.layout()); + TORCH_CHECK(weight_scale.dim() == 1 || weight_scale.dim() == 2, + __func__, " : Expected weight scale argument to be 1D or 2D ", + "tensor, got ", weight_scale.dim(), " dims"); + TORCH_CHECK(weight_scale.layout() == at::Layout::Strided, + __func__, " : Expected weight scale argument to be strided, got " + "layout ", weight_scale.layout()); + if (bias.numel() > 0) { + TORCH_CHECK(bias.dim() == 1, + __func__, " : Expected bias argument to be 1D tensor, got ", + bias.dim(), " dims"); + TORCH_CHECK(bias.layout() == at::Layout::Strided, + __func__, " : Expected bias argument to be strided, got ", + "layout ", bias.layout()); + } + + // Squash the input tensor to 2D tensor. + const auto input_sizes = input.sizes().vec(); + const auto input_2d = input.reshape({-1, input_sizes.back()}); + const auto input_scale_sizes = input_scale.sizes().vec(); + const auto input_scale_1d = input_scale.reshape({-1}); + const auto weight_scale_1d = weight_scale.reshape({-1}); + + // Validate sizes of arguments. + TORCH_CHECK(input_2d.size(1) == 2 * weight.size(1), + __func__, " : Expected input argument to have ", + 2 * weight.size(1), " columns, but got ", input_2d.size(1)); + for (auto i = 0; i < input_scale_sizes.size(); ++i) + TORCH_CHECK(input_scale_sizes[i] == input_sizes[i], + __func__, " : Expected input scale argument size at position ", + i, " to be ", input_sizes[i], ", but got ", + input_scale_sizes[i]); + TORCH_CHECK(weight_scale_1d.numel() == weight.size(0), + __func__, " : Expected weight scale argument to have ", + weight.size(0), " elements, got ", weight_scale_1d.numel(), + " elements"); + if (bias.numel() > 0) { + TORCH_CHECK(bias.numel() == weight.size(0), + __func__, " : Expected bias argument to have ", weight.size(0), + " elements, got ", bias.numel(), " elements"); + } + + // Validate strides of arguments. + const auto input_2d_strides = input_2d.strides(); + TORCH_CHECK(input_2d_strides[0] >= 1 && input_2d_strides[1] == 1, + __func__, " : Expected input argument in row-major layout"); + const auto input_scale_1d_strides = input_scale_1d.strides(); + TORCH_CHECK(input_scale_1d_strides[0] == 1, + __func__, " : Expected input scale argument to be contiguous"); + const auto weight_strides = weight.strides(); + TORCH_CHECK(weight_strides[0] >= 1 && weight_strides[1] == 1, + __func__, " : Expected weight argument in row-major layout"); + const auto weight_scale_1d_strides = weight_scale_1d.strides(); + TORCH_CHECK(weight_scale_1d_strides[0] == 1, + __func__, " : Expected weight scale argument to be contiguous"); + if (bias.numel() > 0) { + const auto bias_strides = bias.strides(); + TORCH_CHECK(bias_strides[0] == 1, + __func__, " : Expected bias argument to be contiguous"); + } + + // Introduce alias names for arguments, according to the CUTLASS + // naming conventions. + const auto& tensor_a = input_2d; + const auto& tensor_a_scale = input_scale_1d; + const auto& tensor_b = weight; + const auto& tensor_b_scale = weight_scale_1d; + const auto& tensor_c = bias; + + // Create output tensor. + at::Tensor tensor_d = + tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)}); + + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + using ElementAccumulator = int32_t; + AT_DISPATCH_SWITCH( + input_scale.scalar_type(), + "s8s4_linear_cutlass", + AT_DISPATCH_CASE( + at::ScalarType::Half, + [&]() { + using ElementAScale = cutlass::half_t; + using ElementBScale = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementEpilogue = float; + using ElementOutput = cutlass::half_t; + if (bias.numel() > 0) { + s8s4_linear_cutlass_dispatch_shapes< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, true>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else { + s8s4_linear_cutlass_dispatch_shapes< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, false>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } + }) + AT_DISPATCH_CASE( + at::ScalarType::BFloat16, + [&]() { + using ElementAScale = cutlass::bfloat16_t; + using ElementBScale = cutlass::bfloat16_t; + using ElementC = cutlass::bfloat16_t; + using ElementEpilogue = float; + using ElementOutput = cutlass::bfloat16_t; + if (bias.numel() > 0) { + s8s4_linear_cutlass_dispatch_shapes< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, true>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else { + s8s4_linear_cutlass_dispatch_shapes< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, false>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } + })); + + auto tensor_d_sizes = input_sizes; + tensor_d_sizes.back() = weight.size(0); + return tensor_d.reshape(tensor_d_sizes); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); + return at::Tensor{}; +#endif +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::s8s4_linear_cutlass", &s8s4_linear_cutlass); +} + +} // namespace torchao diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index c7d98cb56e..9cbd4cd2a0 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -14,6 +14,7 @@ from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( BlockSparseLayout, + CutlassInt4PackedLayout, Int4CPULayout, MarlinQQQLayout, MarlinQQQTensor, @@ -50,4 +51,5 @@ "MarlinQQQTensor", "MarlinQQQLayout", "Int4CPULayout", + "CutlassInt4PackedLayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index ba06d877f3..e7aca34c5f 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -472,6 +472,7 @@ def _apply_fn_to_data(self, fn): register_layout = AffineQuantizedTensor.register_layout get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor + to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index a1667e8fbb..76df949852 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -20,6 +20,10 @@ _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) +from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( + _linear_int8_act_int4_weight_cutlass_check, + _linear_int8_act_int4_weight_cutlass_impl, +) from torchao.dtypes.uintx.gemlite_layout import ( _linear_fp_act_int4_weight_gemlite_check, _linear_fp_act_int4_weight_gemlite_impl, @@ -143,6 +147,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_fp_act_int4_weight_gemlite_check, _linear_fp_act_int4_weight_gemlite_impl, ), + ( + _linear_int8_act_int4_weight_cutlass_check, + _linear_int8_act_int4_weight_cutlass_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index a8a4db9420..7cf375feb4 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,6 +1,9 @@ from .block_sparse_layout import ( BlockSparseLayout, ) +from .cutlass_int4_packed_layout import ( + CutlassInt4PackedLayout, +) from .int4_cpu_layout import ( Int4CPULayout, ) @@ -32,4 +35,5 @@ "MarlinQQQLayout", "MarlinQQQTensor", "to_marlinqqq_quantized_intx", + "CutlassInt4PackedLayout", ] diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py new file mode 100644 index 0000000000..3acf66a201 --- /dev/null +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -0,0 +1,160 @@ +from dataclasses import dataclass + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.plain_layout import ( + _aqt_is_int8_reduced_range, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout + +aten = torch.ops.aten + + +def _aqt_is_int4(aqt): + """Check if an AffineQuantizedTensor is int4 quantized Tensor""" + # TODO: use torch.int4 + return ( + aqt.tensor_impl.dtype == torch.int8 + and aqt.quant_min == -8 + and aqt.quant_max == 7 + ) + + +@dataclass(frozen=True) +class CutlassInt4PackedLayout(Layout): + pass + + +@register_layout(CutlassInt4PackedLayout) +class Int4PackedTensorImpl(AQTTensorImpl): + """ + TensorImpl storage class for int4 packed layout for affine quantized tensor. + """ + + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + self._layout = _layout + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + return ["int_data", "scale", "zero_point"], [ + self._layout, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data = tensor_data_dict["int_data"] + scale = tensor_data_dict["scale"] + zero_point = tensor_data_dict["zero_point"] + _layout = tensor_attributes + return cls(int_data, scale, zero_point, _layout) + + def get_plain(self): + int_data = torch.stack( + ((self.int_data << 4) >> 4, self.int_data >> 4), dim=2 + ).view((self.int_data.shape[0], 2 * self.int_data.shape[1])) + return int_data, self.scale, self.zero_point + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + int_data_s4 = ((int_data[:, 1::2] & 0xF) << 4) | (int_data[:, 0::2] & 0xF) + return cls( + int_data_s4, + scale, + zero_point, + _layout, + ) + + def get_layout(self) -> Layout: + return self._layout + + def _apply_fn_to_data(self, fn): + self.int_data = fn(self.int_data) + self.scale = fn(self.scale) + self.zero_point = fn(self.zero_point) + return self + + +def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and len(input_tensor.shape) >= 2 + and input_tensor.tensor_impl.scale.dtype == input_tensor.dtype + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_int4(weight_tensor) + and weight_tensor.dtype == input_tensor.dtype + and len(weight_tensor.shape) == 2 + and weight_tensor.tensor_impl.scale.dtype == weight_tensor.dtype + and len(weight_tensor.tensor_impl.scale.shape) == 1 + and (bias is None or bias.dtype == input_tensor.dtype) + and (bias is None or len(bias.shape) == 1) + ) + + +def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import s8s4_linear_cutlass + + weight = weight_tensor.tensor_impl.int_data + weight_scale = weight_tensor.tensor_impl.scale + input = input_tensor.tensor_impl.int_data + input_scale = input_tensor.tensor_impl.scale + + out = s8s4_linear_cutlass(input, input_scale, weight, weight_scale, bias) + + return out diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index ed171634cd..502e3c13e9 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -225,6 +225,7 @@ def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): isinstance(input_tensor, AffineQuantizedTensor) and _aqt_is_int8_reduced_range(input_tensor) and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_int8(weight_tensor) and input_tensor.dtype == weight_tensor.dtype and isinstance(input_tensor._layout, PlainLayout) and isinstance(weight_tensor._layout, PlainLayout) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index afc5bcfa3f..4cea35abe7 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -49,12 +49,10 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: ), f"need both tensors to be on the same device but got {mat2.device} and {input.device}" device_cpu = "cpu" in [mat2.device.type, input.device.type] # with input.shape = [i,j] and mat2.shape = [j,k] - i_is_strictly_greater_than_16 = input.shape[0] > 16 j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) bad_dimensions_for_cublas = not ( - i_is_strictly_greater_than_16 - and j_is_nonzero_multiple_of_8 + j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 ) diff --git a/torchao/ops.py b/torchao/ops.py index 2774deb08a..f4b55c4951 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -19,6 +19,9 @@ lib.define( "marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor" ) +lib.define( + "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" +) def register_custom_op(name): @@ -509,3 +512,106 @@ def _( ) return torch.empty((size_m, size_n), dtype=torch.float16, device=x.device) + + +def s8s4_linear_cutlass( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_scale: Tensor, + bias: Tensor, +) -> Tensor: + """ + CUTLASS-based W4A8 linear operator. + Args: + input: input tensor, quantized to 8-bit integer values. + input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. + weight: weight matrix, quantized to 4-bit integer values, in row-major layout. + weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). + bias: a vector of size equal to number of rows of weight tensor, or None. + Returns: + output: result tensor, in row-major layout. + """ + + return torch.ops.torchao.s8s4_linear_cutlass.default( + input, input_scale, weight, weight_scale, bias + ) + + +@register_custom_op("torchao::s8s4_linear_cutlass") +def _( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_scale: Tensor, + bias: Tensor, +) -> Tensor: + # Validate dtypes. + torch._check( + input.dtype == torch.int8, + lambda: f"input dtype {input.dtype} instead of {torch.int8}", + ) + torch._check( + input_scale.dtype in (torch.float16, torch.bfloat16), + lambda: f"input_scale dtype {input_scale.dtype} instead of {torch.float16} or {torch.bfloat16}", + ) + torch._check( + weight.dtype == torch.int8, + lambda: f"weight dtype {weight.dtype} instead of {torch.int8}", + ) + torch._check( + weight_scale.dtype == input_scale.dtype, + lambda: f"weight_scale dtype {weight_scale.dtype} instead of {input_scale.dtype}", + ) + if bias is not None: + torch._check( + bias.dtype == input_scale.dtype, + lambda: f"bias dtype {weight_scale.dtype} instead of {input_scale.dtype}", + ) + + # Validate dims. + torch._check(input.dim() >= 2, lambda: f"input is {input.dim()}D instead of >=2D") + torch._check( + input_scale.dim() == input.dim() - 1, + lambda: f"input_scale is {input_scale.dim()}D instead of {input.dim() - 1}D", + ) + torch._check(weight.dim() == 2, lambda: f"weight is {weight.dim()}D instead of 2D") + torch._check( + weight_scale.dim() == 1 or weight_scale.dim() == 2, + lambda: f"weight_scale is {weight_scale.dim()}D instead of 1D or 2D", + ) + if bias is not None: + torch._check(bias.dim() == 1, lambda: f"bias is {bias.dim()}D instead of 1D") + + # Validate shapes. + torch._check( + input.shape[-1] == 2 * weight.shape[-1], + lambda: "input and weight shapes do not match for matrix product", + ) + for i in range(input_scale.dim()): + torch._check( + input_scale.shape[i] == input.shape[i], + lambda: f"input_scale and input shapes do not match at position {i}", + ) + torch._check( + weight_scale.numel() == weight.shape[0], + lambda: f"weight_scale has {weight_scale.numel()} elements instead of {weight.shape[0]}", + ) + if bias is not None: + torch._check( + bias.numel() == weight.shape[0], + lambda: f"bias has {bias.numel()} elements instead of {weight.shape[0]}", + ) + + # Validate strides (input, input_scales and weight_scales will be + # reshape()-d by the operator, so no need to check strides for + # them). + torch._check(weight.stride(-1) == 1, lambda: "weight is not in row-major layout") + if bias is not None: + torch._check(bias.is_contiguous(), lambda: "bias is not contiguous") + + return torch.empty( + (*input.shape[:-1], weight.shape[0]), + dtype=input_scale.dtype, + device=input.device, + ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c7c0de04de..1474a17523 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -27,6 +27,7 @@ import torchao from torchao.dtypes import ( AffineQuantizedTensor, + CutlassInt4PackedLayout, Float8Layout, Int4CPULayout, MarlinQQQLayout, @@ -599,7 +600,12 @@ def apply_int8_dynamic_activation_int4_weight_quant( if act_mapping_type == MappingType.ASYMMETRIC: input_quant_func = _int8_asymm_per_token_quant elif act_mapping_type == MappingType.SYMMETRIC: - input_quant_func = _int8_symm_per_token_quant + if isinstance(layout, MarlinQQQLayout): + input_quant_func = _int8_symm_per_token_quant + elif isinstance(layout, CutlassInt4PackedLayout): + input_quant_func = _int8_symm_per_token_reduced_range_quant_cutlass + else: + input_quant_func = _int8_symm_per_token_quant else: assert False, f"Unsupported activation mapping type: {act_mapping_type}" @@ -635,7 +641,7 @@ def int8_dynamic_activation_int4_weight( Args: `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained - `layout`: layout type for quantized weight tensor, only supports `PlainLayout()` and `MarlinQQQLayout()` for now + `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric """ @@ -826,6 +832,26 @@ def _int8_symm_per_token_reduced_range_quant_noop_decode( ) +def _int8_symm_per_token_reduced_range_quant_cutlass( + x: torch.Tensor, +) -> torch.Tensor: + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = 1e-5 + quant_min = -127 + quant_max = 127 + return to_affine_quantized_intx( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + eps=eps, + quant_min=quant_min, + quant_max=quant_max, + scale_dtype=torch.float16 if x.dtype == torch.float16 else None, + ) + + def int8_dynamic_activation_int8_weight( layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC, @@ -1264,6 +1290,7 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: [ _int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant, + _int8_symm_per_token_reduced_range_quant_cutlass, _input_activation_quant_func_fp8, ] )