Skip to content

Commit a0b2670

Browse files
authored
[Transform] Deterministic Hadacore Transforms (#24106)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent c4afdb6 commit a0b2670

File tree

10 files changed

+979
-43
lines changed

10 files changed

+979
-43
lines changed

CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
783783
endif()
784784
endif()
785785

786+
# Hadacore kernels
787+
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0;8.9;9.0" "${CUDA_ARCHS}")
788+
if(HADACORE_ARCHS)
789+
set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu")
790+
set_gencode_flags_for_srcs(
791+
SRCS "${SRCS}"
792+
CUDA_ARCHS "${HADACORE_ARCHS}")
793+
list(APPEND VLLM_EXT_SRC "${SRCS}")
794+
message(STATUS "Building hadacore")
795+
endif()
796+
786797
# if CUDA endif
787798
endif()
788799

csrc/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
347347
int64_t open_mem_handle(torch::Tensor& mem_handle);
348348
void free_shared_buffer(int64_t buffer);
349349

350+
torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace);
351+
350352
#ifdef USE_ROCM
351353
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
352354
std::optional<int64_t> qr_max_size = std::nullopt);

csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu

Lines changed: 817 additions & 0 deletions
Large diffs are not rendered by default.

csrc/torch_bindings.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
613613
"int pad_slot_id) -> ()");
614614
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
615615

616+
// Hadamard transforms
617+
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
618+
616619
#ifndef USE_ROCM
617620
// Compute per-token-group FP8 quantized tensor and scaling factor.
618621
ops.def(
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import math
5+
6+
import pytest
7+
import torch
8+
from compressed_tensors.transform import deterministic_hadamard_matrix
9+
10+
from vllm import _custom_ops as ops
11+
12+
13+
@pytest.mark.parametrize("batch_size", [1, 32])
14+
@pytest.mark.parametrize("hidden_dim", [2**n for n in range(10)])
15+
def test_hadacore(batch_size, hidden_dim, dtype=torch.bfloat16, device="cuda"):
16+
x = torch.eye(hidden_dim, dtype=dtype, device=device)
17+
hadamard = deterministic_hadamard_matrix(
18+
hidden_dim, dtype=torch.float64, device="cuda") / math.sqrt(hidden_dim)
19+
20+
y = ops.hadacore_transform(x.clone())
21+
y_true = (x.to(hadamard.dtype) @ hadamard.T).to(y.dtype)
22+
assert torch.allclose(y, y_true)
23+
24+
y = ops.hadacore_transform(y)
25+
assert torch.allclose(y, x)

vllm/_custom_ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2011,3 +2011,27 @@ def onednn_scaled_mm(
20112011
input_zp_adj, bias, dnnl_handler.handler)
20122012

20132013
return output
2014+
2015+
2016+
def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor:
2017+
"""
2018+
Perform Hadamard transforms using [Hadacore](https://arxiv.org/abs/2412.08832)
2019+
kernels. Note that these kernels exploit the recursive properties of
2020+
Sylvester Hadamards, and therefore do not require transform weight data
2021+
2022+
Note that sylvester hadamard transforms are also symmetric, which means that
2023+
this function is also applies the (transpose <=> inverse) transform.
2024+
2025+
:param x: value to be transformed inplace
2026+
:param inplace: modify value in place
2027+
:return: value after transformation
2028+
"""
2029+
return torch.ops._C.hadacore_transform(x, inplace)
2030+
2031+
2032+
if hasattr(torch.ops._C, "hadacore_transform"):
2033+
2034+
@register_fake("_C::hadacore_transform")
2035+
def _hadacore_transform_fake(x: torch.Tensor,
2036+
inplace: bool) -> torch.Tensor:
2037+
return torch.empty_like(x) if not inplace else x

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def get_quant_method(
129129
# choose transform method
130130
if any((input_tfms, output_tfms)):
131131
return CompressedTensorsLinearTransformMethod.from_schemes(
132-
quant_method, input_tfms, output_tfms)
132+
quant_method, quant_scheme, input_tfms, output_tfms)
133133

134134
else:
135135
return quant_method

vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
1313
LinearMethodBase,
1414
QKVCrossParallelLinear)
15+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
16+
CompressedTensorsScheme)
1517
from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501
1618
HadamardTransform)
1719
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
@@ -26,14 +28,22 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
2628

