3232
3333from .jit .cubin_loader import (
3434 FLASHINFER_CUBINS_REPOSITORY ,
35- download_file ,
3635 safe_urljoin ,
3736 FLASHINFER_CUBIN_DIR ,
37+ download_file ,
38+ verify_cubin ,
3839)
3940
4041
@@ -78,50 +79,109 @@ def get_available_cubin_files(
7879 return tuple ()
7980
8081
81- @dataclass (frozen = True )
8282class ArtifactPath :
83- TRTLLM_GEN_FMHA : str = "7206d64e67f4c8949286246d6e2e07706af5d223 /fmha/trtllm-gen"
83+ TRTLLM_GEN_FMHA : str = "a72d85b019dc125b9f711300cb989430f762f5a6 /fmha/trtllm-gen/ "
8484 TRTLLM_GEN_BMM : str = (
8585 "56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23"
8686 )
8787 TRTLLM_GEN_GEMM : str = (
8888 "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3"
8989 )
90- CUDNN_SDPA : str = "4c623163877c8fef5751c9c7a59940cd2baae02e /fmha/cudnn"
91- DEEPGEMM : str = "51d730202c9eef782f06ecc950005331d85c5d4b /deep-gemm"
90+ CUDNN_SDPA : str = "a72d85b019dc125b9f711300cb989430f762f5a6 /fmha/cudnn/ "
91+ DEEPGEMM : str = "a72d85b019dc125b9f711300cb989430f762f5a6 /deep-gemm/ "
9292
9393
9494@dataclass (frozen = True )
9595class MetaInfoHash :
96+ DEEPGEMM : str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
9697 TRTLLM_GEN_FMHA : str = (
97- "2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d "
98+ "d26dbf837f40ff2dcd964094ab6e1b3f2424edda5979c313f5262655161fce98 "
9899 )
99100 TRTLLM_GEN_BMM : str = (
100101 "4a8ceeb356fc5339021acf884061e97e49e01da5c75dbf0f7cf4932c37a70152"
101102 )
102- DEEPGEMM : str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
103103 TRTLLM_GEN_GEMM : str = (
104104 "bd5c3227bec4f8d7a7d3a27fd7628e010d99a5c42651d0a6b97e146803e63340"
105105 )
106106
107107
108- def get_cubin_file_list () -> Generator [str , None , None ]:
108+ class CheckSumHash :
109+ TRTLLM_GEN_FMHA : str = (
110+ "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
111+ )
112+ TRTLLM_GEN_BMM : str = (
113+ "8df2aae8f3aa39d64d2c723e775640beb4ac602a6cbb02e497c2a7316e349934"
114+ )
115+ DEEPGEMM : str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
116+ TRTLLM_GEN_GEMM : str = (
117+ "15cb8c85dfb5eddd4f121d64cb5a718321fb55b85aa19df10ddc1329d4a726b9"
118+ )
119+ map_checksums : dict [str , str ] = {
120+ safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "checksums.txt" ): TRTLLM_GEN_FMHA ,
121+ safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "checksums.txt" ): TRTLLM_GEN_BMM ,
122+ safe_urljoin (ArtifactPath .DEEPGEMM , "checksums.txt" ): DEEPGEMM ,
123+ safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "checksums.txt" ): TRTLLM_GEN_GEMM ,
124+ }
125+
126+
127+ def get_checksums (subdirs ):
128+ checksums = {}
129+ for subdir in subdirs :
130+ uri = safe_urljoin (
131+ FLASHINFER_CUBINS_REPOSITORY , safe_urljoin (subdir , "checksums.txt" )
132+ )
133+ checksum_path = FLASHINFER_CUBIN_DIR / safe_urljoin (subdir , "checksums.txt" )
134+ download_file (uri , checksum_path )
135+ with open (checksum_path , "r" ) as f :
136+ for line in f :
137+ sha256 , filename = line .strip ().split ()
138+
139+ # Distinguish between all meta info header files
140+ if ".h" in filename :
141+ filename = safe_urljoin (subdir , filename )
142+ checksums [filename ] = sha256
143+ return checksums
144+
145+
146+ def get_subdir_file_list () -> Generator [tuple [str , str ], None , None ]:
109147 base = FLASHINFER_CUBINS_REPOSITORY
110148
111- # The meta info header files first.
112- yield safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "include/flashInferMetaInfo.h" )
113- yield safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "include/flashinferMetaInfo.h" )
114- yield safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "include/flashinferMetaInfo.h" )
115-
116- # All the actual kernel cubin's.
117- for kernel in [
149+ cubin_dirs = [
118150 ArtifactPath .TRTLLM_GEN_FMHA ,
119151 ArtifactPath .TRTLLM_GEN_BMM ,
120152 ArtifactPath .TRTLLM_GEN_GEMM ,
121153 ArtifactPath .DEEPGEMM ,
122- ]:
123- for name in get_available_cubin_files (safe_urljoin (base , kernel )):
124- yield safe_urljoin (kernel , name )
154+ ]
155+
156+ # Get checksums of all files
157+ checksums = get_checksums (cubin_dirs )
158+
159+ # The meta info header files first.
160+ yield (
161+ safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "include/flashInferMetaInfo.h" ),
162+ checksums [
163+ safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "include/flashInferMetaInfo.h" )
164+ ],
165+ )
166+ yield (
167+ safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "include/flashinferMetaInfo.h" ),
168+ checksums [
169+ safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "include/flashinferMetaInfo.h" )
170+ ],
171+ )
172+ yield (
173+ safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "include/flashinferMetaInfo.h" ),
174+ checksums [
175+ safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "include/flashinferMetaInfo.h" )
176+ ],
177+ )
178+
179+ # All the actual kernel cubin's.
180+ for cubin_dir in cubin_dirs :
181+ checksum_path = safe_urljoin (cubin_dir , "checksums.txt" )
182+ yield (checksum_path , CheckSumHash .map_checksums [checksum_path ])
183+ for name in get_available_cubin_files (safe_urljoin (base , cubin_dir )):
184+ yield (safe_urljoin (cubin_dir , name ), checksums [name ])
125185
126186
127187def download_artifacts () -> None :
@@ -130,8 +190,7 @@ def download_artifacts() -> None:
130190 # use a shared session to make use of HTTP keep-alive and reuse of
131191 # HTTPS connections.
132192 session = requests .Session ()
133-
134- cubin_files = list (get_cubin_file_list ())
193+ cubin_files = list (get_subdir_file_list ())
135194 num_threads = int (os .environ .get ("FLASHINFER_CUBIN_DOWNLOAD_THREADS" , "4" ))
136195 with tqdm_logging_redirect (
137196 total = len (cubin_files ), desc = "Downloading cubins"
@@ -142,7 +201,7 @@ def update_pbar_cb(_) -> None:
142201
143202 with ThreadPoolExecutor (num_threads ) as pool :
144203 futures = []
145- for name in cubin_files :
204+ for name , _ in cubin_files :
146205 source = safe_urljoin (FLASHINFER_CUBINS_REPOSITORY , name )
147206 local_path = FLASHINFER_CUBIN_DIR / name
148207 # Ensure parent directory exists
@@ -159,13 +218,19 @@ def update_pbar_cb(_) -> None:
159218 if not all_success :
160219 raise RuntimeError ("Failed to download cubins" )
161220
221+ # Check checksums of all downloaded cubins
222+ for name , checksum in cubin_files :
223+ local_path = FLASHINFER_CUBIN_DIR / name
224+ if not verify_cubin (str (local_path ), checksum ):
225+ raise RuntimeError ("Failed to download cubins: checksum mismatch" )
226+
162227
163228def get_artifacts_status () -> tuple [tuple [str , bool ], ...]:
164229 """
165230 Check which cubins are already downloaded and return (num_downloaded, total).
166231 Does not download any cubins.
167232 """
168- cubin_files = get_cubin_file_list ()
233+ cubin_files = get_subdir_file_list ()
169234
170235 def _check_file_status (file_name : str ) -> tuple [str , bool ]:
171236 # get_cubin stores cubins in FLASHINFER_CUBIN_DIR with the same relative path
@@ -174,7 +239,7 @@ def _check_file_status(file_name: str) -> tuple[str, bool]:
174239 exists = os .path .isfile (local_path )
175240 return (file_name , exists )
176241
177- return tuple (_check_file_status (file_name ) for file_name in cubin_files )
242+ return tuple (_check_file_status (file_name ) for file_name , _ in cubin_files )
178243
179244
180245def clear_cubin ():
0 commit comments