|
9 | 9 | from torch._inductor.pattern_matcher import PatternMatcherPass |
10 | 10 | from torch._ops import OpOverload |
11 | 11 |
|
12 | | -from vllm.config import VllmConfig |
| 12 | +from vllm.config import VllmConfig, get_current_vllm_config |
13 | 13 | from vllm.logger import init_logger |
14 | 14 | from vllm.model_executor.layers.quantization.utils.quant_utils import ( |
15 | 15 | GroupShape, |
@@ -93,6 +93,8 @@ class RMSNormQuantPattern: |
93 | 93 | def __init__(self, epsilon: float, key: FusedRMSQuantKey): |
94 | 94 | self.epsilon = epsilon |
95 | 95 | self.quant_dtype = key.quant.dtype |
| 96 | + config = get_current_vllm_config() |
| 97 | + self.model_dtype = config.model_config.dtype if config.model_config else None |
96 | 98 |
|
97 | 99 | assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" |
98 | 100 | self.FUSED_OP = FUSED_OPS[key] |
@@ -124,7 +126,7 @@ def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): |
124 | 126 | def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): |
125 | 127 | # In case we're matching native rms-norm, conversions might be |
126 | 128 | # optimized out. We convert here just to be safe. |
127 | | - input = input.to(dtype=torch.float16) # TODO model dtype |
| 129 | + input = input.to(dtype=self.model_dtype) |
128 | 130 |
|
129 | 131 | result = torch.empty_like(input, dtype=self.quant_dtype) |
130 | 132 | at = auto_functionalized( |
@@ -179,8 +181,8 @@ def replacement( |
179 | 181 | ): |
180 | 182 | # In case we're matching native rms-norm, conversions might be |
181 | 183 | # optimized out. We convert here just to be safe. |
182 | | - input = input.to(dtype=torch.float16) # TODO model dtype |
183 | | - residual = residual.to(dtype=torch.float16) |
| 184 | + input = input.to(dtype=self.model_dtype) |
| 185 | + residual = residual.to(dtype=self.model_dtype) |
184 | 186 |
|
185 | 187 | result = torch.empty_like(input, dtype=self.quant_dtype) |
186 | 188 | at = auto_functionalized( |
@@ -235,7 +237,7 @@ def pattern(input: torch.Tensor, weight: torch.Tensor): |
235 | 237 | def replacement(input: torch.Tensor, weight: torch.Tensor): |
236 | 238 | # In case we're matching native rms-norm, conversions might be |
237 | 239 | # optimized out. We convert here just to be safe. |
238 | | - input = input.to(dtype=torch.float16) # TODO model dtype |
| 240 | + input = input.to(dtype=self.model_dtype) |
239 | 241 |
|
240 | 242 | result = torch.empty_like(input, dtype=self.quant_dtype) |
241 | 243 | scale = self.quant_matcher.make_scale(input) |
@@ -289,8 +291,8 @@ def replacement( |
289 | 291 | ): |
290 | 292 | # In case we're matching native rms-norm, conversions might be |
291 | 293 | # optimized out. We convert here just to be safe. |
292 | | - input = input.to(dtype=torch.float16) # TODO model dtype |
293 | | - residual = residual.to(dtype=torch.float16) |
| 294 | + input = input.to(dtype=self.model_dtype) |
| 295 | + residual = residual.to(dtype=self.model_dtype) |
294 | 296 |
|
295 | 297 | result = torch.empty_like(input, dtype=self.quant_dtype) |
296 | 298 | scale = self.quant_matcher.make_scale(input) |
|
0 commit comments