2729
@classmethod
2830
def from_schemes(
29-
cls, quant_method: LinearMethodBase, input_tfms: dict[int,
30-
TransformTuple],
31-
output_tfms: dict[int, TransformTuple]
31+
cls,
32+
quant_method: LinearMethodBase,
33+
quant_scheme: Optional[CompressedTensorsScheme],
34+
input_tfms: dict[int, TransformTuple],
35+
output_tfms: dict[int, TransformTuple],
3236
) -> "CompressedTensorsLinearTransformMethod":
37+
from vllm.model_executor.layers.quantization.compressed_tensors.transform.schemes.linear_qutlass_nvfp4 import ( # noqa: E501
38+
QutlassNvFP4LinearMethod, is_qutlass_fp4_scheme)
39+
3340
assert input_tfms or output_tfms
3441

35-
# TODO (@ksayers): implement QutlassLinearMethodNvFP4
36-
# hadacore and fwht can be selected by Transform module
42+
if is_qutlass_fp4_scheme(quant_scheme, input_tfms):
43+
return QutlassNvFP4LinearMethod(quant_method, input_tfms,
44+
output_tfms)
45+
46+
# hadacore or dense gemm is selected by Transform module
3747

3848
return cls(quant_method, input_tfms, output_tfms)
3949

@@ -129,11 +139,12 @@ def apply(self,
129139
assert bias is None
130140
x = self.quant_method.apply(layer, x, bias)
131141

132-
# TODO (@ksayers): Write a triton kernel to do this in parallel
142+
# In most cases, input transforms are preferred over output transforms
143+
# (@ksayers): confirm that this is done concurrently
133144
if self.output_transform is not None:
134145
for part_id, (start, length) in enumerate(self.partition_ranges):
135146
x[:, start:start + length] = self.output_transform(
136-
x[:, start:start + length], part_id=part_id)
147+
x[:, start:start + length].contiguous(), part_id=part_id)
137148

138149
return x
139150

vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import math
44
from collections.abc import Hashable
5-
from typing import Callable, Optional
5+
from typing import Callable
66

77
import torch
8-
from compressed_tensors.transform import TransformLocation, TransformScheme
8+
from compressed_tensors.transform import (TransformArgs, TransformLocation,
9+
TransformScheme)
910
from torch import Tensor
1011

12+
import vllm._custom_ops as ops
1113
from vllm.distributed.parallel_state import (
1214
get_tensor_model_parallel_world_size)
1315
from vllm.model_executor.layers.linear import LinearBase
@@ -28,16 +30,12 @@ class HadamardTransform(torch.nn.Module):
2830
transforms: dict[int, TransformTuple] # info parsed from transforms config
2931
weight: SharedWeightParameter # container for shared tensors
3032

31-
kernel: Callable # function used during application
3233
scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0))
3334

