Skip to content

Commit

Permalink
feat(mm): faster hashing for spinning disk HDDs
Browse files Browse the repository at this point in the history
BLAKE3 has poor performance on spinning disks when parallelized. See BLAKE3-team/BLAKE3#31

- Replace `skip_model_hash` setting with `hashing_algorithm`. Any algorithm we support is accepted.
- Add `random` algorithm: hashes a UUID with BLAKE3 to create a random "hash". Equivalent to the previous skip functionality.
- Add `blake3_single` algorithm: hashes on a single thread using BLAKE3, fixes the aforementioned performance issue
- Update model probe to accept the algorithm to hash with as an optional arg, defaulting to `blake3`
- Update all calls of the probe to use the app's configured hashing algorithm
- Update an external script that probes models
- Update tests
- Move ModelHash into its own module to avoid circuclar import issues
  • Loading branch information
psychedelicious committed Mar 14, 2024
1 parent 8287fcf commit eb6e654
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 33 deletions.
4 changes: 3 additions & 1 deletion invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ class InvokeBatch(InvokeAISettings):
from pydantic.config import JsonDict
from pydantic_settings import SettingsConfigDict

from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS

from .config_base import InvokeAISettings

INIT_FILE = Path("invokeai.yaml")
Expand Down Expand Up @@ -360,7 +362,7 @@ class InvokeAIAppConfig(InvokeAISettings):
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory.", json_schema_extra=Categories.Nodes)

# MODEL INSTALL
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models.", json_schema_extra=Categories.ModelInstall)
hashing_algorithm : HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'none' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.", json_schema_extra=Categories.ModelInstall)
remote_api_tokens : Optional[list[URLRegexToken]] = Field(
default=None,
description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.",
Expand Down
11 changes: 2 additions & 9 deletions invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
Expand Down Expand Up @@ -154,10 +153,7 @@ def install_path(
model_path = Path(model_path)
config = config or {}

if self._app_config.skip_model_hash:
config["hash"] = uuid_string()

info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)

if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
Expand Down Expand Up @@ -585,10 +581,7 @@ def _register(
) -> str:
config = config or {}

if self._app_config.skip_model_hash:
config["hash"] = uuid_string()

info = info or ModelProbe.probe(model_path, config)
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)

model_path = model_path.resolve()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Fast hashing of diffusers and checkpoint-style models.
Usage:
from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""

import hashlib
import os
Expand All @@ -15,9 +7,9 @@

from blake3 import blake3

MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
from invokeai.app.util.misc import uuid_string

