Skip to content

Commit 590b3d2

Browse files
committed
cleanups
1 parent f0059cb commit 590b3d2

File tree

8 files changed

+222
-101
lines changed

8 files changed

+222
-101
lines changed

csrc/torch_bindings.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
377377
"bool silu_activation,"
378378
"int pad_slot_id) -> ()");
379379
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
380-
#endif
380+
381+
ops.def("cublas_gemm_rs() -> Tensor");
382+
ops.impl("cublas_gemm_rs", torch::kCUDA, &cublas_gemm_rs);
383+
ops.def("cublas_ag_gemm() -> Tensor");
384+
ops.impl("cublas_ag_gemm", torch::kCUDA, &cublas_ag_gemm);
385+
#endif // !USE_ROCM
381386

382387
// Quantized GEMM for GPTQ.
383388
// Note: even though the C++ inferred schema is correct for this op, it seems

vllm/compilation/backends.py

Lines changed: 124 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,125 @@
2626
logger = init_logger(__name__)
2727

2828

29-
def pprint(x):
30-
pass
29+
class InductorHashCache:
30+
"""
31+
Disk format: a Python list of tuples, each tuple is
32+
(runtime_shape, graph_index, hash_str)
33+
We use list of tuple for readability.
34+
35+
In-memory format: a defaultdict of dict, where the key is
36+
runtime_shape, and the value is a dict of graph_index to hash_str.
37+
38+
The data is essentially `Dict[Optional[int], Dict[int, str]]`,
39+
we don't use json here because json doesn't support int as key.
40+
41+
TODO: better off-the-shelf solution to serialize the data?
42+
"""
43+
44+
def __init__(self, cache_dir: str, disabled: bool = False):
45+
self.cache: defaultdict = defaultdict(dict)
46+
self.disabled = disabled
47+
self.cache_dir = cache_dir
48+
self.cache_file_path = os.path.join(cache_dir,
49+
"inductor_hash_cache.py")
50+
if disabled:
51+
return
52+
# set flags so that Inductor and Triton store their cache
53+
# in the cache_dir, then users only need to copy the cache_dir
54+
# to another machine to reuse the cache.
55+
inductor_cache = os.path.join(cache_dir, "inductor_cache")
56+
os.makedirs(inductor_cache, exist_ok=True)
57+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
58+
triton_cache = os.path.join(cache_dir, "triton_cache")
59+
os.makedirs(triton_cache, exist_ok=True)
60+
os.environ["TRITON_CACHE_DIR"] = triton_cache
61+
if os.path.exists(self.cache_file_path):
62+
with open(self.cache_file_path) as f:
63+
self.deserialize(f.read())
64+
65+
def deserialize(self, data: str):
66+
# we use ast.literal_eval to parse the data
67+
# because it is a safe way to parse Python literals.
68+
# do not use eval(), it is unsafe.
69+
try:
70+
list_data = ast.literal_eval(data)
71+
for runtime_shape, graph_index, hash_str in list_data:
72+
self.cache[runtime_shape][graph_index] = hash_str
73+
except Exception as ex:
74+
logger.warning("Unable to read cache: %s, error: %s", self.cache_file_path, ex)
75+
self.cache.clear()
76+
self.disabled = True
77+
78+
def serialize(self) -> str:
79+
data = []
80+
for runtime_shape, graph_index_to_hash_str in self.cache.items():
81+
for graph_index, hash_str in graph_index_to_hash_str.items():
82+
data.append((runtime_shape, graph_index, hash_str))
83+
printer = pprint.PrettyPrinter(indent=4)
84+
return printer.pformat(data)
85+
86+
def save_to_file(self):
87+
if self.disabled:
88+
return
89+
with open(self.cache_file_path, "w") as f:
90+
f.write(self.serialize())
91+
92+
def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
93+
if self.disabled:
94+
return False
95+
runtime_shape, graph_index = key
96+
return runtime_shape in self.cache and graph_index in self.cache[
97+
runtime_shape]
98+
99+
def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
100+
if self.disabled:
101+
raise KeyError("cannot read from disabled cache")
102+
runtime_shape, graph_index = key
103+
return self.cache[runtime_shape][graph_index]
104+
105+
def __setitem__(self, key: Tuple[Optional[int], int], value: str):
106+
# setitem for disabled cache is fine, because we
107+
# don't actually write to the disk
108+
runtime_shape, graph_index = key
109+
self.cache[runtime_shape][graph_index] = value
110+
111+
112+
class AlwaysHitShapeEnv:
113+
"""
114+
Why do we need this class:
115+
116+
For normal `torch.compile` usage, every compilation will have
117+
one Dynamo bytecode compilation and one Inductor compilation.
118+
The Inductor compilation happens under the context of the
119+
Dynamo bytecode compilation, and that context is used to
120+
determine the dynamic shape information, etc.
121+
122+
For our use case, we only run Dynamo bytecode compilation once,
123+
and run Inductor compilation multiple times with different shapes
124+
plus a general shape. The compilation for specific shapes happens
125+
outside of the context of the Dynamo bytecode compilation. At that
126+
time, we don't have shape environment to provide to Inductor, and
127+
it will fail the Inductor code cache lookup.
128+
129+
By providing a dummy shape environment that always hits, we can
130+
make the Inductor code cache lookup always hit, and we can
131+
compile the graph for different shapes as needed.
132+
133+
The following dummy methods are obtained by trial-and-error
134+
until it works.
135+
"""
136+
137+
def __init__(self) -> None:
138+
self.guards: List[Any] = []
139+
140+
def evaluate_guards_expression(self, *args, **kwargs):
141+
return True
142+
143+
def get_pruned_guards(self, *args, **kwargs):
144+
return []
145+
146+
def produce_guards_expression(self, *args, **kwargs):
147+
return ""
31148

