Skip to content

Commit

Permalink
refactor: add T2I parameter annotations
Browse files Browse the repository at this point in the history
This commit adds parameter annotations to the T2I pipeline similar to
how it is done in the rest of the pipelines. Descriptions will be added
in a subsequenty commit.
  • Loading branch information
rickstaa committed Aug 1, 2024
1 parent 8833c77 commit e14bb84
Show file tree
Hide file tree
Showing 15 changed files with 973 additions and 123 deletions.
842 changes: 842 additions & 0 deletions openapi.json

Large diffs are not rendered by default.

25 changes: 8 additions & 17 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,18 @@

import PIL
import torch
from diffusers import (
AutoPipelineForImage2Image,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from app.pipelines.base import Pipeline
from app.pipelines.utils import (SafetyChecker, get_model_dir,
get_torch_device, is_lightning_model,
is_turbo_model)
from diffusers import (AutoPipelineForImage2Image,
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionXLPipeline, UNet2DConditionModel)
from huggingface_hub import file_download, hf_hub_download
from PIL import ImageFile
from safetensors.torch import load_file

from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)
Expand Down
3 changes: 2 additions & 1 deletion runner/app/pipelines/optim/sfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import logging

from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig, compile
from sfast.compilers.diffusion_pipeline_compiler import (CompilationConfig,
compile)

logger = logging.getLogger(__name__)

Expand Down
21 changes: 6 additions & 15 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,12 @@
import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
split_prompt,
)
from diffusers import (
AutoPipelineForText2Image,
EulerDiscreteScheduler,
StableDiffusion3Pipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from app.pipelines.utils import (SafetyChecker, get_model_dir,
get_torch_device, is_lightning_model,
is_turbo_model, split_prompt)
from diffusers import (AutoPipelineForText2Image, EulerDiscreteScheduler,
StableDiffusion3Pipeline, StableDiffusionXLPipeline,
UNet2DConditionModel)
from diffusers.models import AutoencoderKL
from huggingface_hub import file_download, hf_hub_download
from safetensors.torch import load_file
Expand Down
13 changes: 4 additions & 9 deletions runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,14 @@

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (SafetyChecker, get_model_dir,
get_torch_device, is_lightning_model,
is_turbo_model)
from diffusers import StableDiffusionUpscalePipeline
from huggingface_hub import file_download
from PIL import ImageFile

from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)
Expand Down
14 changes: 4 additions & 10 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
"""This module contains several utility functions that are used across the pipelines module."""

from app.pipelines.utils.utils import (
SafetyChecker,
get_model_dir,
get_model_path,
get_torch_device,
is_lightning_model,
is_turbo_model,
split_prompt,
validate_torch_device,
)
from app.pipelines.utils.utils import (SafetyChecker, get_model_dir,
get_model_path, get_torch_device,
is_lightning_model, is_turbo_model,
split_prompt, validate_torch_device)
3 changes: 2 additions & 1 deletion runner/app/routes/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.pipelines.utils.audio import AudioConversionError
from app.routes.util import HTTPError, TextResponse, file_exceeds_max_size, http_error
from app.routes.util import (HTTPError, TextResponse, file_exceeds_max_size,
http_error)
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand Down
8 changes: 4 additions & 4 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import random
from typing import Annotated

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import (HTTPError, ImageResponse, http_error,
image_to_data_url)
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import Image, ImageFile

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url

ImageFile.LOAD_TRUNCATED_IMAGES = True

router = APIRouter()
Expand Down
3 changes: 2 additions & 1 deletion runner/app/routes/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, VideoResponse, http_error, image_to_data_url
from app.routes.util import (HTTPError, VideoResponse, http_error,
image_to_data_url)
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand Down
29 changes: 17 additions & 12 deletions runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import logging
import os
import random
from typing import Annotated

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url
from app.routes.util import (HTTPError, ImageResponse, http_error,
image_to_data_url)
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
from pydantic import BaseModel, Field

