Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()

# Hadacore kernels
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0;8.9;9.0" "${CUDA_ARCHS}")
if(HADACORE_ARCHS)
set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${HADACORE_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building hadacore")
endif()

# if CUDA endif
endif()

Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);

torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace);

#ifdef USE_ROCM
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
std::optional<int64_t> qr_max_size = std::nullopt);
Expand Down
817 changes: 817 additions & 0 deletions csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);

// Hadamard transforms
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");

#ifndef USE_ROCM
// Compute per-token-group FP8 quantized tensor and scaling factor.
ops.def(
Expand Down
25 changes: 25 additions & 0 deletions tests/kernels/quantization/test_hadacore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math

import pytest
import torch
from compressed_tensors.transform import deterministic_hadamard_matrix

from vllm import _custom_ops as ops


@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("hidden_dim", [2**n for n in range(10)])
def test_hadacore(batch_size, hidden_dim, dtype=torch.bfloat16, device="cuda"):
x = torch.eye(hidden_dim, dtype=dtype, device=device)
hadamard = deterministic_hadamard_matrix(
hidden_dim, dtype=torch.float64, device="cuda") / math.sqrt(hidden_dim)

y = ops.hadacore_transform(x.clone())
y_true = (x.to(hadamard.dtype) @ hadamard.T).to(y.dtype)
assert torch.allclose(y, y_true)

y = ops.hadacore_transform(y)
assert torch.allclose(y, x)
24 changes: 24 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,3 +2010,27 @@ def onednn_scaled_mm(
input_zp_adj, bias, dnnl_handler.handler)

return output


def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor:
"""
Perform Hadamard transforms using [Hadacore](https://arxiv.org/abs/2412.08832)
kernels. Note that these kernels exploit the recursive properties of
Sylvester Hadamards, and therefore do not require transform weight data

Note that sylvester hadamard transforms are also symmetric, which means that
this function is also applies the (transpose <=> inverse) transform.

:param x: value to be transformed inplace
:param inplace: modify value in place
:return: value after transformation
"""
return torch.ops._C.hadacore_transform(x, inplace)


if hasattr(torch.ops._C, "hadacore_transform"):

@register_fake("_C::hadacore_transform")
def _hadacore_transform_fake(x: torch.Tensor,
inplace: bool) -> torch.Tensor:
return torch.empty_like(x) if not inplace else x
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_quant_method(
# choose transform method
if any((input_tfms, output_tfms)):
return CompressedTensorsLinearTransformMethod.from_schemes(
quant_method, input_tfms, output_tfms)
quant_method, quant_scheme, input_tfms, output_tfms)

else:
return quant_method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
LinearMethodBase,
QKVCrossParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501
HadamardTransform)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
Expand All @@ -26,14 +28,22 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):

@classmethod
def from_schemes(
cls, quant_method: LinearMethodBase, input_tfms: dict[int,
TransformTuple],
output_tfms: dict[int, TransformTuple]
cls,
quant_method: LinearMethodBase,
quant_scheme: Optional[CompressedTensorsScheme],
input_tfms: dict[int, TransformTuple],
output_tfms: dict[int, TransformTuple],
) -> "CompressedTensorsLinearTransformMethod":
from vllm.model_executor.layers.quantization.compressed_tensors.transform.schemes.linear_qutlass_nvfp4 import ( # noqa: E501
QutlassNvFP4LinearMethod, is_qutlass_fp4_scheme)

assert input_tfms or output_tfms

# TODO (@ksayers): implement QutlassLinearMethodNvFP4
# hadacore and fwht can be selected by Transform module
if is_qutlass_fp4_scheme(quant_scheme, input_tfms):
return QutlassNvFP4LinearMethod(quant_method, input_tfms,
output_tfms)

# hadacore or dense gemm is selected by Transform module

return cls(quant_method, input_tfms, output_tfms)

Expand Down Expand Up @@ -129,11 +139,12 @@ def apply(self,
assert bias is None
x = self.quant_method.apply(layer, x, bias)

# TODO (@ksayers): Write a triton kernel to do this in parallel
# In most cases, input transforms are preferred over output transforms
# (@ksayers): confirm that this is done concurrently
if self.output_transform is not None:
for part_id, (start, length) in enumerate(self.partition_ranges):
x[:, start:start + length] = self.output_transform(
x[:, start:start + length], part_id=part_id)
x[:, start:start + length].contiguous(), part_id=part_id)

return x

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Hashable
from typing import Callable, Optional
from typing import Callable

import torch
from compressed_tensors.transform import TransformLocation, TransformScheme
from compressed_tensors.transform import (TransformArgs, TransformLocation,
TransformScheme)
from torch import Tensor

import vllm._custom_ops as ops
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import LinearBase
Expand All @@ -28,16 +30,12 @@ class HadamardTransform(torch.nn.Module):
transforms: dict[int, TransformTuple] # info parsed from transforms config
weight: SharedWeightParameter # container for shared tensors

kernel: Callable # function used during application
scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0))

