Skip to content

Commit

Permalink
Merge pull request #81 from Haidra-Org/main
Browse files Browse the repository at this point in the history
fix: tolerate extra fields in legacy model reference file
  • Loading branch information
tazlin authored Mar 1, 2024
2 parents 73fe8e3 + 48a257e commit 9ef7cd7
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 12 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/maintests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,33 @@ jobs:
- name: Run pre-commit
uses: pre-commit/action@v3.0.0

no_extra_fields:
env:
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
HORDELIB_CI_ONGOING: "1"
TESTS_ONGOING: "1"
runs-on: ubuntu-latest
strategy:
matrix:
python: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
- name: Install any required packages
run: |
python -m pip install --upgrade pip
pip install --upgrade -r requirements.dev.txt
- name: Run no_extra_fields check # Enabled by HORDELIB_CI_ONGOING
run: tox -e tests

build:
env:
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
TESTS_ONGOING: "1"
runs-on: ubuntu-latest
strategy:
matrix:
Expand Down
25 changes: 25 additions & 0 deletions .github/workflows/prtests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,34 @@ jobs:
- name: Run pre-commit
uses: pre-commit/action@v3.0.0

no_extra_fields:
env:
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
HORDELIB_CI_ONGOING: "1"
TESTS_ONGOING: "1"
runs-on: ubuntu-latest
strategy:
matrix:
python: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
- name: Install any required packages
run: |
python -m pip install --upgrade pip
pip install --upgrade -r requirements.dev.txt
- name: Run no_extra_fields check # Enabled by HORDELIB_CI_ONGOING
run: tox -e tests
build:
env:
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
TESTS_ONGOING: "1"
runs-on: ubuntu-latest
strategy:
matrix:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class RawLegacy_StableDiffusion_ModelRecord(BaseModel):
# This is a better representation of the legacy model reference than the one in `staging_model_database_records.py`
# which is a hybrid representation of the legacy model reference and the new model reference format.

model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="allow")

name: str
baseline: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def validate_model_file_has_sha256sum(self):
class StagingLegacy_Config_DownloadRecord(BaseModel):
"""An entry in the `config` field of a `StagingLegacy_Generic_ModelRecord`."""

model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="allow")

file_name: str
file_path: str = ""
Expand All @@ -54,7 +54,7 @@ class StagingLegacy_Config_DownloadRecord(BaseModel):
class StagingLegacy_Generic_ModelRecord(BaseModel):
"""This is a helper class, a hybrid representation of the legacy model reference and the new format."""

model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="allow")

name: str
type: str
Expand Down Expand Up @@ -100,7 +100,7 @@ class Legacy_StableDiffusion_ModelRecord(StagingLegacy_Generic_ModelRecord):
class Legacy_Generic_ModelReference(BaseModel):
"""A helper class to convert the legacy model reference to the new model reference format."""

model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="allow")

models: Mapping[str, StagingLegacy_Generic_ModelRecord]

Expand Down
16 changes: 15 additions & 1 deletion horde_model_reference/legacy/validate_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
)


def validate_legacy_stable_diffusion_db(sd_db: Path, write_to_path: Path | None = None) -> bool:
def validate_legacy_stable_diffusion_db(
sd_db: Path,
write_to_path: Path | None = None,
fail_on_extra: bool = False,
) -> bool:
raw_json_sd_db: str
with open(sd_db) as sd_db_file:
raw_json_sd_db = sd_db_file.read()
Expand Down Expand Up @@ -42,8 +46,18 @@ def validate_legacy_stable_diffusion_db(sd_db: Path, write_to_path: Path | None
},
indent=4,
)

correct_json_layout += "\n" # Add a newline to the end of the file, for consistency with formatters.

any_extra_fields = False
for key, record in parsed_db_records.items():
if record.model_extra:
logger.error(f"Extra fields found in {key}: {record.model_extra}")
any_extra_fields = True

if any_extra_fields and fail_on_extra:
raise ValueError("Extra fields found in stable diffusion model database.")

if raw_json_sd_db != correct_json_layout:
logger.error("Invalid stable diffusion model database.")
if write_to_path:
Expand Down
2 changes: 1 addition & 1 deletion legacy_stable_diffusion.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
"type": "object"
},
"RawLegacy_StableDiffusion_ModelRecord": {
"additionalProperties": false,
"additionalProperties": true,
"description": "A model entry in the legacy model reference.",
"properties": {
"name": {
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import os
import sys
from pathlib import Path

import pytest
from loguru import logger

from horde_model_reference.path_consts import LEGACY_REFERENCE_FOLDER_NAME

os.environ["TESTS_ONGOING"] = "1"


@pytest.fixture(scope="session")
def env_var_checks() -> None:
"""Check for required environment variables."""

assert "TESTS_ONGOING" in os.environ, "Environment variable 'TESTS_ONGOING' not set."


@pytest.fixture(scope="session")
def base_path_for_tests() -> Path:
Expand All @@ -30,6 +41,9 @@ def setup_logging(base_path_for_tests: Path):
"sink": base_path_for_tests.joinpath("test_log.txt"),
"format": "{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
},
# add sinks for stdout and stderr
{"sink": sys.stdout, "format": "{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}"},
{"sink": sys.stderr, "format": "{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}"},
],
)

Expand Down
23 changes: 17 additions & 6 deletions tests/test_scripts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path

from horde_model_reference.legacy.download_live_legacy_dbs import LegacyReferenceDownloadManager
Expand All @@ -13,9 +14,19 @@ def test_download_all_model_references(base_path_for_tests: Path):


def test_validate_stable_diffusion_model_reference(legacy_folder_for_tests: Path):
assert validate_legacy_stable_diffusion_db(
sd_db=get_model_reference_file_path(
MODEL_REFERENCE_CATEGORY.stable_diffusion,
base_path=legacy_folder_for_tests,
),
)
if os.environ.get("HORDELIB_CI_ONGOING"):
assert validate_legacy_stable_diffusion_db(
sd_db=get_model_reference_file_path(
MODEL_REFERENCE_CATEGORY.stable_diffusion,
base_path=legacy_folder_for_tests,
),
fail_on_extra=True,
)
else:
assert validate_legacy_stable_diffusion_db(
sd_db=get_model_reference_file_path(
MODEL_REFERENCE_CATEGORY.stable_diffusion,
base_path=legacy_folder_for_tests,
),
fail_on_extra=False,
)
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ commands = pre-commit run --all-files --show-diff-on-failure
[testenv:tests]
description = install pytest in a virtual environment and invoke it on the tests folder
skip_install = false
passenv = HORDELIB_CI_ONGOING
deps =
pytest>=7
pytest-sugar
Expand Down

0 comments on commit 9ef7cd7

Please sign in to comment.