Skip to content

Commit e151e6d

Browse files
committed
quant works except (torch,torch)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent 8e4a56f commit e151e6d

File tree

2 files changed

+30
-11
lines changed

2 files changed

+30
-11
lines changed

vllm/compilation/fusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def register(self, pm_pass: PatternMatcherPass):
113113
def pattern(input: torch.Tensor, weight: torch.Tensor,
114114
scale: torch.Tensor):
115115
result_rms = self.rmsnorm_matcher(input, weight)
116-
return self.quant_matcher(result_rms, scale)
116+
return self.quant_matcher(result_rms, scale)[0]
117117

118118
def replacement(input: torch.Tensor, weight: torch.Tensor,
119119
scale: torch.Tensor):
@@ -161,7 +161,7 @@ def pattern(input: torch.Tensor, residual: torch.Tensor,
161161
weight: torch.Tensor, scale: torch.Tensor):
162162
result_rms, residual = self.rmsnorm_matcher(
163163
input, weight, residual)
164-
result = self.quant_matcher(result_rms, scale)
164+
result, _ = self.quant_matcher(result_rms, scale)
165165

166166
return result, residual
167167

vllm/compilation/matcher_utils.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from vllm.config import get_current_vllm_config
1010
from vllm.model_executor.layers.layernorm import RMSNorm
11+
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
1112
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1213
QuantKey, _normalize_quant_group_shape, kFp8DynamicTensorSym,
1314
kFp8DynamicTokenSym, kFp8StaticTensorSym)
@@ -100,17 +101,29 @@ def __call__(
100101

101102
class MatcherQuant:
102103

103-
def __init__(self, quant_key: QuantKey):
104+
def __init__(self, quant_key: QuantKey, enabled: Optional[bool] = None):
105+
104106
self.quant_key = quant_key
105107
assert quant_key in QUANT_OPS, \
106108
f"unsupported quantization scheme {quant_key}"
107109
self.QUANT_OP = QUANT_OPS[quant_key]
108110

109-
def forward(
111+
assert quant_key.scale2 is None
112+
self.quant_fp8 = QuantFP8(quant_key.scale.static,
113+
quant_key.scale.group_shape)
114+
115+
if enabled is None:
116+
# TODO either pass config to enabled or set it globally
117+
# (global during pass init seems reasonable)
118+
enabled = self.quant_fp8.enabled()
119+
120+
self.forward = self.forward_custom if enabled else self.forward_native
121+
122+
def forward_custom(
110123
self,
111124
input: torch.Tensor,
112125
scale: Optional[torch.Tensor] = None,
113-
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
126+
) -> tuple[torch.Tensor, torch.Tensor]:
114127
# TODO: why does empty_like produce a permute but
115128
# empty via shape doesn't?
116129
result = torch.empty(input.shape,
@@ -123,7 +136,7 @@ def forward(
123136
result=result,
124137
input=input,
125138
scale=scale)
126-
return result
139+
return result, scale
127140
else:
128141
assert scale is None
129142
scale = self.make_scale(input)
@@ -134,6 +147,13 @@ def forward(
134147
scale_ub=None)
135148
return result, scale
136149

150+
def forward_native(
151+
self,
152+
input: torch.Tensor,
153+
scale: Optional[torch.Tensor] = None,
154+
) -> tuple[torch.Tensor, torch.Tensor]:
155+
return self.quant_fp8(input, scale)
156+
137157
def make_scale(self, input: torch.Tensor):
138158
normalized_group_shape = _normalize_quant_group_shape(
139159
input, self.quant_key.scale.group_shape)
@@ -146,9 +166,8 @@ def make_scale(self, input: torch.Tensor):
146166
device=input.device,
147167
dtype=torch.float32)
148168

149-
def __call__(
150-
self,
151-
input: torch.Tensor,
152-
scale: Optional[torch.Tensor] = None
153-
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
169+
def __call__(self,
170+
input: torch.Tensor,
171+
scale: Optional[torch.Tensor] = None
172+
) -> tuple[torch.Tensor, torch.Tensor]:
154173
return self.forward(input, scale)

0 commit comments

Comments
 (0)