Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tests/kernels/quantization/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from huggingface_hub import snapshot_download

import vllm._custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf
from vllm.platforms import current_platform
Expand Down Expand Up @@ -176,12 +175,11 @@ def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,

w2_dequant = torch.tensor(dequantize(w2.data, quant_type),
device="cuda").to(dtype)
act = SiluAndMul()

output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"),
torch.tensor(w2.data,
device="cuda"), topk_weights,
topk_ids, quant_type, quant_type, act)
topk_ids, quant_type, quant_type, "silu")

ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights,
topk_ids).reshape(output.shape)
Expand Down
8 changes: 6 additions & 2 deletions tests/models/quantization/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,12 @@ def gguf_model(self):
)

MODELS = [
LLAMA_CONFIG, QWEN2_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG,
DOLPHIN_CONFIG
LLAMA_CONFIG,
QWEN2_CONFIG,
PHI3_CONFIG,
GPT2_CONFIG,
# STABLELM_CONFIG, # enable this when v1 support head_size=80
DOLPHIN_CONFIG,
# STARCODER_CONFIG, # broken
]

Expand Down
8 changes: 0 additions & 8 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,14 +1291,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

# Some quantization is not compatible with torch.compile.
V1_UNSUPPORTED_QUANT = ["gguf"]
if model_config.quantization in V1_UNSUPPORTED_QUANT:
_raise_or_fallback(
feature_name=f"--quantization {model_config.quantization}",
recommend_to_remove=False)
return False

