Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torchao quant for mixtral and qwen_moe #1418

Merged
merged 8 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion python/sglang/srt/layers/torchao_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,20 @@
Common utilities for torchao.
"""

from typing import Dict, Set

import torch


def torchao_quantize_param_data(param, torchao_config):
def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
"""Quantize a Tensor with torchao quantization specified by torchao_config

Args:
`param`: weight parameter of the linear module
`torchao_config`: type of quantization and their arguments we want to use to
quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
128
"""
# Lazy import to suppress some warnings
from torchao.quantization import (
int4_weight_only,
Expand Down Expand Up @@ -36,3 +46,30 @@ def torchao_quantize_param_data(param, torchao_config):
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
quantize_(dummy_linear, float8_weight_only())
return dummy_linear.weight


def apply_torchao_config_(
self: torch.nn.Module,
params_dict: Dict[str, torch.Tensor],
param_suffixes: Set[str],
) -> None:
"""A util function used for quantizing the weight parameters after they are loaded if
self.torchao_config is specified

Args:
`self`: the model we want to quantize
`params_dict`: dictionary mapping from param_name to the parameter Tensor
`param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes

Returns:
None, the `params_dict` is modified inplace and the weights of `self` model are quantized
"""
if self.torchao_config:
for param_suffix in param_suffixes:
for name in params_dict:
param = params_dict[name]
if param_suffix in name and param.ndim == 2:
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)
self.load_state_dict(params_dict, assign=True)
21 changes: 2 additions & 19 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata

Expand Down Expand Up @@ -405,24 +405,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

if self.torchao_config:
if name.endswith("proj.weight") and param.ndim == 2:
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)

if self.torchao_config:
# quantizing the loaded, stacked params, e.g. "...qkv_proj"
stacked_params = set(entry[0] for entry in stacked_params_mapping)
for param_suffix in stacked_params:
for name in params_dict:
if param_suffix in name:
param = params_dict[name]
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)

self.load_state_dict(params_dict, assign=True)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))


class Phi3ForCausalLM(LlamaForCausalLM):
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata


Expand Down Expand Up @@ -296,6 +298,7 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
Expand Down Expand Up @@ -376,5 +379,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))


EntryClass = MixtralForCausalLM
5 changes: 5 additions & 0 deletions python/sglang/srt/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata


Expand Down Expand Up @@ -359,6 +361,7 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = Qwen2MoeModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
Expand Down Expand Up @@ -451,5 +454,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))


EntryClass = Qwen2MoeForCausalLM
Loading