ALGORITHM = Literal[
HASHING_ALGORITHMS = Literal[
"md5",
"sha1",
"sha224",
Expand All @@ -33,7 +25,10 @@
"shake_128",
"shake_256",
"blake3",
"blake3_single",
"random",
]
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")


class ModelHash:
Expand All @@ -53,6 +48,8 @@ class ModelHash:
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
A convenience algorithm choice of "random" is also available, which returns a random string. This is not a hash.
Usage:
```py
# BLAKE3 hash
Expand All @@ -62,11 +59,17 @@ class ModelHash:
```
"""

def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
def __init__(
self, algorithm: HASHING_ALGORITHMS = "blake3", file_filter: Optional[Callable[[str], bool]] = None
) -> None:
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm == "blake3_single":
self._hash_file = self._blake3_single
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
elif algorithm == "random":
self._hash_file = self._random
else:
raise ValueError(f"Algorithm {algorithm} not available")

Expand Down Expand Up @@ -137,7 +140,7 @@ def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> lis

@staticmethod
def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3
"""Hashes a file using BLAKE3, using parallelized and memory-mapped I/O to avoid reading the entire file into memory.
Args:
file_path: Path to the file to hash
Expand All @@ -150,7 +153,21 @@ def _blake3(file_path: Path) -> str:
return file_hasher.hexdigest()

@staticmethod
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
def _blake3_single(file_path: Path) -> str:
"""Hashes a file using BLAKE3, without parallelism. Suitable for spinning hard drives.
Args:
file_path: Path to the file to hash
Returns:
Hexdigest of the hash of the file
"""
file_hasher = blake3()
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()

@staticmethod
def _get_hashlib(algorithm: HASHING_ALGORITHMS) -> Callable[[Path], str]:
"""Factory function that returns a function to hash a file with the given algorithm.
Args:
Expand All @@ -172,6 +189,13 @@ def hashlib_hasher(file_path: Path) -> str:

return hashlib_hasher

@staticmethod
def _random(_file_path: Path) -> str:
"""Returns a random string. This is not a hash.
The string is a UUID, hashed with BLAKE3 to ensure that it is unique."""
return blake3(uuid_string().encode()).hexdigest()

@staticmethod
def _default_file_filter(file_path: str) -> bool:
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
Expand Down
8 changes: 3 additions & 5 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.util.util import SilenceWarnings

from .config import (
Expand All @@ -24,7 +25,6 @@
ModelVariantType,
SchedulerPredictionType,
)
from .hash import ModelHash
from .util.model_util import lora_token_vector_length, read_checkpoint_meta

CkptType = Dict[str, Any]
Expand Down Expand Up @@ -113,9 +113,7 @@ def register_probe(

@classmethod
def probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, hash_algo: HASHING_ALGORITHMS = "blake3"
) -> AnyModelConfig:
"""
Probe the model at model_path and return its configuration record.
Expand Down Expand Up @@ -160,7 +158,7 @@ def probe(
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)

fields["default_settings"] = (
fields.get("default_settings") or probe.get_default_settings(fields["name"])
Expand Down
12 changes: 11 additions & 1 deletion scripts/probe-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,30 @@

import argparse
from pathlib import Path
from typing import get_args

from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe

algos = ", ".join(set(get_args(HASHING_ALGORITHMS)))

parser = argparse.ArgumentParser(description="Probe model type")
parser.add_argument(
"model_path",
type=Path,
nargs="+",
)
parser.add_argument(
"--hash_algo",
type=str,
default="blake3",
help=f"Hashing algorithm to use (default: blake3), one of: {algos}",
)
args = parser.parse_args()

for path in args.model_path:
try:
info = ModelProbe.probe(path)
info = ModelProbe.probe(path, hash_algo=args.hash_algo)
print(f"{path}:{info.model_dump_json(indent=4)}")
except InvalidModelConfigException as exc:
print(exc)
26 changes: 22 additions & 4 deletions tests/test_model_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import pytest
from blake3 import blake3

from invokeai.backend.model_manager.hash import ALGORITHM, MODEL_FILE_EXTENSIONS, ModelHash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, MODEL_FILE_EXTENSIONS, ModelHash

test_cases: list[tuple[ALGORITHM, str]] = [
test_cases: list[tuple[HASHING_ALGORITHMS, str]] = [
("md5", "a0cd925fc063f98dbf029eee315060c3"),
("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"),
("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"),
Expand All @@ -21,15 +21,15 @@


@pytest.mark.parametrize("algorithm,expected_hash", test_cases)
def test_model_hash_hashes_file(tmp_path: Path, algorithm: ALGORITHM, expected_hash: str):
def test_model_hash_hashes_file(tmp_path: Path, algorithm: HASHING_ALGORITHMS, expected_hash: str):
file = Path(tmp_path / "test")
file.write_text("model data")
md5 = ModelHash(algorithm).hash(file)
assert md5 == expected_hash


@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"])
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM):
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: HASHING_ALGORITHMS):
model_hash = ModelHash(algorithm)
files = [Path(tmp_path, f"{i}.bin") for i in range(5)]

Expand All @@ -47,6 +47,24 @@ def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM):
assert md5 == composite_hasher.hexdigest()


def test_model_hash_blake3_matches_blake3_single(tmp_path: Path):
model_hash = ModelHash("blake3")
model_hash_simple = ModelHash("blake3_single")

file = tmp_path / "test.bin"
file.write_text("model data")

assert model_hash.hash(file) == model_hash_simple.hash(file)


def test_model_hash_random_algorithm(tmp_path: Path):
model_hash = ModelHash("random")
file = tmp_path / "test.bin"
file.write_text("model data")

assert model_hash.hash(file) != model_hash.hash(file)


def test_model_hash_raises_error_on_invalid_algorithm():
with pytest.raises(ValueError, match="Algorithm invalid_algorithm not available"):
ModelHash("invalid_algorithm") # pyright: ignore [reportArgumentType]
Expand Down

0 comments on commit eb6e654

Please sign in to comment.