Skip to content

Commit

Permalink
Merge pull request #112 from Haidra-Org/main
Browse files Browse the repository at this point in the history
feat: miscellaneous model manager support
  • Loading branch information
tazlin authored Jun 23, 2024
2 parents 507bea9 + 3793d1d commit bf08ac8
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 12 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/maintests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ jobs:
- uses: actions/checkout@v3
- name: Run pre-commit
uses: pre-commit/action@v3.0.0
with:
extra_args: --all-files

no_extra_fields:
env:
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/prtests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ jobs:
- uses: actions/checkout@v3
- name: Run pre-commit
uses: pre-commit/action@v3.0.0
with:
extra_args: --all-files

no_extra_fields:
env:
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 24.2.0
rev: 24.4.2
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
rev: v0.4.7
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.8.0'
rev: 'v1.10.0'
hooks:
- id: mypy
additional_dependencies: [pydantic, types-requests, types-pytz, types-setuptools, types-urllib3, StrEnum]
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,5 @@ class Staging_StableDiffusion_ModelReference(Legacy_Generic_ModelReference):
MODEL_REFERENCE_CATEGORY.gfpgan: StagingLegacy_Generic_ModelRecord,
MODEL_REFERENCE_CATEGORY.safety_checker: StagingLegacy_Generic_ModelRecord,
MODEL_REFERENCE_CATEGORY.stable_diffusion: Legacy_StableDiffusion_ModelRecord,
MODEL_REFERENCE_CATEGORY.miscellaneous: StagingLegacy_Generic_ModelRecord,
}
4 changes: 4 additions & 0 deletions horde_model_reference/meta_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class MODEL_REFERENCE_CATEGORY(StrEnum):
gfpgan = auto()
safety_checker = auto()
stable_diffusion = auto()
miscellaneous = auto()


class MODEL_PURPOSE(StrEnum):
Expand All @@ -78,6 +79,8 @@ class MODEL_PURPOSE(StrEnum):
post_processor = auto()
"""The model is a post processor (after image generation) of some variety."""

miscellaneous = auto()


class STABLE_DIFFUSION_BASELINE_CATEGORY(StrEnum):
"""An enum of all the stable diffusion baselines."""
Expand All @@ -98,4 +101,5 @@ class STABLE_DIFFUSION_BASELINE_CATEGORY(StrEnum):
MODEL_REFERENCE_CATEGORY.gfpgan: MODEL_PURPOSE.post_processor,
MODEL_REFERENCE_CATEGORY.safety_checker: MODEL_PURPOSE.post_processor,
MODEL_REFERENCE_CATEGORY.stable_diffusion: MODEL_PURPOSE.image_generation,
MODEL_REFERENCE_CATEGORY.miscellaneous: MODEL_PURPOSE.miscellaneous,
}
2 changes: 2 additions & 0 deletions horde_model_reference/model_reference_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ class ControlNet_ModelReference(Generic_ModelReference):
MODEL_REFERENCE_CATEGORY.gfpgan: Generic_ModelRecord,
MODEL_REFERENCE_CATEGORY.safety_checker: Generic_ModelRecord,
MODEL_REFERENCE_CATEGORY.codeformer: Generic_ModelRecord,
MODEL_REFERENCE_CATEGORY.miscellaneous: Generic_ModelRecord,
}
"""A lookup for the model record type based on the model category. See also `MODEL_REFERENCE_TYPE_LOOKUP`."""

Expand All @@ -264,5 +265,6 @@ class ControlNet_ModelReference(Generic_ModelReference):
MODEL_REFERENCE_CATEGORY.gfpgan: Generic_ModelReference,
MODEL_REFERENCE_CATEGORY.safety_checker: Generic_ModelReference,
MODEL_REFERENCE_CATEGORY.codeformer: Generic_ModelReference,
MODEL_REFERENCE_CATEGORY.miscellaneous: Generic_ModelReference,
}
"""A lookup for the model reference type based on the model category. See also `MODEL_REFERENCE_RECORD_TYPE_LOOKUP`."""
12 changes: 6 additions & 6 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pytest==8.0.1
mypy==1.8.0
black==24.2.0
ruff==0.2.2
tox~=4.13.0
pre-commit~=3.6.2
pytest==8.2.1
mypy==1.10.0
black==24.4.2
ruff==0.4.7
tox~=4.15.0
pre-commit~=3.7.1
build>=0.10.0
coverage>=7.2.7

Expand Down
3 changes: 2 additions & 1 deletion stable_diffusion.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
"controlnet",
"clip",
"blip",
"post_processor"
"post_processor",
"miscellaneous"
],
"title": "MODEL_PURPOSE",
"type": "string"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def test_download_all_model_references(base_path_for_tests: Path):
reference_download_manager = LegacyReferenceDownloadManager(base_path=base_path_for_tests)
download_models = reference_download_manager.download_all_legacy_model_references(overwrite_existing=True)
assert len(download_models) == 8
assert len(download_models) == len(MODEL_REFERENCE_CATEGORY.__members__)


def test_validate_stable_diffusion_model_reference(legacy_folder_for_tests: Path):
Expand Down

0 comments on commit bf08ac8

Please sign in to comment.