Skip to content

Commit 3e4e159

Browse files
author
ilmarkov
committed
Fix QuantFP8 matching
Update range based compilation Signed-off-by: ilmarkov <imarkov@redhat.com>
1 parent b85e752 commit 3e4e159

File tree

9 files changed

+569
-358
lines changed

9 files changed

+569
-358
lines changed

vllm/compilation/backends.py

Lines changed: 144 additions & 81 deletions
Large diffs are not rendered by default.

vllm/compilation/collective_fusion.py

Lines changed: 327 additions & 163 deletions
Large diffs are not rendered by default.

vllm/compilation/compiler_interface.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,17 @@ def compile(
6464
graph: fx.GraphModule,
6565
example_inputs: list[Any],
6666
compiler_config: dict[str, Any],
67-
runtime_shape: Optional[int] = None,
67+
compile_range: Optional[tuple[int, int]] = None,
6868
key: Optional[str] = None,
6969
) -> tuple[Optional[Callable], Optional[Any]]:
7070
"""
7171
Compile the graph with the given example inputs and compiler config,
72-
with a runtime shape. If the `runtime_shape` is None, it means
72+
with a runtime shape. If the `compile_range` is None, it means
7373
the `example_inputs` have a dynamic shape. Otherwise, the
74-
`runtime_shape` specifies the shape of the inputs. Right now we only
75-
support one variable shape for all inputs, which is the batchsize
76-
(number of tokens) during inference.
74+
`compile_range` specifies the range of the inputs,
75+
it could be concrete size, e.g. (4, 4).
76+
Right now we only support one variable shape for all inputs,
77+
which is the batchsize (number of tokens) during inference.
7778
7879
Dynamo will make sure `graph(*example_inputs)` is valid.
7980
@@ -98,7 +99,7 @@ def load(self,
9899
graph: fx.GraphModule,
99100
example_inputs: list[Any],
100101
graph_index: int,
101-
runtime_shape: Optional[int] = None) -> Callable:
102+
compile_range: Optional[tuple[int, int]] = None) -> Callable:
102103
"""
103104
Load the compiled function from the handle.
104105
Raises an error if the handle is invalid.
@@ -188,22 +189,22 @@ def compile(
188189
graph: fx.GraphModule,
189190
example_inputs: list[Any],
190191
compiler_config: dict[str, Any],
191-
runtime_shape: Optional[int] = None,
192+
compile_range: Optional[tuple[int, int]] = None,
192193
key: Optional[str] = None,
193194
) -> tuple[Optional[Callable], Optional[Any]]:
194195
compilation_counter.num_inductor_compiles += 1
195196
current_config = {}
196197
if compiler_config is not None:
197198
current_config.update(compiler_config)
198-
set_inductor_config(current_config, runtime_shape)
199+
set_inductor_config(current_config, compile_range)
199200

200-
if isinstance(runtime_shape, int):
201+
if isinstance(compile_range, tuple):
201202
dynamic_shapes = "from_example_inputs"
202203
else:
203204
dynamic_shapes = "from_tracing_context"
204205

205206
from torch._inductor import standalone_compile
206-
with pass_context(runtime_shape):
207+
with pass_context(compile_range):
207208
compiled_graph = standalone_compile(
208209
graph,
209210
example_inputs,
@@ -223,7 +224,7 @@ def load(self,
223224
graph: fx.GraphModule,
224225
example_inputs: list[Any],
225226
graph_index: int,
226-
runtime_shape: Optional[int] = None) -> Callable:
227+
compile_range: Optional[tuple[int, int]] = None) -> Callable:
227228
assert isinstance(handle, tuple)
228229
assert isinstance(handle[0], str)
229230
assert isinstance(handle[1], str)
@@ -283,7 +284,7 @@ def compile(
283284
graph: fx.GraphModule,
284285
example_inputs: list[Any],
285286
compiler_config: dict[str, Any],
286-
runtime_shape: Optional[int] = None,
287+
compile_range: Optional[tuple[int, int]] = None,
287288
key: Optional[str] = None,
288289
) -> tuple[Optional[Callable], Optional[Any]]:
289290
compilation_counter.num_inductor_compiles += 1
@@ -296,7 +297,7 @@ def compile(
296297
current_config["fx_graph_cache"] = True
297298
current_config["fx_graph_remote_cache"] = False
298299

299-
set_inductor_config(current_config, runtime_shape)
300+
set_inductor_config(current_config, compile_range)
300301

301302
# inductor can inplace modify the graph, so we need to copy it
302303
# see https://github.com/pytorch/pytorch/issues/138980
@@ -433,7 +434,7 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
433434
torch._functorch.config.patch(
434435
enable_remote_autograd_cache=False))
435436