32149

33150
def wrap_inductor(graph: fx.GraphModule,
@@ -369,6 +486,7 @@ def configure_post_pass(self):
369486
inductor_config[PASS_KEY] = self.post_grad_pass_manager
370487

371488
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
489+
372490
# when dynamo calls the backend, it means the bytecode
373491
# transform and analysis are done
374492
compilation_counter.num_graphs_seen += 1
@@ -385,16 +503,16 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
385503
self.configure_post_pass()
386504

387505
if ("before_split_graph"
388-
in self.compilation_configs.pass_config.dump_graph_stages):
389-
dump_graph(self.compilation_configs.pass_config, graph.graph,
506+
in self.compilation_config.pass_config.dump_graph_stages):
507+
dump_graph(self.compilation_config.pass_config, graph.graph,
390508
"before_split_graph")
391509

392510
self.split_gm, self.piecewise_graphs = split_graph(
393511
graph, self.compilation_config.splitting_ops)
394512

395513
if ("after_split_graph"
396-
in self.compilation_configs.pass_config.dump_graph_stages):
397-
dump_graph(self.compilation_configs.pass_config,
514+
in self.compilation_config.pass_config.dump_graph_stages):
515+
dump_graph(self.compilation_config.pass_config,
398516
self.split_gm.graph, "after_split_graph")
399517

400518
compilation_counter.num_piecewise_graphs_seen += len(
@@ -541,13 +659,11 @@ def __call__(self, *args) -> Any:
541659
if not self.first_run_finished:
542660
self.first_run_finished = True
543661
self.check_for_ending_compilation()
544-
pprint(f"RUN GENERAL 1")
545662
return self.compiled_graph_for_general_shape(*args)
546663

547664
runtime_shape = args[self.sym_shape_indices[0]]
548665
if runtime_shape not in self.concrete_size_entries:
549666
# we don't need to do anything for this shape
550-
pprint(f"RUN GENERAL 2 - {runtime_shape}")
551667
return self.compiled_graph_for_general_shape(*args)
552668

553669
entry = self.concrete_size_entries[runtime_shape]
@@ -574,7 +690,6 @@ def __call__(self, *args) -> Any:
574690
self.check_for_ending_compilation()
575691

576692
if not entry.use_cudagraph:
577-
pprint(f"RUN STATIC {runtime_shape}")
578693
return entry.runnable(*args)
579694

580695
if entry.cudagraph is None:
@@ -586,7 +701,6 @@ def __call__(self, *args) -> Any:
586701
entry.num_finished_warmup,
587702
self.compilation_config.cudagraph_num_of_warmups,
588703
runtime_shape)
589-
pprint(f"RUN STATIC CUDAGRAPH WARMUP 1 {runtime_shape}")
590704
return entry.runnable(*args)
591705

592706
if self.is_first_graph:
@@ -617,7 +731,6 @@ def __call__(self, *args) -> Any:
617731
# mind-exploding: carefully manage the reference and memory.
618732
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
619733
# `output` is managed by pytorch's cudagraph pool
620-
pprint(f"RUN STATIC CUDAGRAPH WARMUP 2 {runtime_shape}")
621734
output = entry.runnable(*args)
622735
if self.is_last_graph:
623736
# by converting it to weak ref,
@@ -649,6 +762,5 @@ def __call__(self, *args) -> Any:
649762
f" Expected {entry.input_addresses}, got {new_input_addresses}"
650763
)
651764

