Skip to content

Commit

Permalink
Merge branch 'next' into next-allow-symlinks-instead-of-copy
Browse files Browse the repository at this point in the history
  • Loading branch information
lstein authored Feb 28, 2024
2 parents 5686287 + 4418c11 commit 1a582e0
Show file tree
Hide file tree
Showing 318 changed files with 7,196 additions and 10,772 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/lint-frontend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ jobs:
- name: Typescript
run: 'pnpm run lint:tsc'
- name: Madge
run: 'pnpm run lint:madge'
run: 'pnpm run lint:dpdm'
- name: ESLint
run: 'pnpm run lint:eslint'
- name: Prettier
run: 'pnpm run lint:prettier'
- name: Knip
run: 'pnpm run lint:knip'
190 changes: 75 additions & 115 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated

from invokeai.app.services.model_install import ModelInstallJob, ModelSource
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
Expand Down Expand Up @@ -165,6 +165,27 @@ async def list_model_records(
return ModelsList(models=found_models)


@model_manager_router.get(
"/get_by_attrs",
operation_id="get_model_records_by_attrs",
response_model=AnyModelConfig,
)
async def get_model_records_by_attrs(
name: str = Query(description="The name of the model"),
type: ModelType = Query(description="The type of the model"),
base: BaseModelType = Query(description="The base model of the model"),
) -> AnyModelConfig:
"""Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old
model manager, which identified models by a combination of name, base and type."""
configs = ApiDependencies.invoker.services.model_manager.store.search_by_attr(
base_model=base, model_type=type, model_name=name
)
if not configs:
raise HTTPException(status_code=404, detail="No model found with these attributes")

return configs[0]


@model_manager_router.get(
"/i/{key}",
operation_id="get_model_record",
Expand Down Expand Up @@ -202,15 +223,14 @@ async def list_model_summary(


@model_manager_router.get(
"/meta/i/{key}",
"/i/{key}/metadata",
operation_id="get_model_metadata",
responses={
200: {
"description": "The model metadata was retrieved successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
404: {"description": "No metadata available"},
},
)
async def get_model_metadata(
Expand All @@ -219,8 +239,7 @@ async def get_model_metadata(
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
if not result:
raise HTTPException(status_code=404, detail="No metadata for a model with this key")

return result


Expand All @@ -235,6 +254,11 @@ async def list_tags() -> Set[str]:
return result


class FoundModel(BaseModel):
path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed")


@model_manager_router.get(
"/scan_folder",
operation_id="scan_for_models",
Expand All @@ -243,11 +267,11 @@ async def list_tags() -> Set[str]:
400: {"description": "Invalid directory path"},
},
status_code=200,
response_model=List[pathlib.Path],
response_model=List[FoundModel],
)
async def scan_for_models(
scan_path: str = Query(description="Directory path to search for models", default=None),
) -> List[pathlib.Path]:
) -> List[FoundModel]:
path = pathlib.Path(scan_path)
if not scan_path or not path.is_dir():
raise HTTPException(
Expand All @@ -257,13 +281,46 @@ async def scan_for_models(

search = ModelSearch()
try:
models_found = list(search.search(path))
found_model_paths = search.search(path)
models_path = ApiDependencies.invoker.services.configuration.models_path

# If the search path includes the main models directory, we need to exclude core models from the list.
# TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed
# without needing to crawl the filesystem.
core_models_path = pathlib.Path(models_path, "core").resolve()
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]

installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
resolved_installed_model_paths: list[str] = []
installed_model_sources: list[str] = []

# This call lists all installed models.
for model in installed_models:
path = pathlib.Path(model.path)
# If the model has a source, we need to add it to the list of installed sources.
if model.source:
installed_model_sources.append(model.source)
# If the path is not absolute, that means it is in the app models directory, and we need to join it with
# the models path before resolving.
if not path.is_absolute():
resolved_installed_model_paths.append(str(pathlib.Path(models_path, path).resolve()))
continue
resolved_installed_model_paths.append(str(path.resolve()))

scan_results: list[FoundModel] = []

# Check if the model is installed by comparing the resolved paths, appending to the scan result.
for p in non_core_model_paths:
path = str(p)
is_installed = path in resolved_installed_model_paths or path in installed_model_sources
found_model = FoundModel(path=path, is_installed=is_installed)
scan_results.append(found_model)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while searching the directory: {e}",
)
return models_found
return scan_results


@model_manager_router.get(
Expand Down Expand Up @@ -382,8 +439,8 @@ async def add_model_record(


@model_manager_router.post(
"/heuristic_install",
operation_id="heuristic_install_model",
"/install",
operation_id="install_model",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
Expand All @@ -392,8 +449,9 @@ async def add_model_record(
},
status_code=201,
)
async def heuristic_install(
source: str,
async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
# TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
Expand Down Expand Up @@ -434,106 +492,7 @@ async def heuristic_install(
result: ModelInstallJob = installer.heuristic_import(
source=source,
config=config,
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return result


@model_manager_router.post(
"/install",
operation_id="import_model",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
)
async def import_model(
source: ModelSource,
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
),
) -> ModelInstallJob:
"""Install a model using its local path, repo_id, or remote URL.
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute
that can be used to monitor progress.
The source object is a discriminated Union of LocalModelSource,
HFModelSource and URLModelSource. Set the "type" field to the
appropriate value:
* To install a local path using LocalModelSource, pass a source of form:
```
{
"type": "local",
"path": "/path/to/model",
"inplace": false
}
```
The "inplace" flag, if true, will register the model in place in its
current filesystem location. Otherwise, the model will be copied
into the InvokeAI models directory.
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
```
{
"type": "hf",
"repo_id": "stabilityai/stable-diffusion-2.0",
"variant": "fp16",
"subfolder": "vae",
"access_token": "f5820a918aaf01"
}
```
The `variant`, `subfolder` and `access_token` fields are optional.
* To install a remote model using an arbitrary URL, pass:
```
{
"type": "url",
"url": "http://www.civitai.com/models/123456",
"access_token": "f5820a918aaf01"
}
```
The `access_token` field is optonal
The model's configuration record will be probed and filled in
automatically. To override the default guesses, pass "metadata"
with a Dict containing the attributes you wish to override.
Installation occurs in the background. Either use list_model_install_jobs()
to poll for completion, or listen on the event bus for the following events:
* "model_install_running"
* "model_install_completed"
* "model_install_error"
On successful completion, the event's payload will contain the field "key"
containing the installed ID of the model. On an error, the event's payload
will contain the fields "error_type" and "error" describing the nature of the
error and its traceback, respectively.
"""
logger = ApiDependencies.invoker.services.logger

try:
installer = ApiDependencies.invoker.services.model_manager.install
result: ModelInstallJob = installer.import_model(
source=source,
config=config,
access_token=access_token,
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:
Expand Down Expand Up @@ -669,6 +628,7 @@ async def convert_model(
Note that during the conversion process the key and model hash will change.
The return value is the model configuration for the converted model.
"""
model_manager = ApiDependencies.invoker.services.model_manager
logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load
store = ApiDependencies.invoker.services.model_manager.store
Expand All @@ -685,7 +645,7 @@ async def convert_model(
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")

# loading the model will convert it into a cached diffusers file
loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)

# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)
Expand Down
43 changes: 38 additions & 5 deletions invokeai/app/invocations/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker

from .baseinvocation import (
BaseInvocation,
Classification,
invocation,
)
from .baseinvocation import BaseInvocation, Classification, invocation


@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1")
Expand Down Expand Up @@ -934,3 +930,40 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
image_dto = context.images.save(image=image)

return ImageOutput.build(image_dto)


@invocation(
"canvas_paste_back",
title="Canvas Paste Back",
tags=["image", "combine"],
category="image",
version="1.0.0",
)
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Combines two images by using the mask provided. Intended for use on the Unified Canvas."""

source_image: ImageField = InputField(description="The source image")
target_image: ImageField = InputField(default=None, description="The target image")
mask: ImageField = InputField(
description="The mask to use when pasting",
)
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")

def _prepare_mask(self, mask: Image.Image) -> Image.Image:
mask_array = numpy.array(mask)
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
dilated_mask = Image.fromarray(dilated_mask_array)
if self.mask_blur > 0:
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
return ImageOps.invert(mask.convert("L"))

def invoke(self, context: InvocationContext) -> ImageOutput:
source_image = context.images.get_pil(self.source_image.image_name)
target_image = context.images.get_pil(self.target_image.image_name)
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))

source_image.paste(target_image, (0, 0), mask)

image_dto = context.images.save(image=source_image)
return ImageOutput.build(image_dto)
4 changes: 2 additions & 2 deletions invokeai/app/invocations/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MetadataItemField(BaseModel):
class LoRAMetadataField(BaseModel):
"""LoRA Metadata Field"""

lora: LoRAModelField = Field(description=FieldDescriptions.lora_model)
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
weight: float = Field(description=FieldDescriptions.lora_weight)


Expand Down Expand Up @@ -114,7 +114,7 @@ def invoke(self, context: InvocationContext) -> MetadataOutput:
]


@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.0.1")
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.1.1")
class CoreMetadataInvocation(BaseInvocation):
"""Collects core generation metadata into a MetadataField"""

Expand Down
Loading

0 comments on commit 1a582e0

Please sign in to comment.