Skip to content

Commit 24f1298

Browse files
committed
PR comments: cleanup fusion passes, & matching
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent 7e6f5b3 commit 24f1298

File tree

3 files changed

+17
-31
lines changed

3 files changed

+17
-31
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -673,10 +673,10 @@ def __init__(
673673
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
674674

675675
def get_inputs(self):
676-
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
677-
weight = torch.empty([4], device=self.device, dtype=self.dtype)
676+
input, weight = self.rmsnorm_matcher.inputs()
678677

679-
return [input, weight]
678+
# input goes through allreduce first, always 16-bit
679+
return [input.to(self.dtype), weight]
680680

681681
def register(self, pm_pass: PatternMatcherPass):
682682
def pattern(input: torch.Tensor, weight: torch.Tensor):
@@ -728,14 +728,10 @@ def __init__(
728728
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
729729

730730
def get_inputs(self):
731-
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
732-
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
733-
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
734-
return [
735-
residual,
736-
input,
737-
weight,
738-
]
731+
input, residual, weight = self.rmsnorm_matcher.inputs()
732+
733+
# input goes through allreduce first, always 16-bit
734+
return [residual, input.to(self.dtype), weight]
739735

740736
def register(self, pm_pass: PatternMatcherPass):
741737
def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor):
@@ -802,10 +798,11 @@ def __init__(
802798

803799
def register(self, pm_pass: PatternMatcherPass):
804800
def get_inputs():
805-
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
806-
weight = torch.empty([4], device=self.device, dtype=self.dtype)
807-
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
808-
return [input, weight, scale]
801+
input, weight = self.rmsnorm_matcher.inputs()
802+
_, scale = self.quant_matcher.inputs()
803+
804+
# input goes through allreduce first, always 16-bit
805+
return [input.to(self.dtype), weight, scale]
809806

810807
def pattern(
811808
input: torch.Tensor,
@@ -871,18 +868,11 @@ def __init__(
871868

872869
def register(self, pm_pass: PatternMatcherPass):
873870
def get_inputs():
874-
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
871+
input, residual, weight = self.rmsnorm_matcher.inputs()
872+
_, scale = self.quant_matcher.inputs()
875873

876-
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
877-
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
878-
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
879-
880-
return [
881-
residual,
882-
input,
883-
weight,
884-
scale,
885-
]
874+
# input goes through allreduce first, always 16-bit
875+
return [residual, input.to(self.dtype), weight, scale]
886876

887877
def pattern(
888878
residual: torch.Tensor,

vllm/compilation/fusion.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ def replacement(
182182
# In case we're matching native rms-norm, conversions might be
183183
# optimized out. We convert here just to be safe.
184184
input = input.to(dtype=self.model_dtype)
185-
residual = residual.to(dtype=self.model_dtype)
186185

187186
result = torch.empty_like(input, dtype=self.quant_dtype)
188187
at = auto_functionalized(
@@ -292,7 +291,6 @@ def replacement(
292291
# In case we're matching native rms-norm, conversions might be
293292
# optimized out. We convert here just to be safe.
294293
input = input.to(dtype=self.model_dtype)
295-
residual = residual.to(dtype=self.model_dtype)
296294

297295
result = torch.empty_like(input, dtype=self.quant_dtype)
298296
scale = self.quant_matcher.make_scale(input)

vllm/compilation/matcher_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def __init__(self, epsilon: float, enabled: bool | None = None):
7373

7474
def inputs(self):
7575
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
76-
weight = self.empty(
77-
16,
78-
)
76+
weight = self.empty(16)
7977
return [input, weight]
8078

8179
def forward_custom(

0 commit comments

Comments
 (0)