11import copy
2+ import io
23import logging
34import os
45import pickle
6+ import pickletools
57import shutil
68from abc import ABC , abstractmethod
7- from typing import Any , Dict , List , Optional , Tuple , cast
9+ from typing import Any , Dict , List , Optional , Sequence , Tuple , cast
810
911import torch
10- from torch ._inductor .codecache import FxGraphCachePickler
12+ from torch ._inductor .codecache import FxGraphCachePickler , sha256_hash
1113from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
14+ from torch_tensorrt ._Input import Input
15+ from torch_tensorrt .dynamo ._settings import (
16+ _SETTINGS_TO_BE_ENGINE_INVARIANT ,
17+ CompilationSettings ,
18+ )
1219
1320_LOGGER : logging .Logger = logging .getLogger (__name__ )
1421
22+ UnpackedCacheHit = Tuple [
23+ bytes ,
24+ List [str ],
25+ List [str ],
26+ Sequence [Input ],
27+ CompilationSettings ,
28+ Optional [Dict [str , Any ]],
29+ ]
30+
1531
1632class BaseEngineCache (ABC ):
1733
@@ -24,7 +40,11 @@ def __init__(
2440 pass
2541
2642 @staticmethod
27- def get_hash (gm : torch .fx .GraphModule ) -> str :
43+ def get_hash (
44+ gm : torch .fx .GraphModule ,
45+ input_specs : Sequence [Input ],
46+ settings : CompilationSettings ,
47+ ) -> str :
2848 """Get the hash value of the GraphModule
2949
3050 Args:
@@ -39,7 +59,23 @@ def get_hash(gm: torch.fx.GraphModule) -> str:
3959 for name , param in new_gm .named_parameters ():
4060 param .data .zero_ ()
4161
42- hash_val = cast (str , FxGraphCachePickler .get_hash (new_gm ))
62+ graph_hash_val = cast (str , FxGraphCachePickler .get_hash (new_gm ))
63+
64+ input_spec_strs = [str (i ) for i in input_specs ]
65+ with io .BytesIO () as stream :
66+ input_specs_data = pickle .dumps (input_spec_strs )
67+ input_specs_data = pickletools .optimize (input_specs_data )
68+ input_specs_hash = sha256_hash (input_specs_data )
69+
70+ invariant_engine_specs = [
71+ str (getattr (settings , field )) for field in _SETTINGS_TO_BE_ENGINE_INVARIANT
72+ ]
73+ with io .BytesIO () as stream :
74+ engine_specs_data = pickle .dumps (invariant_engine_specs )
75+ engine_specs_data = pickletools .optimize (engine_specs_data )
76+ engine_specs_hash = sha256_hash (engine_specs_data )
77+
78+ hash_val : str = graph_hash_val + input_specs_hash + engine_specs_hash
4379
4480 return hash_val
4581
@@ -48,6 +84,8 @@ def pack(
4884 serialized_engine : bytes ,
4985 input_names : List [str ],
5086 output_names : List [str ],
87+ input_specs : Sequence [Input ],
88+ compilation_settings : CompilationSettings ,
5189 weight_name_map : Optional [Dict [Any , Any ]],
5290 ) -> bytes :
5391 """Pack serialized engine, input names, output names, and weight map into a single blob
@@ -56,40 +94,83 @@ def pack(
5694 serialized_engine (bytes): serialized TRT engine
5795 input_names (List[str]): input names of TRT engine
5896 output_names (List[str]): output names of TRT engine
97+ input_specs (Sequence[Input]): input specs of TRT engine
98+ compilation_settings (CompilationSettings): compilation settings of TRT engine
5999 weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting
60100
61101 Returns:
62102 bytes: packed blob
63103 """
104+
105+ settings = copy .deepcopy (compilation_settings )
64106 return pickle .dumps (
65107 {
66108 "serialized_engine" : bytes (serialized_engine ),
67109 "input_names" : input_names ,
68110 "output_names" : output_names ,
111+ "input_specs" : input_specs ,
112+ "compilation_settings" : settings ,
69113 "weight_name_map" : weight_name_map ,
70114 }
71115 )
72116
73117 @staticmethod
74- def unpack (
75- packed_obj : bytes ,
76- ) -> Tuple [bytes , List [str ], List [str ], Optional [Dict [Any , Any ]]]:
118+ def unpack (packed_obj : bytes ) -> UnpackedCacheHit :
77119 """Unpack packed blob into serialized engine, input names, output names, and weight map
78120
79121 Args:
80122 packed_obj (bytes): packed blob
81123
82124 Returns:
83- Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]: serialized engine, input names, output names, weight name map
125+ Tuple[bytes, List[str], List[str], Sequence[Input], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, input specs, CompilationSettings , weight name map
84126 """
85127 unpacked = pickle .loads (packed_obj )
86128 return (
87129 unpacked ["serialized_engine" ],
88130 unpacked ["input_names" ],
89131 unpacked ["output_names" ],
132+ unpacked ["input_specs" ],
133+ unpacked ["compilation_settings" ],
90134 unpacked ["weight_name_map" ],
91135 )
92136
137+ def insert (
138+ self , hash : str , entry : UnpackedCacheHit , * args : Any , ** kwargs : Any
139+ ) -> None :
140+ """
141+ Insert a cache entry into the engine cache.
142+
143+ Args:
144+ hash (str): The hash value of the GraphModule.
145+ entry (Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]): The cache entry to be inserted.
146+ *args: Variable length argument list passed to ``save``.
147+ **kwargs: Arbitrary keyword arguments passed to ``save``.
148+
149+ Returns:
150+ None
151+ """
152+ packed_cache_info = BaseEngineCache .pack (* entry )
153+ return self .save (hash , packed_cache_info , * args , ** kwargs )
154+
155+ def check (self , hash : str , * args : Any , ** kwargs : Any ) -> Optional [UnpackedCacheHit ]:
156+ """
157+ Check if a cache entry exists for the given hash.
158+
159+ Args:
160+ hash (str): The hash value of the GraphModule.
161+ *args: Variable length argument list passed to ``load``.
162+ **kwargs: Arbitrary keyword arguments passed to ``load``.
163+
164+ Returns:
165+ Optional[Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]]: The unpacked cache entry if found, None otherwise.
166+ """
167+ packed_cache_info = self .load (hash , * args , ** kwargs )
168+
169+ if packed_cache_info :
170+ return BaseEngineCache .unpack (packed_cache_info )
171+ else :
172+ return None
173+
93174 @abstractmethod
94175 def save (self , hash : str , blob : bytes , * args : Any , ** kwargs : Any ) -> None :
95176 """Store blob in cache
@@ -203,11 +284,7 @@ def LRU() -> None:
203284 else :
204285 LRU ()
205286
206- def save (
207- self ,
208- hash : str ,
209- blob : bytes ,
210- ) -> None :
287+ def save (self , hash : str , blob : bytes , * args : Any , ** kwargs : Any ) -> None :
211288 blob_size = len (blob )
212289 if blob_size > self .total_engine_cache_size :
213290 _LOGGER .warning (
@@ -244,7 +321,7 @@ def save(
244321 f"The size { blob_size } is still larger than the available cache size { self .available_engine_cache_size } ."
245322 )
246323
247- def load (self , hash : str ) -> Optional [bytes ]:
324+ def load (self , hash : str , * args : Any , ** kwargs : Any ) -> Optional [bytes ]:
248325 directory = os .path .join (self .engine_cache_dir , hash )
249326 if os .path .exists (directory ):
250327 blob_path = os .path .join (directory , "blob.bin" )
0 commit comments