Skip to content

Commit 7cd2458

Browse files
committed
Support FP8 FA from Quark format
1 parent 5c3b97a commit 7cd2458

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

vllm/model_executor/layers/quantization/quark/quark.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,17 @@ def _check_scheme_supported(self,
153153
else:
154154
return False
155155

156+
def is_fp8_w8a8(self) -> bool:
157+
# Returns True if all quantized layers in model are fp8 w8a8.
158+
global_quant_config = self.quant_config.get("global_quant_config")
159+
layer_quant_configs = self.quant_config.get("layer_quant_config")
160+
for quant_config in (global_quant_config,
161+
*layer_quant_configs.values()):
162+
if not self._is_fp8_w8a8(quant_config.get("weight"),
163+
quant_config.get("input_tensors")):
164+
return False
165+
return True
166+
156167
def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
157168
input_quant: Optional[Dict[str, Any]]) -> bool:
158169
# Confirm weights and input quantized.

vllm/model_executor/models/grok1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def __init__(
200200
self.hidden_size = config.hidden_size
201201
self.use_fp8 = isinstance(
202202
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
203-
and quant_config._is_fp8_w8a8)
203+
and quant_config.is_fp8_w8a8())
204204
# Requires transformers > 4.32.0
205205
rope_theta = getattr(config, "rope_theta", 10000)
206206
self.attn = Grok1Attention(hidden_size=self.hidden_size,

vllm/model_executor/models/llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
)
8888
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
8989
(isinstance(quant_config, QuarkConfig)
90-
and quant_config._is_fp8_w8a8)
90+
and quant_config.is_fp8_w8a8())
9191
if current_platform.is_rocm() and not is_navi() else
9292
False)
9393
if hidden_act != "silu":
@@ -201,7 +201,7 @@ def __init__(self,
201201
# For CUDA devices and Navi4x, attn_fp8 will be set to false.
202202
use_fp8 = isinstance(
203203
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
204-
and quant_config._is_fp8_w8a8)
204+
and quant_config.is_fp8_w8a8())
205205
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
206206
and current_platform.is_rocm() \
207207
and not is_navi() \
@@ -248,7 +248,7 @@ def __init__(
248248
self.hidden_size = config.hidden_size
249249
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
250250
(isinstance(quant_config, QuarkConfig)
251-
and quant_config._is_fp8_w8a8)
251+
and quant_config.is_fp8_w8a8())
252252
if current_platform.is_rocm() and not is_navi() else
253253
False)
254254
rope_theta = getattr(config, "rope_theta", 10000)

0 commit comments

Comments
 (0)