22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44import dataclasses
5- from typing import Any , Callable
5+ from typing import Any , Callable , Optional
66
77import torch .fx as fx
88
1111from vllm .compilation .monitor import end_monitoring_torch_compile
1212from vllm .config import VllmConfig
1313from vllm .logger import init_logger
14- from typing import Optional
1514
1615logger = init_logger (__name__ )
1716
17+
1818@dataclasses .dataclass
19- class ConditionalEntry :
20- runtime_shape : int
19+ class RangeEntry :
20+ compile_range : Optional [ tuple [ int , int ]]
2121 compiled : bool = False
2222 runnable : Callable = None # type: ignore
23- runtime_range : Optional [tuple [int ,
24- int ]] = None # only used for range entries
2523
2624
2725class PiecewiseBackend :
@@ -55,9 +53,25 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
5553
5654 self .compile_sizes : set [int ] = set (
5755 self .compilation_config .compile_sizes )
58- self .compile_ranges : tuple [
59- int , int ] = self .compilation_config .compile_ranges
60- self .is_in_range = lambda x , range : range [0 ] <= x <= range [1 ]
56+ self .compile_ranges_split_points : list [
57+ int ] = self .compilation_config .compile_ranges_split_points
58+ self .compile_ranges = []
59+ split_points = sorted (
60+ set (self .compile_sizes ).union (set (
61+ self .compile_ranges_split_points )))
62+ for i , s in enumerate (split_points ):
63+ if i == 0 :
64+ self .compile_ranges .append ((1 , s ))
65+ else :
66+ self .compile_ranges .append ((split_points [i - 1 ], s ))
67+ if s in self .compile_sizes :
68+ self .compile_ranges .append ((s , s ))
69+ self .compile_ranges = sorted (self .compile_ranges )
70+ logger .debug ("PiecewiseBackend: compile_ranges: %s" ,
71+ self .compile_ranges )
72+
73+ self .is_in_range = lambda x , range : range [0 ] <= x < range [1 ] if range [
74+ 0 ] < range [1 ] else x == range [0 ]
6175
6276 self .first_run_finished = False
6377
@@ -68,28 +82,26 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
6882 self .is_debugging_mode = envs .VLLM_LOGGING_LEVEL == "DEBUG"
6983
7084 # the entries for different shapes that we need to compile
71- self .concrete_size_entries : dict [int , ConditionalEntry ] = {}
85+ # self.concrete_size_entries: dict[int, RangeEntry ] = {}
7286
7387 # the entries for ranges that we need to either
7488 # TODO: we should merge with concrete_size_entries
75- self .range_entries : dict [tuple [int , int ], ConditionalEntry ] = {}
89+ self .range_entries : dict [tuple [int , int ], RangeEntry ] = {}
7690
77- # to_be_compiled_sizes tracks the remaining sizes to compile,
91+ # to_be_compiled_ranges tracks the remaining ranges to compile,
7892 # and updates during the compilation process, so we need to copy it
79- self .to_be_compiled_sizes : set [int ] = self .compile_sizes .copy ()
8093 self .to_be_compiled_ranges : set [tuple [int ,
8194 int ]] = set (self .compile_ranges )
8295
8396 # We only keep compilation management inside this class directly.
84- for shape in self .compile_sizes :
85- self .concrete_size_entries [ shape ] = ConditionalEntry (
86- runtime_shape = shape ,
97+ for range in self .compile_ranges :
98+ self .range_entries [ range ] = RangeEntry (
99+ compile_range = range ,
87100 runnable = self .compiled_graph_for_general_shape ,
88101 )
89102
90103 def check_for_ending_compilation (self ):
91- if (self .is_last_graph and not self .to_be_compiled_sizes
92- and not self .to_be_compiled_ranges ):
104+ if (self .is_last_graph and not self .to_be_compiled_ranges ):
93105 # no specific sizes to compile
94106 # save the hash of the inductor graph for the next run
95107 self .vllm_backend .compiler_manager .save_to_file ()
@@ -103,47 +115,32 @@ def __call__(self, *args) -> Any:
103115
104116 runtime_shape = args [self .sym_shape_indices [0 ]]
105117
106-
107118 range_entry = None
108119 for range in self .compile_ranges :
109120 if self .is_in_range (runtime_shape , range ):
110- if range not in self .range_entries :
111- self .range_entries [range ] = ConditionalEntry (
112- runtime_shape = runtime_shape ,
113- runtime_range = range ,
114- )
115121 range_entry = self .range_entries [range ]
116122 break
117123
118- if (runtime_shape not in self .concrete_size_entries
119- and range_entry is None ):
124+ if (range_entry is None ):
120125 # we don't need to do anything for this shape
121126 return self .compiled_graph_for_general_shape (* args )
122127
123- if range_entry is not None :
124- entry = range_entry
125- else :
126- entry = self .concrete_size_entries [runtime_shape ]
128+ if not range_entry .compiled :
129+ range_entry .compiled = True
130+ self .to_be_compiled_ranges .remove (range_entry .compile_range )
127131
128- if not entry .compiled :
129- entry .compiled = True
130- if range_entry is not None :
131- self .to_be_compiled_ranges .remove (range_entry .runtime_range )
132- else :
133- self .to_be_compiled_sizes .remove (runtime_shape )
134132 # args are real arguments
135- entry .runnable = self .vllm_backend .compiler_manager .compile (
133+ range_entry .runnable = self .vllm_backend .compiler_manager .compile (
136134 self .graph ,
137135 args ,
138136 self .compilation_config .inductor_compile_config ,
139137 self .compilation_config ,
140138 graph_index = self .piecewise_compile_index ,
141139 num_graphs = self .total_piecewise_compiles ,
142- runtime_shape = runtime_shape )
140+ compile_range = range_entry . compile_range )
143141
144142 # finished compilations for all required shapes
145- if (self .is_last_graph and not self .to_be_compiled_sizes
146- and not self .to_be_compiled_ranges ):
143+ if (self .is_last_graph and not self .to_be_compiled_ranges ):
147144 self .check_for_ending_compilation ()
148145
149- return entry .runnable (* args )
146+ return range_entry .runnable (* args )
0 commit comments