34-
def __init__(self,
35-
transforms: dict[int, TransformTuple],
36-
layer: torch.nn.Module,
37-
weight_loader: Callable,
35+
def __init__(self, transforms: dict[int, TransformTuple],
36+
layer: torch.nn.Module, weight_loader: Callable,
3837
input_size_per_partition: int,
39-
output_partition_sizes: list[int],
40-
kernel: Optional[Callable] = None):
38+
output_partition_sizes: list[int]):
4139
super().__init__()
4240
self.transforms = transforms
4341
self.scales = {}
@@ -55,7 +53,7 @@ def __init__(self,
5553
for part_index, (_scheme_name, scheme,
5654
args) in self.transforms.items():
5755
output_size = output_partition_sizes[part_index]
58-
weight_size = self._get_weight_size(layer, args.location,
56+
weight_size = self._get_weight_size(layer, scheme, args,
5957
input_size, output_size)
6058

6159
data_key = self._get_data_key(scheme, weight_size)
@@ -69,9 +67,6 @@ def __init__(self,
6967
# validate that shared tensors and schemes are correct
7068
self._validate_input_transforms()
7169

72-
# select kernel based on transform schemes
73-
self.kernel = self._infer_kernel() if kernel is None else kernel
74-
7570
def process_weights_after_loading(self):
7671
for part_id in self.weight.partitions:
7772
data = self.weight.partitions[part_id].data
@@ -90,32 +85,59 @@ def forward(self, value: Tensor, part_id: int = 0) -> Tensor:
9085
if part_id not in self.weight.partitions:
9186
return value
9287

93-
weight = self.weight.partitions[part_id]
94-
weight = weight if self.transforms[
95-
part_id].args.inverse else weight.T # linear := x(W.T)
96-
scale = self.scales[part_id]
97-
return self.kernel(self, value.to(weight.dtype), weight, None).to(
98-
value.dtype) * scale
88+
# use hadacore if possible
89+
if self.transforms[part_id].scheme.type == "hadamard":
90+
if self.transforms[part_id].scheme.head_dim is not None:
91+
weight_size = self.transforms[part_id].scheme.head_dim
92+
value = value.unflatten(-1, (-1, weight_size))
93+
value = ops.hadacore_transform(value)
94+
value = value.flatten(-2, -1)
95+
96+
return value
97+
98+
# sylvester transforms are symmetric, inv => transpose => original
99+
return ops.hadacore_transform(value)
100+
101+
# fall back to dense
102+
else:
103+
weight = self.weight.partitions[part_id]
104+
weight = weight if self.transforms[
105+
part_id].args.inverse else weight.T # linear := x(W.T)
106+
scale = self.scales[part_id]
107+
108+
if self.transforms[part_id].scheme.head_dim is not None:
109+
value = value.unflatten(-1, (-1, weight.size(0)))
110+
value = dispatch_unquantized_gemm()(self, value.to(
111+
weight.dtype), weight, None).to(value.dtype) * scale
112+
value = value.flatten(-2, -1)
113+
114+
return value
115+
116+
return dispatch_unquantized_gemm()(self, value.to(
117+
weight.dtype), weight, None).to(value.dtype) * scale
99118

100119
def _get_data_key(self, scheme: TransformScheme,
101120
weight_size: int) -> Hashable:
102121
return (id(scheme), weight_size)
103122

104-
def _get_weight_size(self, layer: torch.nn.Module,
105-
location: TransformLocation, input_size: int,
123+
def _get_weight_size(self, layer: torch.nn.Module, scheme: TransformScheme,
124+
args: TransformArgs, input_size: int,
106125
output_size: int) -> int:
126+
if scheme.head_dim is not None:
127+
return scheme.head_dim
128+
107129
if isinstance(layer, LinearBase):
108-
if location == TransformLocation.INPUT:
130+
if args.location == TransformLocation.INPUT:
109131
return input_size
110132

111-
elif location == TransformLocation.OUTPUT:
133+
elif args.location == TransformLocation.OUTPUT:
112134
return output_size
113135

114136
elif isinstance(layer, VocabParallelEmbedding):
115-
if location == TransformLocation.INPUT:
137+
if args.location == TransformLocation.INPUT:
116138
return output_size
117139

118-
elif location == TransformLocation.OUTPUT:
140+
elif args.location == TransformLocation.OUTPUT:
119141
return input_size
120142

121143
raise ValueError()
@@ -129,7 +151,3 @@ def _validate_input_transforms(self):
129151
for partition in self.weight.partitions.values():
130152
if partition.data.data_ptr() != first_data.data_ptr():
131153
raise ValueError("")
132-
133-
def _infer_kernel(self) -> Callable:
134-
# TODO (@ksayers): use fwht, hadacore
135-
return dispatch_unquantized_gemm()

vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,43 @@
44

55
import torch
66

7+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
8+
CompressedTensorsScheme, CompressedTensorsW4A4Fp4)
79
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
8-
CompressedTensorsLinearTransformMethod)
10+
CompressedTensorsLinearTransformMethod, TransformTuple)
911

12+
__all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"]
1013

11-
# Because qutlass fuses hadamard with quantization, it cannot automatically be
12-
# composed with kernels in the way CompressedTensorsLinearTransformMethod does.
13-
# Therefore, a separate scheme must be created for each quantized dtype
14-
class QutlassLinearMethodNvFP4(CompressedTensorsLinearTransformMethod):
14+
15+
def is_qutlass_fp4_scheme(quant_scheme: Optional[CompressedTensorsScheme],
16+
input_tfms: dict[int, TransformTuple]) -> bool:
17+
return isinstance(
18+
quant_scheme,
19+
(CompressedTensorsW4A4Fp4, )) and len(input_tfms) == 1 and input_tfms[
20+
0].scheme.head_dim == quant_scheme.group_size
21+
22+
23+
class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod):
24+
25+
def create_weights(self, layer, input_size_per_partition,
26+
output_partition_sizes, input_size, output_size,
27+
params_dtype, **extra_weight_attrs):
28+
# initializes fp4 qparams
29+
assert isinstance(layer.scheme, (CompressedTensorsW4A4Fp4, ))
30+
ret = super().create_weights(layer, input_size_per_partition,
31+
output_partition_sizes, input_size,
32+
output_size, params_dtype,
33+
**extra_weight_attrs)
34+
35+
assert self.input_transform is not None
36+
assert len(self.input_transform.weight) == 1
37+
assert self.input_transform.weight[0].size(
38+
0) == layer.scheme.group_size
39+
40+
return ret
1541

1642
def apply(self,
1743
layer: torch.nn.Module,
1844
x: torch.Tensor,
1945
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
20-
# fused hadamard quant linear method
2146
raise NotImplementedError()

0 commit comments

Comments
 (0)