Skip to content

Commit

Permalink
fixed bugs, fp8 quant error
Browse files Browse the repository at this point in the history
  • Loading branch information
CNTRYROA committed Sep 19, 2024
1 parent 66dc08b commit e0660f7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
8 changes: 7 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False

VLLM_TEST_FORCE_FP8_MARLIN: bool = False

def get_default_cache_root():
return os.getenv(
Expand Down Expand Up @@ -341,6 +341,12 @@ def get_default_config_root():
lambda:
(os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in
("1", "true")),
# If set, forces FP8 Marlin to be used for FP8 quantization regardless
# of the hardware support for FP8 compute.
"VLLM_TEST_FORCE_FP8_MARLIN":
lambda:
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")),
}

# end-env-vars-definition
Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.nn import Module
from torch.nn.parameter import Parameter

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
Expand Down Expand Up @@ -118,7 +119,7 @@ def __init__(self, quant_config: Fp8Config):
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN

def create_weights(
self,
Expand Down Expand Up @@ -173,6 +174,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
assert weight_scale.numel() == 1
weight_scale = convert_to_channelwise(
weight_scale.expand(len(layer.logical_widths)),
layer.logical_widths)

# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
Expand Down

0 comments on commit e0660f7

Please sign in to comment.