2626logger = init_logger (__name__ )
2727
2828
29- def pprint (x ):
30- pass
29+ class InductorHashCache :
30+ """
31+ Disk format: a Python list of tuples, each tuple is
32+ (runtime_shape, graph_index, hash_str)
33+ We use list of tuple for readability.
34+
35+ In-memory format: a defaultdict of dict, where the key is
36+ runtime_shape, and the value is a dict of graph_index to hash_str.
37+
38+ The data is essentially `Dict[Optional[int], Dict[int, str]]`,
39+ we don't use json here because json doesn't support int as key.
40+
41+ TODO: better off-the-shelf solution to serialize the data?
42+ """
43+
44+ def __init__ (self , cache_dir : str , disabled : bool = False ):
45+ self .cache : defaultdict = defaultdict (dict )
46+ self .disabled = disabled
47+ self .cache_dir = cache_dir
48+ self .cache_file_path = os .path .join (cache_dir ,
49+ "inductor_hash_cache.py" )
50+ if disabled :
51+ return
52+ # set flags so that Inductor and Triton store their cache
53+ # in the cache_dir, then users only need to copy the cache_dir
54+ # to another machine to reuse the cache.
55+ inductor_cache = os .path .join (cache_dir , "inductor_cache" )
56+ os .makedirs (inductor_cache , exist_ok = True )
57+ os .environ ["TORCHINDUCTOR_CACHE_DIR" ] = inductor_cache
58+ triton_cache = os .path .join (cache_dir , "triton_cache" )
59+ os .makedirs (triton_cache , exist_ok = True )
60+ os .environ ["TRITON_CACHE_DIR" ] = triton_cache
61+ if os .path .exists (self .cache_file_path ):
62+ with open (self .cache_file_path ) as f :
63+ self .deserialize (f .read ())
64+
65+ def deserialize (self , data : str ):
66+ # we use ast.literal_eval to parse the data
67+ # because it is a safe way to parse Python literals.
68+ # do not use eval(), it is unsafe.
69+ try :
70+ list_data = ast .literal_eval (data )
71+ for runtime_shape , graph_index , hash_str in list_data :
72+ self .cache [runtime_shape ][graph_index ] = hash_str
73+ except Exception as ex :
74+ logger .warning ("Unable to read cache: %s, error: %s" , self .cache_file_path , ex )
75+ self .cache .clear ()
76+ self .disabled = True
77+
78+ def serialize (self ) -> str :
79+ data = []
80+ for runtime_shape , graph_index_to_hash_str in self .cache .items ():
81+ for graph_index , hash_str in graph_index_to_hash_str .items ():
82+ data .append ((runtime_shape , graph_index , hash_str ))
83+ printer = pprint .PrettyPrinter (indent = 4 )
84+ return printer .pformat (data )
85+
86+ def save_to_file (self ):
87+ if self .disabled :
88+ return
89+ with open (self .cache_file_path , "w" ) as f :
90+ f .write (self .serialize ())
91+
92+ def __contains__ (self , key : Tuple [Optional [int ], int ]) -> bool :
93+ if self .disabled :
94+ return False
95+ runtime_shape , graph_index = key
96+ return runtime_shape in self .cache and graph_index in self .cache [
97+ runtime_shape ]
98+
99+ def __getitem__ (self , key : Tuple [Optional [int ], int ]) -> str :
100+ if self .disabled :
101+ raise KeyError ("cannot read from disabled cache" )
102+ runtime_shape , graph_index = key
103+ return self .cache [runtime_shape ][graph_index ]
104+
105+ def __setitem__ (self , key : Tuple [Optional [int ], int ], value : str ):
106+ # setitem for disabled cache is fine, because we
107+ # don't actually write to the disk
108+ runtime_shape , graph_index = key
109+ self .cache [runtime_shape ][graph_index ] = value
110+
111+
112+ class AlwaysHitShapeEnv :
113+ """
114+ Why do we need this class:
115+
116+ For normal `torch.compile` usage, every compilation will have
117+ one Dynamo bytecode compilation and one Inductor compilation.
118+ The Inductor compilation happens under the context of the
119+ Dynamo bytecode compilation, and that context is used to
120+ determine the dynamic shape information, etc.
121+
122+ For our use case, we only run Dynamo bytecode compilation once,
123+ and run Inductor compilation multiple times with different shapes
124+ plus a general shape. The compilation for specific shapes happens
125+ outside of the context of the Dynamo bytecode compilation. At that
126+ time, we don't have shape environment to provide to Inductor, and
127+ it will fail the Inductor code cache lookup.
128+
129+ By providing a dummy shape environment that always hits, we can
130+ make the Inductor code cache lookup always hit, and we can
131+ compile the graph for different shapes as needed.
132+
133+ The following dummy methods are obtained by trial-and-error
134+ until it works.
135+ """
136+
137+ def __init__ (self ) -> None :
138+ self .guards : List [Any ] = []
139+
140+ def evaluate_guards_expression (self , * args , ** kwargs ):
141+ return True
142+
143+ def get_pruned_guards (self , * args , ** kwargs ):
144+ return []
145+
146+ def produce_guards_expression (self , * args , ** kwargs ):
147+ return ""
31148
32149
33150def wrap_inductor (graph : fx .GraphModule ,
@@ -369,6 +486,7 @@ def configure_post_pass(self):
369486 inductor_config [PASS_KEY ] = self .post_grad_pass_manager
370487
371488 def __call__ (self , graph : fx .GraphModule , example_inputs ) -> Callable :
489+
372490 # when dynamo calls the backend, it means the bytecode
373491 # transform and analysis are done
374492 compilation_counter .num_graphs_seen += 1
@@ -385,16 +503,16 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
385503 self .configure_post_pass ()
386504
387505 if ("before_split_graph"
388- in self .compilation_configs .pass_config .dump_graph_stages ):
389- dump_graph (self .compilation_configs .pass_config , graph .graph ,
506+ in self .compilation_config .pass_config .dump_graph_stages ):
507+ dump_graph (self .compilation_config .pass_config , graph .graph ,
390508 "before_split_graph" )
391509
392510 self .split_gm , self .piecewise_graphs = split_graph (
393511 graph , self .compilation_config .splitting_ops )
394512
395513 if ("after_split_graph"
396- in self .compilation_configs .pass_config .dump_graph_stages ):
397- dump_graph (self .compilation_configs .pass_config ,
514+ in self .compilation_config .pass_config .dump_graph_stages ):
515+ dump_graph (self .compilation_config .pass_config ,
398516 self .split_gm .graph , "after_split_graph" )
399517
400518 compilation_counter .num_piecewise_graphs_seen += len (
@@ -541,13 +659,11 @@ def __call__(self, *args) -> Any:
541659 if not self .first_run_finished :
542660 self .first_run_finished = True
543661 self .check_for_ending_compilation ()
544- pprint (f"RUN GENERAL 1" )
545662 return self .compiled_graph_for_general_shape (* args )
546663
547664 runtime_shape = args [self .sym_shape_indices [0 ]]
548665 if runtime_shape not in self .concrete_size_entries :
549666 # we don't need to do anything for this shape
550- pprint (f"RUN GENERAL 2 - { runtime_shape } " )
551667 return self .compiled_graph_for_general_shape (* args )
552668
553669 entry = self .concrete_size_entries [runtime_shape ]
@@ -574,7 +690,6 @@ def __call__(self, *args) -> Any:
574690 self .check_for_ending_compilation ()
575691
576692 if not entry .use_cudagraph :
577- pprint (f"RUN STATIC { runtime_shape } " )
578693 return entry .runnable (* args )
579694
580695 if entry .cudagraph is None :
@@ -586,7 +701,6 @@ def __call__(self, *args) -> Any:
586701 entry .num_finished_warmup ,
587702 self .compilation_config .cudagraph_num_of_warmups ,
588703 runtime_shape )
589- pprint (f"RUN STATIC CUDAGRAPH WARMUP 1 { runtime_shape } " )
590704 return entry .runnable (* args )
591705
592706 if self .is_first_graph :
@@ -617,7 +731,6 @@ def __call__(self, *args) -> Any:
617731 # mind-exploding: carefully manage the reference and memory.
618732 with torch .cuda .graph (cudagraph , pool = self .graph_pool ):
619733 # `output` is managed by pytorch's cudagraph pool
620- pprint (f"RUN STATIC CUDAGRAPH WARMUP 2 { runtime_shape } " )
621734 output = entry .runnable (* args )
622735 if self .is_last_graph :
623736 # by converting it to weak ref,
@@ -649,6 +762,5 @@ def __call__(self, *args) -> Any:
649762 f" Expected { entry .input_addresses } , got { new_input_addresses } "
650763 )
651764
652- pprint (f"RUN STATIC CUDAGRAPH REPLAY { runtime_shape } " )
653765 entry .cudagraph .replay ()
654766 return entry .output
0 commit comments