Skip to content

Commit 8ffb474

Browse files
committed
Remove/fix TODOs
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent db16ee1 commit 8ffb474

File tree

4 files changed

+18
-14
lines changed

4 files changed

+18
-14
lines changed

tests/compile/test_fusion_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
101101
num_blocks = batch_size * max_blocks
102102
backend = self.attn.backend
103103

104-
# TODO use get_kv_cache_stride_order
104+
# TODO(luka) use get_kv_cache_stride_order
105105
# Create dummy KV cache for the selected backend
106106
if backend == _Backend.ROCM_ATTN:
107107
# k/v as 1st dimention

tests/compile/test_fusions_e2e.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class ModelBackendTestCase(NamedTuple):
9090
ModelBackendTestCase(
9191
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
9292
model_kwargs=dict(max_model_len=1024),
93-
backend=_Backend.ROCM_AITER_FA, # TODO ROCM_AITER_UNIFIED_ATTN
93+
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
9494
attention_fusions=32,
9595
),
9696
]
@@ -187,7 +187,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
187187
flat_product(
188188
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
189189
)
190-
) # TODO
190+
)
191191
# Toggle RMSNorm for FP4 models and unquant models
192192
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
193193
)

vllm/compilation/fusion.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch._inductor.pattern_matcher import PatternMatcherPass
1010
from torch._ops import OpOverload
1111

12-
from vllm.config import VllmConfig
12+
from vllm.config import VllmConfig, get_current_vllm_config
1313
from vllm.logger import init_logger
1414
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1515
GroupShape,
@@ -93,6 +93,8 @@ class RMSNormQuantPattern:
9393
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
9494
self.epsilon = epsilon
9595
self.quant_dtype = key.quant.dtype
96+
config = get_current_vllm_config()
97+
self.model_dtype = config.model_config.dtype if config.model_config else None
9698

9799
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
98100
self.FUSED_OP = FUSED_OPS[key]
@@ -124,7 +126,7 @@ def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
124126
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
125127
# In case we're matching native rms-norm, conversions might be
126128
# optimized out. We convert here just to be safe.
127-
input = input.to(dtype=torch.float16) # TODO model dtype
129+
input = input.to(dtype=self.model_dtype)
128130

129131
result = torch.empty_like(input, dtype=self.quant_dtype)
130132
at = auto_functionalized(
@@ -179,8 +181,8 @@ def replacement(
179181
):
180182
# In case we're matching native rms-norm, conversions might be
181183
# optimized out. We convert here just to be safe.
182-
input = input.to(dtype=torch.float16) # TODO model dtype
183-
residual = residual.to(dtype=torch.float16)
184+
input = input.to(dtype=self.model_dtype)
185+
residual = residual.to(dtype=self.model_dtype)
184186

185187
result = torch.empty_like(input, dtype=self.quant_dtype)
186188
at = auto_functionalized(
@@ -235,7 +237,7 @@ def pattern(input: torch.Tensor, weight: torch.Tensor):
235237
def replacement(input: torch.Tensor, weight: torch.Tensor):
236238
# In case we're matching native rms-norm, conversions might be
237239
# optimized out. We convert here just to be safe.
238-
input = input.to(dtype=torch.float16) # TODO model dtype
240+
input = input.to(dtype=self.model_dtype)
239241

240242
result = torch.empty_like(input, dtype=self.quant_dtype)
241243
scale = self.quant_matcher.make_scale(input)
@@ -289,8 +291,8 @@ def replacement(
289291
):
290292
# In case we're matching native rms-norm, conversions might be
291293
# optimized out. We convert here just to be safe.
292-
input = input.to(dtype=torch.float16) # TODO model dtype
293-
residual = residual.to(dtype=torch.float16)
294+
input = input.to(dtype=self.model_dtype)
295+
residual = residual.to(dtype=self.model_dtype)
294296

295297
result = torch.empty_like(input, dtype=self.quant_dtype)
296298
scale = self.quant_matcher.make_scale(input)

vllm/compilation/matcher_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434

3535
class MatcherCustomOp(ABC):
3636
def __init__(self, enabled: bool):
37-
self.model_dtype = get_current_vllm_config().model_config.dtype
37+
config = get_current_vllm_config()
38+
self.model_dtype = config.model_config.dtype if config.model_config else None
39+
self.device = config.device_config.device if config.device_config else None
3840

3941
self.enabled = enabled
4042
self.forward = self.forward_custom if enabled else self.forward_native
@@ -51,10 +53,10 @@ def __call__(self, *args, **kws):
5153
return self.forward(*args, **kws)
5254

5355
def empty(self, *args, **kws):
54-
return torch.empty(*args, dtype=self.model_dtype, device="cuda", **kws)
56+
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
5557

5658
def empty_f32(self, *args, **kws):
57-
return torch.empty(*args, dtype=torch.float32, device="cuda", **kws)
59+
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
5860

5961
def inputs(self) -> list[torch.Tensor]:
6062
"""Utility for inputs to the pattern"""
@@ -166,7 +168,7 @@ def forward_custom(
166168
input: torch.Tensor,
167169
scale: torch.Tensor | None = None,
168170
) -> tuple[torch.Tensor, torch.Tensor]:
169-
# TODO: why does empty_like produce a permute but
171+
# TODO(luka): why does empty_like produce a permute but
170172
# empty via shape doesn't?
171173
result = torch.empty(
172174
input.shape, device=input.device, dtype=self.quant_key.dtype

0 commit comments

Comments
 (0)