Skip to content

Commit

Permalink
[#6] Feat: Flux demo
Browse files Browse the repository at this point in the history
  • Loading branch information
jiooum committed Oct 6, 2024
1 parent b28ef82 commit a9b048e
Show file tree
Hide file tree
Showing 25 changed files with 397 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
max-line-length = 120
extend-ignore = E203, F401, E501, E402, E714
extend-ignore = E203, F401, E501, E402, E714, F811
19 changes: 19 additions & 0 deletions gradio_demo/getting_started/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from gradio_demo.getting_started.containers.app_container import AppContainer
from gradio_demo.getting_started.ui.ui import make_app_ui


if __name__ == "__main__":
container = AppContainer()
container.config.from_yaml(
"/mnt/data/0.Workspace/0.SourceCode/project_zero/gradio_demo/getting_started/configs.yaml"
)
container.wire(modules=["gradio_demo.getting_started.ui.ui"])

# Make ui
ui = make_app_ui()
# os.environ['GRADIO_TEMP_DIR'] = container.config.gradio.temp_dir()
ui.launch(
server_name=container.config.gradio.server_name(),
server_port=container.config.gradio.server_port(),
share=container.config.gradio.share(),
)
File renamed without changes.
27 changes: 27 additions & 0 deletions gradio_demo/getting_started/containers/app_container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from dependency_injector import containers, providers

from gradio_demo.getting_started.core.flux_wrapper import FluxWrapper
from gradio_demo.getting_started.core.minio_wrapper import MinioWrapper
from gradio_demo.getting_started.core.mongo_client_wrapper import MongoClientWrapper


class AppContainer(containers.DeclarativeContainer):
config = providers.Configuration()

minio_storage = providers.Singleton(
MinioWrapper,
endpoint=config.minio_storage.endpoint,
access_key=config.minio_storage.access_key,
secret_key=config.minio_storage.secret_key,
bucket_name=config.minio_storage.bucket_name,
secure=config.minio_storage.secure,
)
mongo_db = providers.Singleton(
MongoClientWrapper,
endpoint=config.mongo_db.endpoint,
username=config.mongo_db.username,
password=config.mongo_db.password,
database=config.mongo_db.database,
collection=config.mongo_db.collection,
)
flux = providers.Singleton(FluxWrapper)
File renamed without changes.
23 changes: 23 additions & 0 deletions gradio_demo/getting_started/core/flux_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import gc

import numpy as np
import torch

from sdk.models.flux import Flux


class FluxWrapper:
def __init__(self, *args, **kwargs):
self._flux: Flux = None

def run(self, n_items: int = 1, *args, **kwargs):
if self._flux is None:
self._flux = Flux()

output_images = []
for i in range(n_items):
img = self._flux.run(*args, **kwargs)
output_images.append(img)
torch.cuda.empty_cache()
gc.collect()
return output_images[0]
15 changes: 15 additions & 0 deletions gradio_demo/getting_started/core/minio_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from sdk.cloud_storage.minio_storage import MinioStorage
from sdk.utils.url_handler import cloud_upload


class MinioWrapper:
def __init__(self, *args, **kwargs):
self.__minio = MinioStorage(*args, **kwargs)

def upload(self, file):
try:
public_url, upload_result = cloud_upload(cloud_storage=self.__minio, local_path=file)
print(f"File uploaded to: {public_url}")
return public_url
except Exception as e:
return str(e)
30 changes: 30 additions & 0 deletions gradio_demo/getting_started/core/mongo_client_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import datetime
import uuid

from gradio_client.documentation import document

from sdk.database.mongo_db import MongoDB


class MongoClientWrapper:
def __init__(
self,
database: str,
collection: str,
username: str,
password: str,
endpoint: str,
):
self._mongo_client = MongoDB(username=username, password=password, endpoint=endpoint)
self._database = database
self._collection = collection

def create(self, data, *args, **kwargs):
doc = {k.label: v for k, v in data.items()}
create_time = datetime.datetime.now()
doc.update({'create_time': create_time})

result = self._mongo_client.insert_one(self._database, self._collection, doc)
inserted_id = str(result.inserted_id)
print(f"Created request_id: {inserted_id}!")
return inserted_id
41 changes: 0 additions & 41 deletions gradio_demo/getting_started/getting_started_demo.py

This file was deleted.

Empty file.
48 changes: 48 additions & 0 deletions gradio_demo/getting_started/models/inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from pydantic import BaseModel, field_validator


class FluxInput(BaseModel):
prompt: str
width: int
height: int
num_inference_steps: int
generator_seed: int
guidance_scale: float

__MIN_SIZE__ = 512
__MAX_SIZE__ = 1920
__DEFAULT_SIZE__ = 720

__MAX_NUM_INFERENCE_STEPS__ = 10
__MIN_NUM_INFERENCE_STEPS__ = 1
__DEFAULT_NUM_INFERENCE_STEPS__ = 4

__MAX_GUIDANCE_SCALE__ = 5.0
__MIN_GUIDANCE_SCALE__ = 0.5
__DEFAULT_GUIDANCE_SCALE__ = 3.5

__DEFAULT_GENERATOR_SEED__ = 12345

@field_validator('width', 'height')
def validate_sizes(cls, v):
assert (
cls.__MIN_SIZE__ <= v <= cls.__MAX_SIZE__
), f"value must be in range {cls.__MIN_SIZE__} to {cls.__MAX_SIZE__}"

@field_validator('num_inference_steps')
def validate_num_inference_steps(cls, v):
assert v <= 10, 'num_inference_steps must not be greater than 10'
return v

@field_validator('guidance_scale')
def validate_guidance_scale(cls, v):
assert (
cls.__MAX_GUIDANCE_SCALE__ <= v <= cls.__MIN_GUIDANCE_SCALE__
), f'guidance_scale must be in range {cls.__MIN_GUIDANCE_SCALE__} to {cls.__MAX_GUIDANCE_SCALE__}'

@field_validator('num_inference_steps')
def validate_num_inference_steps(cls, v):
assert cls.__MIN_NUM_INFERENCE_STEPS__ <= v <= cls.__MAX_NUM_INFERENCE_STEPS__, (
f'num_inference_steps must be in range '
f'{cls.__MIN_NUM_INFERENCE_STEPS__} to {cls.__MAX_NUM_INFERENCE_STEPS__}'
)
21 changes: 21 additions & 0 deletions gradio_demo/getting_started/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
torch~=2.4.1
torchvision
einops
fire >= 0.6.0
huggingface-hub
safetensors
sentencepiece
transformers~=4.44.2
tokenizers
protobuf
requests
invisible-watermark
optimum-quanto
git+https://github.com/huggingface/diffusers.git
dependency-injector~=4.42.0
minio~=7.2.9
gradio~=4.44.1
pydantic~=2.9.2
wget~=3.2
exceptiongroup~=1.2.2
optimum~=1.22.0
1 change: 1 addition & 0 deletions gradio_demo/getting_started/ui/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .ui import make_app_ui
94 changes: 94 additions & 0 deletions gradio_demo/getting_started/ui/ui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import gradio
from dependency_injector.providers import Configuration
from dependency_injector.wiring import Provide, inject

from gradio_demo.getting_started.containers.app_container import AppContainer
from gradio_demo.getting_started.core.flux_wrapper import FluxWrapper
from gradio_demo.getting_started.core.minio_wrapper import MinioWrapper
from gradio_demo.getting_started.core.mongo_client_wrapper import MongoClientWrapper
from gradio_demo.getting_started.models.inputs import FluxInput


@inject
def make_app_ui(
minio_storage: MinioWrapper = Provide[AppContainer.minio_storage],
flux: FluxWrapper = Provide[AppContainer.flux],
mongo_db: MongoClientWrapper = Provide[AppContainer.mongo_db],
config: Configuration = Provide[AppContainer.config],
):
with gradio.Blocks() as demo:
with gradio.Row(show_progress=False):
with gradio.Column(show_progress=False):
gradio.Markdown(
"""
# Inputs
"""
)
input_prompt = gradio.Textbox(label="Input prompt", interactive=True)
image_width = gradio.Slider(
label="Image width",
minimum=FluxInput.__MIN_SIZE__,
maximum=FluxInput.__MAX_SIZE__,
value=FluxInput.__DEFAULT_SIZE__,
step=8,
interactive=True,
)
image_height = gradio.Slider(
label="Image height",
minimum=FluxInput.__MIN_SIZE__,
maximum=FluxInput.__MAX_SIZE__,
value=FluxInput.__DEFAULT_SIZE__,
step=8,
interactive=True,
)
num_inference_step = gradio.Slider(
label="Number of inference steps",
minimum=FluxInput.__MIN_NUM_INFERENCE_STEPS__,
maximum=FluxInput.__MAX_NUM_INFERENCE_STEPS__,
value=FluxInput.__DEFAULT_NUM_INFERENCE_STEPS__,
step=1,
interactive=True,
)
generator_seed = gradio.Number(
label="Generator seed", value=FluxInput.__DEFAULT_GENERATOR_SEED__, interactive=True
)
guidance_scale = gradio.Slider(
label="Guidance scale",
minimum=FluxInput.__MIN_GUIDANCE_SCALE__,
maximum=FluxInput.__MAX_GUIDANCE_SCALE__,
value=FluxInput.__DEFAULT_GUIDANCE_SCALE__,
interactive=True,
)
n_items = gradio.Number(value=1, visible=False)

btn_generate_images = gradio.Button("Generate images")

with gradio.Column(show_progress=True):
gradio.Markdown(
"""
# Outputs
"""
)
request_id = gradio.Textbox(label="Request ID")
output_url = gradio.Textbox(label="Output url", interactive=False, visible=True)
output_image = gradio.Image(type="filepath", format="png", show_download_button=True, interactive=False)
btn_generate_images.click(
flux.run,
[n_items, input_prompt, image_width, image_height, num_inference_step, guidance_scale, generator_seed],
[output_image],
)
output_image.change(minio_storage.upload, [output_image], [output_url])
output_url.change(
mongo_db.create,
inputs={
input_prompt,
image_width,
image_height,
num_inference_step,
guidance_scale,
generator_seed,
output_url,
},
outputs=[request_id],
)
return demo
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,10 @@ pytest~=7.4.2
dependency-injector>=4.0,<5.0
minio~=7.2.9
gradio~=4.44.1

pillow~=10.4.0
torch~=2.4.1
diffusers~=0.31.0.dev0
transformers~=4.43.3
pymongo~=4.10.1
pydantic~=2.8.2
Empty file added sdk/__init__.py
Empty file.
Empty file added sdk/cloud_storage/__init__.py
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import os
import uuid

from dependency_injector.providers import AbstractSingleton, Configuration
from minio import Minio

from utils.cloud_storage.base_cloud_storage import AbstractCloudStorage
from sdk.cloud_storage.base_cloud_storage import AbstractCloudStorage


class MinioStorage(AbstractCloudStorage):
Expand Down
Empty file added sdk/database/__init__.py
Empty file.
25 changes: 25 additions & 0 deletions sdk/database/mongo_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import urllib.parse

import pymongo


class MongoDB:
def __init__(self, username: str, password: str, endpoint: str):
mongodb_uri = f"mongodb://{username}:{urllib.parse.quote(password)}@{endpoint}"
self._client = pymongo.MongoClient(mongodb_uri, authSource="admin", connect=True)

def insert_one(self, database: str, collection: str, document: dict):
collection = self._client.get_database(database).get_collection(collection)
return collection.insert_one(document)

def update_one(self, database: str, collection: str, document_id: str, document: dict):
collection = self._client.get_database(database).get_collection(collection)
collection.update_one(
filter={"_id": document_id},
update={"$set": document},
)


if __name__ == '__main__':
mongodb = MongoDB(username="admin", password="mongodb123", endpoint="localhost:27017")
mongodb.insert_one(database="app", collection="request", document={"msg": "test"})
Empty file added sdk/models/__init__.py
Empty file.
Loading

0 comments on commit a9b048e

Please sign in to comment.