From 3793d1d99e3161216403f981bb8f27095e9a6edb Mon Sep 17 00:00:00 2001 From: tazlin Date: Sun, 23 Jun 2024 10:55:18 -0400 Subject: [PATCH] feat: `miscellaneous` model manager support --- .../legacy/classes/staging_model_database_records.py | 1 + horde_model_reference/meta_consts.py | 4 ++++ horde_model_reference/model_reference_records.py | 2 ++ stable_diffusion.schema.json | 3 ++- tests/test_scripts.py | 2 +- 5 files changed, 10 insertions(+), 2 deletions(-) diff --git a/horde_model_reference/legacy/classes/staging_model_database_records.py b/horde_model_reference/legacy/classes/staging_model_database_records.py index 06892c1..05171bb 100644 --- a/horde_model_reference/legacy/classes/staging_model_database_records.py +++ b/horde_model_reference/legacy/classes/staging_model_database_records.py @@ -126,4 +126,5 @@ class Staging_StableDiffusion_ModelReference(Legacy_Generic_ModelReference): MODEL_REFERENCE_CATEGORY.gfpgan: StagingLegacy_Generic_ModelRecord, MODEL_REFERENCE_CATEGORY.safety_checker: StagingLegacy_Generic_ModelRecord, MODEL_REFERENCE_CATEGORY.stable_diffusion: Legacy_StableDiffusion_ModelRecord, + MODEL_REFERENCE_CATEGORY.miscellaneous: StagingLegacy_Generic_ModelRecord, } diff --git a/horde_model_reference/meta_consts.py b/horde_model_reference/meta_consts.py index 51c2f98..9a76858 100644 --- a/horde_model_reference/meta_consts.py +++ b/horde_model_reference/meta_consts.py @@ -60,6 +60,7 @@ class MODEL_REFERENCE_CATEGORY(StrEnum): gfpgan = auto() safety_checker = auto() stable_diffusion = auto() + miscellaneous = auto() class MODEL_PURPOSE(StrEnum): @@ -78,6 +79,8 @@ class MODEL_PURPOSE(StrEnum): post_processor = auto() """The model is a post processor (after image generation) of some variety.""" + miscellaneous = auto() + class STABLE_DIFFUSION_BASELINE_CATEGORY(StrEnum): """An enum of all the stable diffusion baselines.""" @@ -98,4 +101,5 @@ class STABLE_DIFFUSION_BASELINE_CATEGORY(StrEnum): MODEL_REFERENCE_CATEGORY.gfpgan: MODEL_PURPOSE.post_processor, MODEL_REFERENCE_CATEGORY.safety_checker: MODEL_PURPOSE.post_processor, MODEL_REFERENCE_CATEGORY.stable_diffusion: MODEL_PURPOSE.image_generation, + MODEL_REFERENCE_CATEGORY.miscellaneous: MODEL_PURPOSE.miscellaneous, } diff --git a/horde_model_reference/model_reference_records.py b/horde_model_reference/model_reference_records.py index b8db3a7..72ef8ff 100644 --- a/horde_model_reference/model_reference_records.py +++ b/horde_model_reference/model_reference_records.py @@ -252,6 +252,7 @@ class ControlNet_ModelReference(Generic_ModelReference): MODEL_REFERENCE_CATEGORY.gfpgan: Generic_ModelRecord, MODEL_REFERENCE_CATEGORY.safety_checker: Generic_ModelRecord, MODEL_REFERENCE_CATEGORY.codeformer: Generic_ModelRecord, + MODEL_REFERENCE_CATEGORY.miscellaneous: Generic_ModelRecord, } """A lookup for the model record type based on the model category. See also `MODEL_REFERENCE_TYPE_LOOKUP`.""" @@ -264,5 +265,6 @@ class ControlNet_ModelReference(Generic_ModelReference): MODEL_REFERENCE_CATEGORY.gfpgan: Generic_ModelReference, MODEL_REFERENCE_CATEGORY.safety_checker: Generic_ModelReference, MODEL_REFERENCE_CATEGORY.codeformer: Generic_ModelReference, + MODEL_REFERENCE_CATEGORY.miscellaneous: Generic_ModelReference, } """A lookup for the model reference type based on the model category. See also `MODEL_REFERENCE_RECORD_TYPE_LOOKUP`.""" diff --git a/stable_diffusion.schema.json b/stable_diffusion.schema.json index 3475fb6..5ee10b1 100644 --- a/stable_diffusion.schema.json +++ b/stable_diffusion.schema.json @@ -54,7 +54,8 @@ "controlnet", "clip", "blip", - "post_processor" + "post_processor", + "miscellaneous" ], "title": "MODEL_PURPOSE", "type": "string" diff --git a/tests/test_scripts.py b/tests/test_scripts.py index d6cf08d..1032f01 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -10,7 +10,7 @@ def test_download_all_model_references(base_path_for_tests: Path): reference_download_manager = LegacyReferenceDownloadManager(base_path=base_path_for_tests) download_models = reference_download_manager.download_all_legacy_model_references(overwrite_existing=True) - assert len(download_models) == 8 + assert len(download_models) == len(MODEL_REFERENCE_CATEGORY.__members__) def test_validate_stable_diffusion_model_reference(legacy_folder_for_tests: Path):