Skip to content

Commit 5eeb376

Browse files
committed
TEMP collective fusion hack to enable custom op, matching rms_norm and fused_add_rms_norm
Signed-off-by: Luka Govedic <lgovedic@redhat.com>
1 parent bc5dfaf commit 5eeb376

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,8 @@ def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
719719
self.quant_dtype = torch.float8_e4m3fn
720720
self.quant_fp8 = QuantFP8(static=True,
721721
group_shape=GroupShape.PER_TENSOR)
722+
# TODO HACK
723+
self.quant_fp8._forward_method = self.quant_fp8.forward_native
722724

723725
def register(self, pm_pass: PatternMatcherPass):
724726

@@ -729,9 +731,9 @@ def get_inputs():
729731
rmsnorm_result = torch.empty([1, 8, 4],
730732
device=self.device,
731733
dtype=self.dtype)
732-
quant_result = torch.empty([1, 8, 4],
733-
device=self.device,
734-
dtype=self.quant_dtype)
734+
# quant_result = torch.empty([1, 8, 4],
735+
# device=self.device,
736+
# dtype=self.quant_dtype)
735737
weight = torch.empty([4], device=self.device, dtype=self.dtype)
736738
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
737739
return [
@@ -807,6 +809,8 @@ def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
807809
self.quant_dtype = torch.float8_e4m3fn
808810
self.quant_fp8 = QuantFP8(static=True,
809811
group_shape=GroupShape.PER_TENSOR)
812+
# TODO HACK
813+
self.quant_fp8._forward_method = self.quant_fp8.forward_native
810814

811815
def register(self, pm_pass: PatternMatcherPass):
812816

@@ -817,9 +821,9 @@ def get_inputs():
817821
device=self.device,
818822
dtype=self.dtype)
819823
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
820-
quant_result = torch.empty([4, 4],
821-
device=self.device,
822-
dtype=self.quant_dtype)
824+
# quant_result = torch.empty([4, 4],
825+
# device=self.device,
826+
# dtype=self.quant_dtype)
823827
scale = torch.empty([1, 1],
824828
device=self.device,
825829
dtype=torch.float32)
@@ -1166,6 +1170,9 @@ def __init__(self, config: VllmConfig):
11661170
# and allow multiple values of epsilon.
11671171
torch._inductor.pattern_matcher._seen_patterns.clear()
11681172

1173+
if path := config.compilation_config.debug_dump_path:
1174+
with open(f"{path}/patterns.txt", 'w') as f:
1175+
print(self.patterns.patterns, file=f)
11691176
self.disabled = False
11701177

11711178
def __call__(self, graph: fx.Graph):

0 commit comments

Comments
 (0)