436-
with pass_context(runtime_shape):
437+
with pass_context(compile_range):
437438
compiled_graph = compile_fx(
438439
graph,
439440
example_inputs,
@@ -547,9 +548,9 @@ def metrics_context(self) -> contextlib.AbstractContextManager:
547548
return contextlib.nullcontext()
548549

549550

550-
def set_inductor_config(config, runtime_shape):
551-
if isinstance(runtime_shape, int):
552-
# for a specific batchsize, tuning triton kernel parameters
551+
def set_inductor_config(config, compile_range):
552+
if isinstance(compile_range, tuple):
553+
# for a specific range of batchsizes, tuning triton kernel parameters
553554
# can be beneficial
554555
config["max_autotune"] = True
555556
config["coordinate_descent_tuning"] = True

vllm/compilation/cuda_piecewise_backend.py

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import dataclasses
5-
from typing import Any, Callable
5+
from typing import Any, Callable, Optional
66

77
import torch.fx as fx
88

@@ -11,17 +11,15 @@
1111
from vllm.compilation.monitor import end_monitoring_torch_compile
1212
from vllm.config import VllmConfig
1313
from vllm.logger import init_logger
14-
from typing import Optional
1514

1615
logger = init_logger(__name__)
1716

17+
1818
@dataclasses.dataclass
19-
class ConditionalEntry:
20-
runtime_shape: int
19+
class RangeEntry:
20+
compile_range: Optional[tuple[int, int]]
2121
compiled: bool = False
2222
runnable: Callable = None # type: ignore
23-
runtime_range: Optional[tuple[int,
24-
int]] = None # only used for range entries
2523

2624

2725
class PiecewiseBackend:
@@ -55,9 +53,25 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
5553

5654
self.compile_sizes: set[int] = set(
5755
self.compilation_config.compile_sizes)
58-
self.compile_ranges: tuple[
59-
int, int] = self.compilation_config.compile_ranges
60-
self.is_in_range = lambda x, range: range[0] <= x <= range[1]
56+
self.compile_ranges_split_points: list[
57+
int] = self.compilation_config.compile_ranges_split_points
58+
self.compile_ranges = []
59+
split_points = sorted(
60+
set(self.compile_sizes).union(set(
61+
self.compile_ranges_split_points)))
62+
for i, s in enumerate(split_points):
63+
if i == 0:
64+
self.compile_ranges.append((1, s))
65+
else:
66+
self.compile_ranges.append((split_points[i - 1], s))
67+
if s in self.compile_sizes:
68+
self.compile_ranges.append((s, s))
69+
self.compile_ranges = sorted(self.compile_ranges)
70+
logger.debug("PiecewiseBackend: compile_ranges: %s",
71+
self.compile_ranges)
72+
73+
self.is_in_range = lambda x, range: range[0] <= x < range[1] if range[
74+
0] < range[1] else x == range[0]
6175

6276
self.first_run_finished = False
6377

@@ -68,28 +82,26 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
6882
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
6983

7084
# the entries for different shapes that we need to compile
71-
self.concrete_size_entries: dict[int, ConditionalEntry] = {}
85+
# self.concrete_size_entries: dict[int, RangeEntry] = {}
7286

7387
# the entries for ranges that we need to either
7488
# TODO: we should merge with concrete_size_entries
75-
self.range_entries: dict[tuple[int, int], ConditionalEntry] = {}
89+
self.range_entries: dict[tuple[int, int], RangeEntry] = {}
7690

77-
# to_be_compiled_sizes tracks the remaining sizes to compile,
91+
# to_be_compiled_ranges tracks the remaining ranges to compile,
7892
# and updates during the compilation process, so we need to copy it
79-
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
8093
self.to_be_compiled_ranges: set[tuple[int,
8194
int]] = set(self.compile_ranges)
8295

8396
# We only keep compilation management inside this class directly.
84-
for shape in self.compile_sizes:
85-
self.concrete_size_entries[shape] = ConditionalEntry(
86-
runtime_shape=shape,
97+
for range in self.compile_ranges:
98+
self.range_entries[range] = RangeEntry(
99+
compile_range=range,
87100
runnable=self.compiled_graph_for_general_shape,
88101
)
89102

90103
def check_for_ending_compilation(self):
91-
if (self.is_last_graph and not self.to_be_compiled_sizes
92-
and not self.to_be_compiled_ranges):
104+
if (self.is_last_graph and not self.to_be_compiled_ranges):
93105
# no specific sizes to compile
94106
# save the hash of the inductor graph for the next run
95107
self.vllm_backend.compiler_manager.save_to_file()
@@ -103,47 +115,32 @@ def __call__(self, *args) -> Any:
103115

104116
runtime_shape = args[self.sym_shape_indices[0]]
105117

106-
107118
range_entry = None
108119
for range in self.compile_ranges:
109120
if self.is_in_range(runtime_shape, range):
110-
if range not in self.range_entries:
111-
self.range_entries[range] = ConditionalEntry(
112-
runtime_shape=runtime_shape,
113-
runtime_range=range,
114-
)
115121
range_entry = self.range_entries[range]
116122
break
117123

118-
if (runtime_shape not in self.concrete_size_entries
119-
and range_entry is None):
124+
if (range_entry is None):
120125
# we don't need to do anything for this shape
121126
return self.compiled_graph_for_general_shape(*args)
122127

123-
if range_entry is not None:
124-
entry = range_entry
125-
else:
126-
entry = self.concrete_size_entries[runtime_shape]
128+
if not range_entry.compiled:
129+
range_entry.compiled = True
130+
self.to_be_compiled_ranges.remove(range_entry.compile_range)
127131

128-
if not entry.compiled:
129-
entry.compiled = True
130-
if range_entry is not None:
131-
self.to_be_compiled_ranges.remove(range_entry.runtime_range)
132-
else:
133-
self.to_be_compiled_sizes.remove(runtime_shape)
134132
# args are real arguments
135-
entry.runnable = self.vllm_backend.compiler_manager.compile(
133+
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
136134
self.graph,
137135
args,
138136
self.compilation_config.inductor_compile_config,
139137
self.compilation_config,
140138
graph_index=self.piecewise_compile_index,
141139
num_graphs=self.total_piecewise_compiles,
142-
runtime_shape=runtime_shape)
140+
compile_range=range_entry.compile_range)
143141

144142
# finished compilations for all required shapes
145-
if (self.is_last_graph and not self.to_be_compiled_sizes
146-
and not self.to_be_compiled_ranges):
143+
if (self.is_last_graph and not self.to_be_compiled_ranges):
147144
self.check_for_ending_compilation()
148145

149-
return entry.runnable(*args)
146+
return range_entry.runnable(*args)

vllm/compilation/inductor_pass.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525

2626
class PassContext:
2727

28-
def __init__(self, runtime_shape: Optional[int]):
29-
self.runtime_shape = runtime_shape
28+
def __init__(self, compile_range: Optional[tuple[int, int]]):
29+
self.compile_range = compile_range
3030

3131

3232
def get_pass_context() -> PassContext:
@@ -36,13 +36,13 @@ def get_pass_context() -> PassContext:
3636

3737

3838
@contextmanager
39-
def pass_context(runtime_shape: Optional[int]):
39+
def pass_context(compile_range: Optional[tuple[int, int]]):
4040
"""A context manager that stores the current pass context,
4141
usually it is a list of sizes to specialize.
4242
"""
4343
global _pass_context
4444
prev_context = _pass_context
45-
_pass_context = PassContext(runtime_shape)
45+
_pass_context = PassContext(compile_range)
4646
try:
4747
yield
4848
finally:
@@ -93,7 +93,8 @@ def hash_dict(dict_: dict[Any, Any]):
9393
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
9494
return hashlib.sha256(encoded).hexdigest()
9595

96-
def is_applicable_for_shape(self, shape: Optional[int]):
96+
def is_applicable_for_range(self, compile_range: Optional[tuple[int,
97+
int]]):
9798
return True
9899

99100

vllm/compilation/pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def __init__(self):
4343
self.passes: list[VllmInductorPass] = []
4444

4545
def __call__(self, graph: fx.Graph):
46-
shape = get_pass_context().runtime_shape
46+
compile_range = get_pass_context().compile_range
4747
for pass_ in self.passes:
48-
if pass_.is_applicable_for_shape(shape):
48+
if pass_.is_applicable_for_range(compile_range):
4949
pass_(graph)
5050

5151
# always run fix_functionalization last

vllm/compilation/sequence_parallelism.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,12 @@ def __init__(self, config: VllmConfig):
469469
# and allow multiple values of epsilon.
470470
torch._inductor.pattern_matcher._seen_patterns.clear()
471471

472-
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
472+
def is_applicable_for_range(
473+
self, compile_range: Optional[tuple[int, int]]) -> bool:
473474
tp_size = get_tensor_model_parallel_world_size()
474-
return shape is not None and shape % tp_size == 0
475+
return compile_range is not None and (
476+
compile_range[0]
477+
== compile_range[1]) and (compile_range[1] % tp_size == 0)
475478

476479
def __call__(self, graph: fx.Graph):
477480
self.begin()

0 commit comments

Comments
 (0)