77import pprint
88import time
99from collections .abc import Sequence
10+ from contextlib import contextmanager
1011from typing import Any , Callable , Optional
1112
1213import torch
@@ -66,7 +67,25 @@ def __init__(self, compilation_config: CompilationConfig):
6667 def compute_hash (self , vllm_config : VllmConfig ) -> str :
6768 return self .compiler .compute_hash (vllm_config )
6869
69- def initialize_cache (self , cache_dir : str , disable_cache : bool = False ):
70+ def initialize_cache (self ,
71+ cache_dir : str ,
72+ disable_cache : bool = False ,
73+ prefix : str = "" ):
74+ """
75+ Initialize the cache directory for the compiler.
76+
77+ The organization of the cache directory is as follows:
78+ cache_dir=/path/to/hash_str/rank_i_j/prefix/
79+ inside cache_dir, there will be:
80+ - vllm_compile_cache.py
81+ - computation_graph.py
82+ - transformed_code.py
83+
84+ for multiple prefixes, they can share the same
85+ base cache dir of /path/to/hash_str/rank_i_j/ ,
86+ to store some common compilation artifacts.
87+ """
88+
7089 self .disable_cache = disable_cache
7190 self .cache_dir = cache_dir
7291 self .cache_file_path = os .path .join (cache_dir , "vllm_compile_cache.py" )
@@ -80,7 +99,8 @@ def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
8099 self .cache = ast .literal_eval (f .read ())
81100
82101 self .compiler .initialize_cache (cache_dir = cache_dir ,
83- disable_cache = disable_cache )
102+ disable_cache = disable_cache ,
103+ prefix = prefix )
84104
85105 def save_to_file (self ):
86106 if self .disable_cache or not self .is_cache_updated :
@@ -310,6 +330,25 @@ def call_module(self, target: torch.fx.node.Target,
310330 return output
311331
312332
333+ # the tag for the part of model being compiled,
334+ # e.g. backbone/eagle_head
335+ model_tag : str = "backbone"
336+
337+
338+ @contextmanager
339+ def set_model_tag (tag : str ):
340+ """Context manager to set the model tag."""
341+ global model_tag
342+ assert tag != model_tag , \
343+ f"Model tag { tag } is the same as the current tag { model_tag } ."
344+ old_tag = model_tag
345+ model_tag = tag
346+ try :
347+ yield
348+ finally :
349+ model_tag = old_tag
350+
351+
313352class VllmBackend :
314353 """The compilation backend for `torch.compile` with vLLM.
315354 It is used for compilation level of `CompilationLevel.PIECEWISE`,
@@ -341,7 +380,17 @@ class VllmBackend:
341380 def __init__ (
342381 self ,
343382 vllm_config : VllmConfig ,
383+ prefix : str = "" ,
344384 ):
385+
386+ # if the model is initialized with a non-empty prefix,
387+ # then usually it's enough to use that prefix,
388+ # e.g. launguage_model, vision_model, etc.
389+ # when multiple parts are initialized as independent
390+ # models, we need to use the model_tag to distinguish
391+ # them, e.g. backbone (default), eagle_head, etc.
392+ self .prefix = prefix or model_tag
393+
345394 global global_graph_pool
346395 if global_graph_pool is None :
347396 global_graph_pool = current_platform .graph_pool_handle ()
@@ -441,16 +490,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
441490 )
442491 self .compilation_config .cache_dir = cache_dir
443492
444- if compilation_counter .num_graphs_seen > 0 :
445- cache_dir = self .compilation_config .cache_dir + \
446- f'-{ compilation_counter .num_graphs_seen } '
447- else :
448- cache_dir = self .compilation_config .cache_dir
493+ cache_dir = self .compilation_config .cache_dir
449494 os .makedirs (cache_dir , exist_ok = True )
450495 self .compilation_config .cache_dir = cache_dir
451496 rank = vllm_config .parallel_config .rank
452497 dp_rank = vllm_config .parallel_config .data_parallel_rank
453- local_cache_dir = os .path .join (cache_dir , f"rank_{ rank } _{ dp_rank } " )
498+ local_cache_dir = os .path .join (cache_dir , f"rank_{ rank } _{ dp_rank } " ,
499+ self .prefix )
454500 os .makedirs (local_cache_dir , exist_ok = True )
455501 self .compilation_config .local_cache_dir = local_cache_dir
456502
@@ -462,7 +508,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
462508 logger .info ("Using cache directory: %s for vLLM's torch.compile" ,
463509 local_cache_dir )
464510
465- self .compiler_manager .initialize_cache (local_cache_dir , disable_cache )
511+ self .compiler_manager .initialize_cache (local_cache_dir , disable_cache ,
512+ self .prefix )
466513
467514 # when dynamo calls the backend, it means the bytecode
468515 # transform and analysis are done
0 commit comments