652-
pprint(f"RUN STATIC CUDAGRAPH REPLAY {runtime_shape}")
653765
entry.cudagraph.replay()
654766
return entry.output

vllm/compilation/collective_fusion.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
fwd_only, register_replacement)
88

99
import vllm.envs as envs
10-
from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem,
11-
find_op, last_node_in_match)
10+
from vllm.compilation.fx_utils import (find_auto_fn, find_fn, find_getitem,
11+
find_op, last_node_in_match)
1212
from vllm.config import CompilationConfig
1313
from vllm.distributed import (tensor_model_parallel_all_gather,
1414
tensor_model_parallel_all_reduce)
@@ -19,6 +19,8 @@
1919

2020
from .inductor_pass import get_pass_context
2121
from .vllm_inductor_pass import VllmInductorPass
22+
from .utils import use_cc_kernels
23+
2224

2325
logger = init_logger(__name__)
2426

@@ -32,21 +34,11 @@
3234
logger.info("Attempting to use flux but flux not installed.")
3335
use_flux = False
3436

35-
# Depends on arch, see auto_tile_shape in include/flux/gemm_hparams.h
36-
# Can be 256 on sm80.
37-
FLUX_TILE_SIZE: int = 128
38-
3937

4038
def get_world_name() -> str:
4139
return torch.distributed.group.WORLD.group_name
4240

4341