router = APIRouter()

Expand All @@ -18,16 +20,19 @@
class TextToImageParams(BaseModel):
# TODO: Make model_id and other None properties optional once Go codegen tool
# supports OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
model_id: str = ""
prompt: str
height: int = None
width: int = None
guidance_scale: float = 7.5
negative_prompt: str = ""
safety_check: bool = True
seed: int = None
num_inference_steps: int = 50 # NOTE: Hardcoded due to varying pipeline values.
num_images_per_prompt: int = 1
model_id: Annotated[
str,
Field(default="", description=""),
]
prompt: Annotated[str, Field(description="")]
height: Annotated[int, Field(default=576, description="")]
width: Annotated[int, Field(default=1024, description="")]
guidance_scale: Annotated[float, Field(default=7.5, description="")]
negative_prompt: Annotated[str, Field(default="", description="")]
safety_check: Annotated[bool, Field(default=True, description="")]
seed: Annotated[int, Field(default=None, description="")]
num_inference_steps: Annotated[int, Field(default=50, description="")]
num_images_per_prompt: Annotated[int, Field(default=1, description="")]


RESPONSES = {
Expand Down
8 changes: 4 additions & 4 deletions runner/app/routes/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import random
from typing import Annotated

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import (HTTPError, ImageResponse, http_error,
image_to_data_url)
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import Image, ImageFile

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url

ImageFile.LOAD_TRUNCATED_IMAGES = True

router = APIRouter()
Expand Down
10 changes: 2 additions & 8 deletions runner/gen_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,8 @@

import yaml
from app.main import app, use_route_names_as_operation_ids
from app.routes import (
audio_to_text,
health,
image_to_image,
image_to_video,
text_to_image,
upscale,
)
from app.routes import (audio_to_text, health, image_to_image, image_to_video,
text_to_image, upscale)
from fastapi.openapi.utils import get_openapi

# Specify Endpoints for OpenAPI schema generation.
Expand Down
3 changes: 2 additions & 1 deletion runner/modal_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import os
from pathlib import Path

from app.main import config_logging, load_route, use_route_names_as_operation_ids
from app.main import (config_logging, load_route,
use_route_names_as_operation_ids)
from app.routes import health
from modal import Image, Secret, Stub, Volume, asgi_app, enter, method

Expand Down
20 changes: 16 additions & 4 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -700,47 +700,59 @@
"model_id": {
"type": "string",
"title": "Model Id",
"description": "The ID of the model to use for image generation.",
"default": ""
},
"prompt": {
"type": "string",
"title": "Prompt"
"title": "Prompt",
"description": "The text prompt to generate the image from."
},
"height": {
"type": "integer",
"title": "Height"
"title": "Height",
"description": "The height of the output image in pixels.",
"default": 576
},
"width": {
"type": "integer",
"title": "Width"
"title": "Width",
"description": "The width of the output image in pixels.",
"default": 1024
},
"guidance_scale": {
"type": "number",
"title": "Guidance Scale",
"description": "The guidance scale for image generation.",
"default": 7.5
},
"negative_prompt": {
"type": "string",
"title": "Negative Prompt",
"description": "The negative prompt to avoid certain features in the image.",
"default": ""
},
"safety_check": {
"type": "boolean",
"title": "Safety Check",
"description": "Whether to perform a safety check on the generated image.",
"default": true
},
"seed": {
"type": "integer",
"title": "Seed"
"title": "Seed",
"description": "The seed for random number generation."
},
"num_inference_steps": {
"type": "integer",
"title": "Num Inference Steps",
"description": "The number of inference steps for image generation.",
"default": 50
},
"num_images_per_prompt": {
"type": "integer",
"title": "Num Images Per Prompt",
"description": "The number of images to generate per prompt.",
"default": 1
}
},
Expand Down
Loading

0 comments on commit e14bb84

Please sign in to comment.