Skip to content

Commit

Permalink
tests(mm): add tests for ModelHash
Browse files Browse the repository at this point in the history
  • Loading branch information
psychedelicious committed Feb 28, 2024
1 parent 47d8017 commit 4801979
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions tests/test_model_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# pyright:reportPrivateUsage=false

from pathlib import Path
from typing import Iterable

import pytest
from blake3 import blake3

from invokeai.backend.model_manager.hash import ALGORITHM, ModelHash

test_cases: list[tuple[ALGORITHM, str]] = [
("md5", "a0cd925fc063f98dbf029eee315060c3"),
("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"),
("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"),
(
"sha512",
"c4a10476b21e00042f638ad5755c561d91f2bb599d3504d25409495e1c7eda94543332a1a90fbb4efdaf9ee462c33e0336b5eae4acfb1fa0b186af452dd67dc6",
),
("blake3", "ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"),
]


@pytest.mark.parametrize("algorithm,expected_hash", test_cases)
def test_model_hash_hashes_file(tmp_path: Path, algorithm: ALGORITHM, expected_hash: str):
file = Path(tmp_path / "test")
file.write_text("model data")
md5 = ModelHash(algorithm).hash(file)
assert md5 == expected_hash


@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"])
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM):
model_hash = ModelHash(algorithm)
files = [Path(tmp_path, f"{i}.bin") for i in range(5)]

for f in files:
f.write_text("data")

md5 = model_hash.hash(tmp_path)

# Manual implementation of composite hash - always uses BLAKE3
composite_hasher = blake3()
for f in files:
h = model_hash.hash(f)
composite_hasher.update(h.encode("utf-8"))

assert md5 == composite_hasher.hexdigest()


def test_model_hash_raises_error_on_invalid_algorithm():
with pytest.raises(ValueError, match="Algorithm invalid_algorithm not available"):
ModelHash("invalid_algorithm") # pyright: ignore [reportArgumentType]


def paths_to_str_set(paths: Iterable[Path]) -> set[str]:
return {str(p) for p in paths}


def test_model_hash_filters_out_non_model_files(tmp_path: Path):
model_files = {
Path(tmp_path, f"{i}.{ext}") for i, ext in enumerate([".ckpt", ".safetensors", ".bin", ".pt", ".pth"])
}

for i, f in enumerate(model_files):
f.write_text(f"data{i}")

assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)

# Add file that should be ignored - hash should not change
file = tmp_path / "test.icecream"
file.write_text("data")

assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)

# Add file that should not be ignored - hash should change
file = tmp_path / "test.bin"
file.write_text("more data")
model_files.add(file)

assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)

0 comments on commit 4801979

Please sign in to comment.