@@ -195,7 +195,6 @@ def compile(
195195        hash_str , file_path  =  None , None 
196196        from  torch ._inductor .codecache  import  (FxGraphCache ,
197197                                               compiled_fx_graph_hash )
198- 
199198        if  torch .__version__ .startswith ("2.5" ):
200199            original_load  =  FxGraphCache .load 
201200            original_load_name  =  "torch._inductor.codecache.FxGraphCache.load" 
@@ -280,6 +279,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
280279                patch ("torch._inductor.codecache.FxGraphCache._get_shape_env" ,
281280                      _get_shape_env ))
282281
282+             from  torch ._functorch ._aot_autograd .autograd_cache  import  (
283+                 AOTAutogradCache )
284+ 
285+             # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache 
286+             if  hasattr (AOTAutogradCache , "_get_shape_env" ):
287+                 stack .enter_context (
288+                     patch (
289+                         "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env" ,
290+                         _get_shape_env ))
291+ 
283292            # for forcing the graph to be cached 
284293            stack .enter_context (
285294                patch (
@@ -325,11 +334,19 @@ def load(self,
325334        assert  isinstance (handle [1 ], str )
326335        hash_str  =  handle [0 ]
327336
337+         from  torch ._functorch ._aot_autograd .autograd_cache  import  (
338+             AOTAutogradCache )
328339        from  torch ._inductor .codecache  import  FxGraphCache 
329340        with  ExitStack () as  exit_stack :
330341            exit_stack .enter_context (
331342                patch ("torch._inductor.codecache.FxGraphCache._get_shape_env" ,
332343                      lambda  * args , ** kwargs : AlwaysHitShapeEnv ()))
344+             # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache 
345+             if  hasattr (AOTAutogradCache , "_get_shape_env" ):
346+                 exit_stack .enter_context (
347+                     patch (
348+                         "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env" ,
349+                         lambda  * args , ** kwargs : AlwaysHitShapeEnv ()))
333350
334351            # Dynamo metrics context, see method for more details. 
335352            exit_stack .enter_context (self .metrics_context ())
0 commit comments