Skip to content

Commit 2750c6c

Browse files
authored
misc: checksum check when downloading artifacts (#1761)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description checks the sha256 hash when downloading cubins from artifactory, using the generated `checksum.txt` in each cubin directory. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 80bdea5 commit 2750c6c

File tree

5 files changed

+124
-31
lines changed

5 files changed

+124
-31
lines changed

β€Žflashinfer/artifacts.pyβ€Ž

Lines changed: 88 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@
3232

3333
from .jit.cubin_loader import (
3434
FLASHINFER_CUBINS_REPOSITORY,
35-
download_file,
3635
safe_urljoin,
3736
FLASHINFER_CUBIN_DIR,
37+
download_file,
38+
verify_cubin,
3839
)
3940

4041

@@ -78,50 +79,109 @@ def get_available_cubin_files(
7879
return tuple()
7980

8081

81-
@dataclass(frozen=True)
8282
class ArtifactPath:
83-
TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen"
83+
TRTLLM_GEN_FMHA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/trtllm-gen/"
8484
TRTLLM_GEN_BMM: str = (
8585
"56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23"
8686
)
8787
TRTLLM_GEN_GEMM: str = (
8888
"1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3"
8989
)
90-
CUDNN_SDPA: str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn"
91-
DEEPGEMM: str = "51d730202c9eef782f06ecc950005331d85c5d4b/deep-gemm"
90+
CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
91+
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
9292

9393

9494
@dataclass(frozen=True)
9595
class MetaInfoHash:
96+
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
9697
TRTLLM_GEN_FMHA: str = (
97-
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
98+
"d26dbf837f40ff2dcd964094ab6e1b3f2424edda5979c313f5262655161fce98"
9899
)
99100
TRTLLM_GEN_BMM: str = (
100101
"4a8ceeb356fc5339021acf884061e97e49e01da5c75dbf0f7cf4932c37a70152"
101102
)
102-
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
103103
TRTLLM_GEN_GEMM: str = (
104104
"bd5c3227bec4f8d7a7d3a27fd7628e010d99a5c42651d0a6b97e146803e63340"
105105
)
106106

107107

108-
def get_cubin_file_list() -> Generator[str, None, None]:
108+
class CheckSumHash:
109+
TRTLLM_GEN_FMHA: str = (
110+
"b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
111+
)
112+
TRTLLM_GEN_BMM: str = (
113+
"8df2aae8f3aa39d64d2c723e775640beb4ac602a6cbb02e497c2a7316e349934"
114+
)
115+
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
116+
TRTLLM_GEN_GEMM: str = (
117+
"15cb8c85dfb5eddd4f121d64cb5a718321fb55b85aa19df10ddc1329d4a726b9"
118+
)
119+
map_checksums: dict[str, str] = {
120+
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA,
121+
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "checksums.txt"): TRTLLM_GEN_BMM,
122+
safe_urljoin(ArtifactPath.DEEPGEMM, "checksums.txt"): DEEPGEMM,
123+
safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "checksums.txt"): TRTLLM_GEN_GEMM,
124+
}
125+
126+
127+
def get_checksums(subdirs):
128+
checksums = {}
129+
for subdir in subdirs:
130+
uri = safe_urljoin(
131+
FLASHINFER_CUBINS_REPOSITORY, safe_urljoin(subdir, "checksums.txt")
132+
)
133+
checksum_path = FLASHINFER_CUBIN_DIR / safe_urljoin(subdir, "checksums.txt")
134+
download_file(uri, checksum_path)
135+
with open(checksum_path, "r") as f:
136+
for line in f:
137+
sha256, filename = line.strip().split()
138+
139+
# Distinguish between all meta info header files
140+
if ".h" in filename:
141+
filename = safe_urljoin(subdir, filename)
142+
checksums[filename] = sha256
143+
return checksums
144+
145+
146+
def get_subdir_file_list() -> Generator[tuple[str, str], None, None]:
109147
base = FLASHINFER_CUBINS_REPOSITORY
110148

111-
# The meta info header files first.
112-
yield safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")
113-
yield safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")
114-
yield safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")
115-
116-
# All the actual kernel cubin's.
117-
for kernel in [
149+
cubin_dirs = [
118150
ArtifactPath.TRTLLM_GEN_FMHA,
119151
ArtifactPath.TRTLLM_GEN_BMM,
120152
ArtifactPath.TRTLLM_GEN_GEMM,
121153
ArtifactPath.DEEPGEMM,
122-
]:
123-
for name in get_available_cubin_files(safe_urljoin(base, kernel)):
124-
yield safe_urljoin(kernel, name)
154+
]
155+
156+
# Get checksums of all files
157+
checksums = get_checksums(cubin_dirs)
158+
159+
# The meta info header files first.
160+
yield (
161+
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h"),
162+
checksums[
163+
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")
164+
],
165+
)
166+
yield (
167+
safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h"),
168+
checksums[
169+
safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")
170+
],
171+
)
172+
yield (
173+
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h"),
174+
checksums[
175+
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")
176+
],
177+
)
178+
179+
# All the actual kernel cubin's.
180+
for cubin_dir in cubin_dirs:
181+
checksum_path = safe_urljoin(cubin_dir, "checksums.txt")
182+
yield (checksum_path, CheckSumHash.map_checksums[checksum_path])
183+
for name in get_available_cubin_files(safe_urljoin(base, cubin_dir)):
184+
yield (safe_urljoin(cubin_dir, name), checksums[name])
125185

126186