def __init__(self,
transforms: dict[int, TransformTuple],
layer: torch.nn.Module,
weight_loader: Callable,
def __init__(self, transforms: dict[int, TransformTuple],
layer: torch.nn.Module, weight_loader: Callable,
input_size_per_partition: int,
output_partition_sizes: list[int],
kernel: Optional[Callable] = None):
output_partition_sizes: list[int]):
super().__init__()
self.transforms = transforms
self.scales = {}
Expand All @@ -55,7 +53,7 @@ def __init__(self,
for part_index, (_scheme_name, scheme,
args) in self.transforms.items():
output_size = output_partition_sizes[part_index]
weight_size = self._get_weight_size(layer, args.location,
weight_size = self._get_weight_size(layer, scheme, args,
input_size, output_size)

data_key = self._get_data_key(scheme, weight_size)
Expand All @@ -69,9 +67,6 @@ def __init__(self,
# validate that shared tensors and schemes are correct
self._validate_input_transforms()

# select kernel based on transform schemes
self.kernel = self._infer_kernel() if kernel is None else kernel

def process_weights_after_loading(self):
for part_id in self.weight.partitions:
data = self.weight.partitions[part_id].data
Expand All @@ -90,32 +85,59 @@ def forward(self, value: Tensor, part_id: int = 0) -> Tensor:
if part_id not in self.weight.partitions:
return value

weight = self.weight.partitions[part_id]
weight = weight if self.transforms[
part_id].args.inverse else weight.T # linear := x(W.T)
scale = self.scales[part_id]
return self.kernel(self, value.to(weight.dtype), weight, None).to(
value.dtype) * scale
# use hadacore if possible
if self.transforms[part_id].scheme.type == "hadamard":
if self.transforms[part_id].scheme.head_dim is not None:
weight_size = self.transforms[part_id].scheme.head_dim
value = value.unflatten(-1, (-1, weight_size))
value = ops.hadacore_transform(value)
value = value.flatten(-2, -1)

return value

# sylvester transforms are symmetric, inv => transpose => original
return ops.hadacore_transform(value)

# fall back to dense
else:
weight = self.weight.partitions[part_id]
weight = weight if self.transforms[
part_id].args.inverse else weight.T # linear := x(W.T)
scale = self.scales[part_id]

if self.transforms[part_id].scheme.head_dim is not None:
value = value.unflatten(-1, (-1, weight.size(0)))
value = dispatch_unquantized_gemm()(self, value.to(
weight.dtype), weight, None).to(value.dtype) * scale
value = value.flatten(-2, -1)

return value

return dispatch_unquantized_gemm()(self, value.to(
weight.dtype), weight, None).to(value.dtype) * scale

def _get_data_key(self, scheme: TransformScheme,
weight_size: int) -> Hashable:
return (id(scheme), weight_size)

def _get_weight_size(self, layer: torch.nn.Module,
location: TransformLocation, input_size: int,
def _get_weight_size(self, layer: torch.nn.Module, scheme: TransformScheme,
args: TransformArgs, input_size: int,
output_size: int) -> int:
if scheme.head_dim is not None:
return scheme.head_dim

if isinstance(layer, LinearBase):
if location == TransformLocation.INPUT:
if args.location == TransformLocation.INPUT:
return input_size

elif location == TransformLocation.OUTPUT:
elif args.location == TransformLocation.OUTPUT:
return output_size

elif isinstance(layer, VocabParallelEmbedding):
if location == TransformLocation.INPUT:
if args.location == TransformLocation.INPUT:
return output_size

elif location == TransformLocation.OUTPUT:
elif args.location == TransformLocation.OUTPUT:
return input_size

raise ValueError()
Expand All @@ -129,7 +151,3 @@ def _validate_input_transforms(self):
for partition in self.weight.partitions.values():
if partition.data.data_ptr() != first_data.data_ptr():
raise ValueError("")

def _infer_kernel(self) -> Callable:
# TODO (@ksayers): use fwht, hadacore
return dispatch_unquantized_gemm()
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,43 @@

import torch

from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsScheme, CompressedTensorsW4A4Fp4)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod)
CompressedTensorsLinearTransformMethod, TransformTuple)

__all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"]

# Because qutlass fuses hadamard with quantization, it cannot automatically be
# composed with kernels in the way CompressedTensorsLinearTransformMethod does.
# Therefore, a separate scheme must be created for each quantized dtype
class QutlassLinearMethodNvFP4(CompressedTensorsLinearTransformMethod):

def is_qutlass_fp4_scheme(quant_scheme: Optional[CompressedTensorsScheme],
input_tfms: dict[int, TransformTuple]) -> bool:
return isinstance(
quant_scheme,
(CompressedTensorsW4A4Fp4, )) and len(input_tfms) == 1 and input_tfms[
0].scheme.head_dim == quant_scheme.group_size


class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod):

def create_weights(self, layer, input_size_per_partition,
output_partition_sizes, input_size, output_size,
params_dtype, **extra_weight_attrs):
# initializes fp4 qparams
assert isinstance(layer.scheme, (CompressedTensorsW4A4Fp4, ))
ret = super().create_weights(layer, input_size_per_partition,
output_partition_sizes, input_size,
output_size, params_dtype,
**extra_weight_attrs)

assert self.input_transform is not None
assert len(self.input_transform.weight) == 1
assert self.input_transform.weight[0].size(
0) == layer.scheme.group_size

return ret

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# fused hadamard quant linear method
raise NotImplementedError()
Loading