Skip to content

Commit d2d6e87

Browse files
committed
Add GCS ObjectStorageModel.pull_files test
Signed-off-by: Peter Schuurman <psch@google.com>
1 parent 40f2253 commit d2d6e87

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import glob
5+
import hashlib
56
import os
67
import tempfile
78

89
import huggingface_hub.constants
910

1011
from vllm.model_executor.model_loader.weight_utils import (
1112
download_weights_from_hf)
12-
from vllm.transformers_utils.runai_utils import (is_runai_obj_uri,
13+
from vllm.transformers_utils.runai_utils import (ObjectStorageModel,
14+
is_runai_obj_uri,
1315
list_safetensors)
1416

1517

@@ -34,6 +36,21 @@ def test_runai_list_safetensors_local():
3436
assert len(safetensors) == len(files)
3537

3638

37-
if __name__ == "__main__":
38-
test_is_runai_obj_uri()
39-
test_runai_list_safetensors_local()
39+
def test_runai_pull_files_gcs(monkeypatch):
40+
monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true")
41+
filename = "LT08_L1GT_074061_20130309_20170505_01_T2_MTL.txt"
42+
gcs_bucket = "gs://gcp-public-data-landsat/LT08/01/074/061/LT08_L1GT_074061_20130309_20170505_01_T2/"
43+
gcs_url = f"{gcs_bucket}/{filename}"
44+
model = ObjectStorageModel(gcs_url)
45+
model.pull_files(gcs_bucket, allow_pattern=[f"*{filename}"])
46+
# To re-generate / change URLs:
47+
# gsutil ls -L gs://<gcs-url> | grep "Hash (md5)" | tr -d ' ' \
48+
# | cut -d":" -f2 | base64 -d | xxd -p
49+
expected_checksum = "f60dea775da1392434275b311b31a431"
50+
hasher = hashlib.new("md5")
51+
with open(os.path.join(model.dir, filename), 'rb') as f:
52+
# Read the file in chunks to handle large files efficiently
53+
for chunk in iter(lambda: f.read(4096), b''):
54+
hasher.update(chunk)
55+
actual_checksum = hasher.hexdigest()
56+
assert actual_checksum == expected_checksum

0 commit comments

Comments
 (0)