127187
def download_artifacts() -> None:
@@ -130,8 +190,7 @@ def download_artifacts() -> None:
130190
# use a shared session to make use of HTTP keep-alive and reuse of
131191
# HTTPS connections.
132192
session = requests.Session()
133-
134-
cubin_files = list(get_cubin_file_list())
193+
cubin_files = list(get_subdir_file_list())
135194
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
136195
with tqdm_logging_redirect(
137196
total=len(cubin_files), desc="Downloading cubins"
@@ -142,7 +201,7 @@ def update_pbar_cb(_) -> None:
142201

143202
with ThreadPoolExecutor(num_threads) as pool:
144203
futures = []
145-
for name in cubin_files:
204+
for name, _ in cubin_files:
146205
source = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, name)
147206
local_path = FLASHINFER_CUBIN_DIR / name
148207
# Ensure parent directory exists
@@ -159,13 +218,19 @@ def update_pbar_cb(_) -> None:
159218
if not all_success:
160219
raise RuntimeError("Failed to download cubins")
161220

221+
# Check checksums of all downloaded cubins
222+
for name, checksum in cubin_files:
223+
local_path = FLASHINFER_CUBIN_DIR / name
224+
if not verify_cubin(str(local_path), checksum):
225+
raise RuntimeError("Failed to download cubins: checksum mismatch")
226+
162227

163228
def get_artifacts_status() -> tuple[tuple[str, bool], ...]:
164229
"""
165230
Check which cubins are already downloaded and return (num_downloaded, total).
166231
Does not download any cubins.
167232
"""
168-
cubin_files = get_cubin_file_list()
233+
cubin_files = get_subdir_file_list()
169234

170235
def _check_file_status(file_name: str) -> tuple[str, bool]:
171236
# get_cubin stores cubins in FLASHINFER_CUBIN_DIR with the same relative path
@@ -174,7 +239,7 @@ def _check_file_status(file_name: str) -> tuple[str, bool]:
174239
exists = os.path.isfile(local_path)
175240
return (file_name, exists)
176241

177-
return tuple(_check_file_status(file_name) for file_name in cubin_files)
242+
return tuple(_check_file_status(file_name) for file_name, _ in cubin_files)
178243

179244

180245
def clear_cubin():

β€Žflashinfer/jit/attention/modules.pyβ€Ž

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1575,7 +1575,8 @@ def gen_trtllm_gen_fmha_module():
15751575

15761576
# use `get_cubin` to get "flashinferMetaInfo.h"
15771577
metainfo = get_cubin(
1578-
f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_FMHA
1578+
f"{include_path}/{header_name}.h",
1579+
MetaInfoHash.TRTLLM_GEN_FMHA,
15791580
)
15801581

15811582
# make sure "flashinferMetaInfo.h" is downloaded or cached

β€Žflashinfer/jit/cubin_loader.pyβ€Ž

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,35 @@ def download_file(
136136
return False
137137

138138

139+
def get_meta_hash(checksum_path: str) -> str:
140+
"""
141+
Load the file from local cache (checksums.txt)
142+
and get the hash of corresponding flashinferMetaInfo.h file
143+
"""
144+
local_path = FLASHINFER_CUBIN_DIR / safe_urljoin(checksum_path, "checksums.txt")
145+
with open(local_path, "r") as f:
146+
for line in f:
147+
sha256, filename = line.strip().split()
148+
if ".h" in filename:
149+
return sha256
150+
raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}")
151+
152+
153+
def verify_cubin(cubin_path: str, expected_sha256: str) -> bool:
154+
"""
155+
Verify the cubin file against the sha256 checksum.
156+
"""
157+
with open(cubin_path, "rb") as f:
158+
data = f.read()
159+
actual_sha256 = hashlib.sha256(data).hexdigest()
160+
if actual_sha256 != expected_sha256:
161+
logger.warning(
162+
f"sha256 mismatch (expected {expected_sha256} actual {actual_sha256}) for {cubin_path}"
163+
)
164+
return False
165+
return True
166+
167+
139168
def load_cubin(cubin_path: str, sha256: str) -> bytes:
140169
"""
141170
Load a cubin from the provide local path and

β€Žflashinfer/jit/fused_moe.pyβ€Ž

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:
181181
header_name = "flashinferMetaInfo"
182182

183183
# use `get_cubin` to get "flashinferMetaInfo.h"
184-
metainfo = get_cubin(f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_BMM)
184+
metainfo = get_cubin(
185+
f"{include_path}/{header_name}.h",
186+
MetaInfoHash.TRTLLM_GEN_BMM,
187+
)
185188
# make sure "flashinferMetaInfo.h" is downloaded or cached
186189
assert metainfo, f"{header_name}.h not found"
187190

β€Žtests/utils/test_load_cubin_compile_race_condition.pyβ€Ž

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,12 @@ def worker_process(temp_dir):
3535
os.environ["FLASHINFER_CUBIN_DIR"] = temp_dir
3636

3737
# Import here to ensure FLASHINFER_CUBIN_DIR is set before module loads
38-
from flashinfer.artifacts import ArtifactPath, MetaInfoHash
39-
from flashinfer.jit.cubin_loader import get_cubin
38+
from flashinfer.artifacts import ArtifactPath
4039

4140
# Define the target file - same for all processes
4241
include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include"
4342
header_name = "flashinferMetaInfo"
4443

45-
# Use get_cubin to get "flashinferMetaInfo.h"
46-
# Note: all processes target the same file name
47-
metainfo = get_cubin(f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_BMM) # noqa: F841
48-
4944
# Read the file from FLASHINFER_CUBIN_DIR
5045
# NOTE(Zihao): instead of using metainfo, we directly read from the file path,
5146
# that aligns with how we compile the kernel.

0 commit comments

Comments
Β (0)