Skip to content

Commit 7e2c490

Browse files
committed
cleanup
Signed-off-by: Bill Nell <bill@neuralmagic.com>
1 parent b75cbba commit 7e2c490

File tree

2 files changed

+27
-35
lines changed

2 files changed

+27
-35
lines changed

vllm/compilation/backends.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,16 +253,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
253253
self.compilation_configs.init_during_runtime()
254254
self.configure_post_pass()
255255

256-
if "before_split_graph" in self.compilation_configs.pass_config.dump_graph_stages:
256+
if ("before_split_graph"
257+
in self.compilation_configs.pass_config.dump_graph_stages):
257258
dump_graph(self.compilation_configs.pass_config, graph.graph,
258259
"before_split_graph")
259260

260261
self.split_gm, self.piecewise_graphs = split_graph(
261262
graph, self.compilation_configs.splitting_ops)
262263

263-
if "after_split_graph" in self.compilation_configs.pass_config.dump_graph_stages:
264-
dump_graph(self.compilation_configs.pass_config, self.split_gm.graph,
265-
"after_split_graph")
264+
if ("after_split_graph"
265+
in self.compilation_configs.pass_config.dump_graph_stages):
266+
dump_graph(self.compilation_configs.pass_config,
267+
self.split_gm.graph, "after_split_graph")
266268

267269
from torch._dynamo.utils import lazy_format_graph_code
268270
logger.debug("%s", lazy_format_graph_code("before split", self.graph))

vllm/compilation/collective_fusion.py

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
logger.info("Attempting to use flux but flux not installed.")
3232
use_flux = False
3333

34-
3534
# Depends on arch, see auto_tile_shape in include/flux/gemm_hparams.h
3635
# Can be 256 on sm80.
3736
FLUX_TILE_SIZE: int = 128
@@ -60,11 +59,11 @@ def residual_slice_shape_fake(residual: torch.Tensor, rank: int) -> int:
6059

6160

6261
def match_gemm_rs_ag_gemm(
63-
residual: torch.Tensor,
64-
gemm_1_weights: torch.Tensor,
65-
gemm_1_activations: torch.Tensor,
66-
rms_norm_weights: torch.Tensor,
67-
gemm_2_weights: torch.Tensor,
62+
residual: torch.Tensor,
63+
gemm_1_weights: torch.Tensor,
64+
gemm_1_activations: torch.Tensor,
65+
rms_norm_weights: torch.Tensor,
66+
gemm_2_weights: torch.Tensor,
6867
) -> Tuple[torch.Tensor, torch.Tensor]:
6968
gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0])
7069
mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm)
@@ -239,10 +238,10 @@ def gemm_rs_ag_gemm_fake(
239238

240239

241240
def match_final(
242-
my_residual: torch.Tensor,
243-
gemm_1_weights: torch.Tensor,
244-
gemm_1_activations: torch.Tensor,
245-
rms_norm_weights: torch.Tensor,
241+
my_residual: torch.Tensor,
242+
gemm_1_weights: torch.Tensor,
243+
gemm_1_activations: torch.Tensor,
244+
rms_norm_weights: torch.Tensor,
246245
) -> torch.Tensor:
247246
gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0])
248247
mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm)
@@ -260,7 +259,7 @@ def match_final(
260259
return normalized
261260

262261

263-
# Register this as a custom op since all reduce cannot be torch.compiled yet.
262+
# Register this as a custom op since all gather cannot be torch.compiled yet.
264263
def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
265264
gemm_1_activations: torch.Tensor,
266265
rms_norm_weights: torch.Tensor) -> torch.Tensor:
@@ -333,17 +332,14 @@ def __init__(self, config: CompilationConfig):
333332
inputs = [resid, x, w, resid_w, x2]
334333
final_inputs = [x, w, resid, resid_w]
335334

336-
register_replacement(
337-
match_gemm_rs_ag_gemm,
338-
match_gemm_rs_ag_gemm,
339-
inputs,
340-
fwd_only, [self.gemm_rs_ag_gemm_pattern],
341-
extra_check=lambda m: self.record_match(m))
335+
register_replacement(match_gemm_rs_ag_gemm,
336+
match_gemm_rs_ag_gemm,
337+
inputs,
338+
fwd_only, [self.gemm_rs_ag_gemm_pattern],
339+
extra_check=lambda m: self.record_match(m))
342340

343-
register_replacement(match_final
344-
torch.ops.vllm.gemm_ag_final,
345-
final_inputs, fwd_only,
346-
[self.final_pattern])
341+
register_replacement(match_final, torch.ops.vllm.gemm_ag_final,
342+
final_inputs, fwd_only, [self.final_pattern])
347343

348344
def record_match(self, match: Match) -> bool:
349345
# Hijack the extra_check to record the match and
@@ -394,16 +390,10 @@ def find_min_index(match: Match) -> int:
394390

395391
# Extract group_name from matched code. Use to
396392
# generate proper replacement code.
397-
#ar_node = find_auto_fn(match.nodes, torch.ops.vllm.inplace_all_reduce.default)
398-
ar_node = None
399-
if ar_node is not None:
400-
tp_group_name = ar_node.kwargs["group_name"]
401-
else:
402-
ar_node = find_fn(
403-
match.nodes,
404-
torch.ops.vllm.all_reduce.default)
405-
assert ar_node is not None
406-
tp_group_name = ar_node.args[1]
393+
ar_node = find_fn(match.nodes,
394+
torch.ops.vllm.all_reduce.default)
395+
assert ar_node is not None
396+
tp_group_name = ar_node.args[1]
407397

408398
fused_gemm_func = get_gemm_rs_ag_gemm(
409399
use_flux, max_m, gemm_1.dtype, gemm_1.shape, gemm_2.dtype,

0 commit comments

Comments
 (0)