# No Embedding Models so far.
if model_config.task not in ["generate"]:
_raise_or_fallback(feature_name=f"--task {model_config.task}",
Expand Down
4 changes: 0 additions & 4 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,6 @@ def weight_loader(self,
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return

param_data = param.data
Expand Down Expand Up @@ -982,8 +980,6 @@ def weight_loader(self,
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
return

param_data = param.data
Expand Down
223 changes: 181 additions & 42 deletions vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
Expand All @@ -19,6 +18,7 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op

logger = init_logger(__name__)

Expand Down Expand Up @@ -96,8 +96,8 @@ def get_quant_method(self, layer: torch.nn.Module,
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES


def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
# HACK: when doing chunked prefill we don't generate output tokens
# so input to logits generator is empty which causes invalid parameter
if x.shape[0] == 0:
Expand Down Expand Up @@ -130,6 +130,30 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
return y


def _fused_mul_mat_gguf_fake(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
) -> torch.Tensor:
return torch.empty(x.shape[0],
qweight.shape[0],
dtype=x.dtype,
device=x.device)


try:
direct_register_custom_op(
op_name="_fused_mul_mat_gguf",
op_func=_fused_mul_mat_gguf,
mutates_args=[],
fake_impl=_fused_mul_mat_gguf_fake,
)
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf

except AttributeError as error:
raise error


def _fused_moe_gguf(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be overkill to do it in this PR, but we should aim to migrate this to use the modular kernel approach, inheriting from classes like

class FusedMoEPermuteExpertsUnpermute(ABC):

This will take care of activation dispatch and other features

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great. Let me migrate it in a following PR!

x: torch.Tensor,
w1: torch.Tensor,
Expand All @@ -138,8 +162,21 @@ def _fused_moe_gguf(
topk_ids: torch.Tensor,
qweight_type: int,
qweight_type2: int,
act,
activation: str,
) -> torch.Tensor:

def act(x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation == "silu":
torch.ops._C.silu_and_mul(out, x)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(out, x)
else:
raise ValueError(f"Unsupported activation: {activation}")
return out

# lazy import to avoid triggering triton import in CPU backend
from vllm.model_executor.layers.fused_moe.fused_moe import (
moe_align_block_size)
Expand Down Expand Up @@ -189,12 +226,12 @@ def _fused_moe_gguf(
for ww, ii in zip(w, idx):
expert_up = w1[ii]

out = _fuse_mul_mat(inp, expert_up, qweight_type)
out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
out = act(out)

expert_down = w2[ii]
current_state = _fuse_mul_mat(out, expert_down,
qweight_type2).mul_(ww)
current_state = fused_mul_mat_gguf(out, expert_down,
qweight_type2).mul_(ww)
if current_hidden_state is None:
current_hidden_state = current_state
else:
Expand All @@ -203,6 +240,78 @@ def _fused_moe_gguf(
return out_hidden_states


def _fused_moe_gguf_fake(
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
qweight_type: int,
qweight_type2: int,
activation: str,
) -> torch.Tensor:
return torch.empty_like(x)


try:
direct_register_custom_op(
op_name="_fused_moe_gguf",
op_func=_fused_moe_gguf,
mutates_args=[],
fake_impl=_fused_moe_gguf_fake,
)
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf

except AttributeError as error:
raise error


def _apply_gguf_embedding(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if qweight_type in UNQUANTIZED_TYPES:
return torch.embedding(qweight, x)
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
x_flat = x.flatten()
assert (hidden_size == qweight.shape[1] // type_size * block_size)
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0], dtype)
return dequant.view(*x.shape, hidden_size)
else:
qweight_type = WeightType(qweight_type)
raise NotImplementedError(
f"Unsupported GGUF quantization type: {qweight_type}")


def _apply_gguf_embedding_fake(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)


try:
direct_register_custom_op(
op_name="_apply_gguf_embedding",
op_func=_apply_gguf_embedding,
mutates_args=[],
fake_impl=_apply_gguf_embedding_fake,
)
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding

except AttributeError as error:
raise error


class GGUFLinearMethod(LinearMethodBase):
"""Linear method for GGUF.

Expand Down Expand Up @@ -249,26 +358,76 @@ def create_weights(self, layer: torch.nn.Module,
set_weight_attrs(qweight_type, extra_weight_attrs)
layer.register_parameter("qweight_type", qweight_type)

def process_weights_after_loading(self, layer: torch.nn.Module):
qweight_type = layer.qweight_type.weight_type
if not (qweight_type in UNQUANTIZED_TYPES
or qweight_type in DEQUANT_TYPES):
qweight_type = WeightType(qweight_type)
raise ValueError(
f"Unsupported GGUF quantization type {qweight_type} in "
f"layer {layer}.")
# For MergedColumnParallelLinear and QKVParallelLinear, we need to
# materialize the padded weight parameter for CUDA Graph compatibility.
self._create_padded_weight_param(layer)

def _create_padded_weight_param(self, layer: torch.nn.Module):
"""Create padded weight parameter for GGUF MergedLinear layer."""
qweight = layer.qweight
shard_id_map = qweight.shard_id_map
shard_id = qweight.shard_id
if len(data_container := qweight.data_container) > 1:
dtype = {data.dtype for data in data_container}
assert len(dtype) == 1, ValueError(
f"Data container has mixed dtypes: {dtype}")
dtype = next(iter(dtype))
# concat dim0 and pad dim1
padded_side = max(x.size(1) for x in data_container)
concat_side = sum(x.size(0) for x in data_container)
# Pad the quantized weights to dense tensor, and create a map
# with the location of each shard in the padded tensor.
padded_data = torch.zeros((concat_side, padded_side),
dtype=dtype,
device=qweight.device)
# (dim0_start, dim0_end, dim1_size)
shard_offset_map = dict[str, tuple[int, int, int]]()
for idx in shard_id:
id_in_container = shard_id_map[idx]
start = sum(
x.size(0) for x in data_container[:id_in_container])
end = start + data_container[id_in_container].size(0)
size = data_container[id_in_container].size(1)
padded_data[start:end, :size] = data_container[id_in_container]
shard_offset_map[idx] = (start, end, size)
qweight.data_container.clear()
padded_param = Parameter(padded_data, requires_grad=False)
set_weight_attrs(padded_param, vars(qweight))
set_weight_attrs(padded_param,
{"shard_offset_map": shard_offset_map})
layer.register_parameter("qweight", padded_param)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
shard_id = getattr(layer.qweight, "shard_id", None)
shard_id = layer.qweight.shard_id

if shard_id:
# dequantize shard weights respectively
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
qweight = layer.qweight.unbind(0)
qweight = layer.qweight
result = []
for idx in shard_id:
q_idx = layer.qweight.shard_id_map[idx]
start, end, offset = layer.qweight.shard_offset_map[idx]
qweight_type = layer.qweight_type.shard_weight_type[idx]
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
result.append(
fused_mul_mat_gguf(
x, qweight[start:end, :offset].contiguous(),
qweight_type))
out = torch.cat(result, axis=1)
else:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
out = _fuse_mul_mat(x, qweight, qweight_type)
out = fused_mul_mat_gguf(x, qweight, qweight_type)
if bias is not None:
out.add_(bias)
return out
Expand Down Expand Up @@ -338,7 +497,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,

set_weight_attrs(w2_qweight_type, extra_weight_attrs)
layer.register_parameter("w2_qweight_type", w2_qweight_type)
self.act = SiluAndMul()

def apply(
self,
Expand Down Expand Up @@ -375,10 +533,10 @@ def apply(
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
topk_weights, topk_ids,
layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type, self.act)
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
topk_weights, topk_ids,
layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type, activation)


class GGUFEmbeddingMethod(GGUFLinearMethod):
Expand All @@ -392,34 +550,15 @@ def embedding(self, layer: torch.nn.Module,
x: torch.Tensor) -> torch.Tensor:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
hidden_size = qweight.tensor_shape[1]

block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
hidden_size = qweight.shape[1] // type_size * block_size
if qweight_type < 2:
return torch.embedding(qweight, x)
x_flat = x.flatten()
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0], self.params_dtype)
return dequant.view(*x.shape, hidden_size)
return apply_gguf_embedding(x,
qweight,
qweight_type,
hidden_size,
dtype=self.params_dtype)


class GGUFUninitializedParameter(UninitializedParameter):
cls_to_become = Parameter
data_container: list[torch.Tensor]

def materialize_nested(self) -> Parameter:
dtype = {data.dtype for data in self.data_container}
assert len(dtype) == 1, ValueError(
f"Data container has mixed dtypes: {dtype}")
dtype = next(iter(dtype))
nested_data = torch.nested.nested_tensor(self.data_container,
device=self.device,
dtype=dtype)
self.data_container.clear()
param = torch.Tensor._make_subclass(self.cls_to_become,
nested_data,
require_grad=False)
for k, v in self.__dict__.items():
setattr(param, k, v)
return param