44-
def use_cc_kernels(m_shape: int) -> bool:
45-
n_slices = get_tensor_model_parallel_world_size()
46-
return (m_shape % (FLUX_TILE_SIZE * n_slices) == 0
47-
and m_shape >= FLUX_TILE_SIZE * n_slices)
48-
49-
5042
def residual_slice_shape(residual: torch.Tensor, rank: int) -> int:
5143
n_slices = get_tensor_model_parallel_world_size()
5244
assert residual.size(0) % n_slices == 0
@@ -79,7 +71,7 @@ def match_gemm_rs_ag_gemm(
7971
return mm_2, new_residual
8072

8173

82-
def get_gemm_rs_ag_gemm(use_flux: bool, max_m: int, gemm_1_type: torch.dtype,
74+
def get_gemm_rs_ag_gemm(max_m: int, gemm_1_type: torch.dtype,
8375
gemm_1_weights: torch.Size, gemm_2_type: torch.dtype,
8476
gemm_2_weights: torch.Size,
8577
tp_group_name: str,
@@ -213,7 +205,6 @@ def gemm_rs_ag_gemm_static(
213205
rms_norm_weights: torch.Tensor, gemm_2_weights: torch.Tensor,
214206
first_layer: bool,
215207
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
216-
print(f"START STATIC FLUX {residual.shape} {first_layer}")
217208
if first_layer:
218209
slice_shape = residual_slice_shape(residual, rank)
219210
residual_chunk = torch.ops.aten.split.Tensor(residual, slice_shape)
@@ -237,8 +228,6 @@ def gemm_rs_ag_gemm_static(
237228

238229
mm_2 = ag_gemm(output, gemm_2_weights)
239230

240-
print(f"END STATIC FLUX {residual.shape} {first_layer}")
241-
242231
return mm_2, new_residual, slice_scatter
243232

244233
def gemm_rs_ag_gemm_fake(
@@ -304,14 +293,12 @@ def match_final(
304293
def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
305294
gemm_1_activations: torch.Tensor,
306295
rms_norm_weights: torch.Tensor) -> torch.Tensor:
307-
# TODO: use ag gemm here?
308296
mm_1 = torch.ops.aten.mm.default(gemm_1_activations,
309297
gemm_1_weights.transpose(1, 0))
310298

311299
reduced = tensor_model_parallel_all_reduce(mm_1)
312300

313301
if use_cc_kernels(reduced.size(0)):
314-
print(f"ALL GATHER {my_residual.size()}, {reduced.size()}")
315302
wait_tensor = tensor_model_parallel_all_gather(my_residual)
316303
else:
317304
assert reduced.size() == my_residual.size()
@@ -322,15 +309,12 @@ def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
322309
weight=rms_norm_weights,
323310
epsilon=1e-05)
324311

325-
print(f"DONE FINAL {my_residual.size()}, {reduced.size()}")
326-
327312
return reduced
328313

329314

330315
def gemm_ag_final_static(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
331316
gemm_1_activations: torch.Tensor,
332317
rms_norm_weights: torch.Tensor) -> torch.Tensor:
333-
# TODO: use ag gemm here?
334318
mm_1 = torch.ops.aten.mm.default(gemm_1_activations,
335319
gemm_1_weights.transpose(1, 0))
336320

@@ -507,7 +491,7 @@ def find_min_index(match: Match) -> int:
507491
tp_group_name = ar_node.args[1]
508492

509493
fused_gemm_func, fused_gemm_fake_func = get_gemm_rs_ag_gemm(
510-
use_flux, max_m, gemm_1.dtype, gemm_1.shape, gemm_2.dtype,
494+
max_m, gemm_1.dtype, gemm_1.shape, gemm_2.dtype,
511495
gemm_2.shape, tp_group_name, self.is_static_shape())
512496

513497
fused_node = graph.call_function(fused_gemm_func,

vllm/compilation/fx_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,27 @@
44
from torch import fx
55
from torch._higher_order_ops.auto_functionalize import auto_functionalized
66
from torch._ops import OpOverload
7+
from torch._inductor.pattern_matcher import Match
78

89

910
def is_func(node: fx.Node, target) -> bool:
1011
return node.op == "call_function" and node.target == target
1112

1213

14+
def find_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]:
15+
for node in nodes:
16+
if node.op == "call_function" and node.target == op:
17+
return node
18+
return None
19+
20+
21+
def find_op(nodes: Iterable[fx.Node], op: str) -> Optional[fx.Node]:
22+
for node in nodes:
23+
if node.op == op:
24+
return node
25+
return None
26+
27+
1328
# Returns the first auto_functionalized node with the given op (if it exists)
1429
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
1530
op: OpOverload) -> Optional[fx.Node]:
@@ -40,3 +55,13 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
4055
ret = find_getitem_maybe(node, idx)
4156
assert ret is not None, f"Could not find getitem {idx} in node {node}"
4257
return ret
58+
59+
60+
def last_node_in_match(match: Match) -> fx.Node:
61+
if len(match.nodes) > 0:
62+
graph = match.nodes[0].graph
63+
for n in reversed(graph.nodes):
64+
if n in reversed(match.nodes):
65+
return n
66+
raise ValueError("No nodes in graph")
67+

0 commit comments

Comments
 (0)