Skip to content

Commit

Permalink
feat: miscellaneous model manager support
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Jun 23, 2024
1 parent 319265b commit 3793d1d
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,5 @@ class Staging_StableDiffusion_ModelReference(Legacy_Generic_ModelReference):
MODEL_REFERENCE_CATEGORY.gfpgan: StagingLegacy_Generic_ModelRecord,
MODEL_REFERENCE_CATEGORY.safety_checker: StagingLegacy_Generic_ModelRecord,
MODEL_REFERENCE_CATEGORY.stable_diffusion: Legacy_StableDiffusion_ModelRecord,
MODEL_REFERENCE_CATEGORY.miscellaneous: StagingLegacy_Generic_ModelRecord,
}
4 changes: 4 additions & 0 deletions horde_model_reference/meta_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class MODEL_REFERENCE_CATEGORY(StrEnum):
gfpgan = auto()
safety_checker = auto()
stable_diffusion = auto()
miscellaneous = auto()


class MODEL_PURPOSE(StrEnum):
Expand All @@ -78,6 +79,8 @@ class MODEL_PURPOSE(StrEnum):
post_processor = auto()
"""The model is a post processor (after image generation) of some variety."""

miscellaneous = auto()


class STABLE_DIFFUSION_BASELINE_CATEGORY(StrEnum):
"""An enum of all the stable diffusion baselines."""
Expand All @@ -98,4 +101,5 @@ class STABLE_DIFFUSION_BASELINE_CATEGORY(StrEnum):
MODEL_REFERENCE_CATEGORY.gfpgan: MODEL_PURPOSE.post_processor,
MODEL_REFERENCE_CATEGORY.safety_checker: MODEL_PURPOSE.post_processor,
MODEL_REFERENCE_CATEGORY.stable_diffusion: MODEL_PURPOSE.image_generation,
MODEL_REFERENCE_CATEGORY.miscellaneous: MODEL_PURPOSE.miscellaneous,
}
2 changes: 2 additions & 0 deletions horde_model_reference/model_reference_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ class ControlNet_ModelReference(Generic_ModelReference):
MODEL_REFERENCE_CATEGORY.gfpgan: Generic_ModelRecord,
MODEL_REFERENCE_CATEGORY.safety_checker: Generic_ModelRecord,
MODEL_REFERENCE_CATEGORY.codeformer: Generic_ModelRecord,
MODEL_REFERENCE_CATEGORY.miscellaneous: Generic_ModelRecord,
}
"""A lookup for the model record type based on the model category. See also `MODEL_REFERENCE_TYPE_LOOKUP`."""

Expand All @@ -264,5 +265,6 @@ class ControlNet_ModelReference(Generic_ModelReference):
MODEL_REFERENCE_CATEGORY.gfpgan: Generic_ModelReference,
MODEL_REFERENCE_CATEGORY.safety_checker: Generic_ModelReference,
MODEL_REFERENCE_CATEGORY.codeformer: Generic_ModelReference,
MODEL_REFERENCE_CATEGORY.miscellaneous: Generic_ModelReference,
}
"""A lookup for the model reference type based on the model category. See also `MODEL_REFERENCE_RECORD_TYPE_LOOKUP`."""
3 changes: 2 additions & 1 deletion stable_diffusion.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
"controlnet",
"clip",
"blip",
"post_processor"
"post_processor",
"miscellaneous"
],
"title": "MODEL_PURPOSE",
"type": "string"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def test_download_all_model_references(base_path_for_tests: Path):
reference_download_manager = LegacyReferenceDownloadManager(base_path=base_path_for_tests)
download_models = reference_download_manager.download_all_legacy_model_references(overwrite_existing=True)
assert len(download_models) == 8
assert len(download_models) == len(MODEL_REFERENCE_CATEGORY.__members__)


def test_validate_stable_diffusion_model_reference(legacy_folder_for_tests: Path):
Expand Down

0 comments on commit 3793d1d

Please sign in to comment.