Skip to content

Commit

Permalink
Merge pull request #7 from Jiooum102/6-getting-started-with-gradio
Browse files Browse the repository at this point in the history
Getting started with gradio
  • Loading branch information
Jiooum102 authored Oct 16, 2024
2 parents c009265 + 5de4893 commit 5cb825e
Show file tree
Hide file tree
Showing 32 changed files with 788 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[flake8]
max-line-length = 120
extend-ignore = E203, F401, E501, E402, E714, F811
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,6 @@ pip-selfcheck.json
.ionide

# End of https://www.toptal.com/developers/gitignore/api/python,pycharm+all,visualstudiocode,venv

# Data
data/
Empty file added gradio_demo/__init__.py
Empty file.
Empty file.
19 changes: 19 additions & 0 deletions gradio_demo/getting_started/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from gradio_demo.getting_started.containers.app_container import AppContainer
from gradio_demo.getting_started.ui.ui import make_app_ui


if __name__ == "__main__":
container = AppContainer()
container.config.from_yaml(
"/mnt/data/0.Workspace/0.SourceCode/project_zero/gradio_demo/getting_started/configs.yaml"
)
container.wire(modules=["gradio_demo.getting_started.ui.ui"])

# Make ui
ui = make_app_ui()
# os.environ['GRADIO_TEMP_DIR'] = container.config.gradio.temp_dir()
ui.launch(
server_name=container.config.gradio.server_name(),
server_port=container.config.gradio.server_port(),
share=container.config.gradio.share(),
)
20 changes: 20 additions & 0 deletions gradio_demo/getting_started/configs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
gradio:
server_name: "0.0.0.0"
server_port: 7001
share: false
temp_dir: '/mnt/data/0.Workspace/0.SourceCode/project_zero/data/temp'

minio_storage:
endpoint: "192.168.194.2:9000"
access_key: "jKYA67p9LclmAhr30cDO"
secret_key: "ovN1SNMTpIJxRKlps71JZBeypECoFP6LaFstXC3S"
bucket_name: "public"
secure: false

mongo_db:
endpoint: "192.168.194.2:27017"
username: "admin"
password: "mongodb123"
database: "app"
requests_collection: "requests"
users_collection: "users"
Empty file.
37 changes: 37 additions & 0 deletions gradio_demo/getting_started/containers/app_container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from dependency_injector import containers, providers

from gradio_demo.getting_started.core.controller import AppController
from gradio_demo.getting_started.core.flux_wrapper import FluxWrapper
from gradio_demo.getting_started.core.minio_wrapper import MinioWrapper
from gradio_demo.getting_started.core.mongo_client_wrapper import MongoClientWrapper


class AppContainer(containers.DeclarativeContainer):
config = providers.Configuration()

minio_storage = providers.Factory(
MinioWrapper,
endpoint=config.minio_storage.endpoint,
access_key=config.minio_storage.access_key,
secret_key=config.minio_storage.secret_key,
bucket_name=config.minio_storage.bucket_name,
secure=config.minio_storage.secure,
)
mongo_db = providers.Singleton(
MongoClientWrapper,
endpoint=config.mongo_db.endpoint,
username=config.mongo_db.username,
password=config.mongo_db.password,
database=config.mongo_db.database,
users_collection=config.mongo_db.users_collection,
requests_collection=config.mongo_db.requests_collection,
)
flux = providers.Singleton(FluxWrapper)

app_controller = providers.Singleton(
AppController,
minio_storage=minio_storage,
flux=flux,
mongo_db=mongo_db,
config=config,
)
Empty file.
156 changes: 156 additions & 0 deletions gradio_demo/getting_started/core/controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import os
import uuid
from typing import Dict, Union

import gradio

from gradio_demo.getting_started.core.flux_wrapper import FluxWrapper
from gradio_demo.getting_started.core.minio_wrapper import MinioWrapper
from gradio_demo.getting_started.core.mongo_client_wrapper import MongoClientWrapper
from gradio_demo.getting_started.models.inputs import FluxInput
from gradio_demo.getting_started.models.outputs import FluxOutput
from sdk.utils.download import download_file


class AppController:
def __init__(
self,
minio_storage: MinioWrapper,
flux: FluxWrapper,
mongo_db: MongoClientWrapper,
config: Dict,
):
self.__minio_storage = minio_storage
self.__flux = flux
self.__mongo_db = mongo_db
self.__config = config

self.__session_infor: Dict[str, Dict[str, Union[FluxInput, FluxOutput]]] = {}

def create_new_session(self):
session_id = str(uuid.uuid4())
self.__session_infor[session_id] = {
'input': FluxInput(),
'output': FluxOutput(),
}
return session_id

def update_input_prompt(self, user_id: str, new_input_prompt: str):
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}"

self.__session_infor[user_id]['input'].prompt = new_input_prompt
return []

def update_width(self, user_id: str, width: int):
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}"
self.__session_infor[user_id]['input'].width = width
return []

def update_height(self, user_id: str, height: int):
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}"
self.__session_infor[user_id]["input"].height = height
return []

def update_num_inference_steps(self, user_id: str, num_inference_steps: int):
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}"
self.__session_infor[user_id]["input"].num_inference_steps = num_inference_steps
return []

def update_generator_seed(self, user_id: str, generator_seed: int):
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}"
self.__session_infor[user_id]['input'].generator_seed = generator_seed
return []

def update_guidance_scale(self, user_id: str, guidance_scale: float):
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}"
self.__session_infor[user_id]["input"].guidance_scale = guidance_scale
return []

def update_output_url(self, user_id: str, output_url: str):
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}"
self.__session_infor[user_id]["output"].output_url = output_url
return []

def update_output_image(self, user_id: str, output_path: str):
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}"
self.__session_infor[user_id]["output"].output_path = output_path
return []

def update_request_id(self, user_id: str, request_id: str):
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}"
self.__session_infor[user_id]["output"].output_record_id = request_id
return []

def btn_start_demo_clicked(self):
session_id = self.create_new_session()
return session_id, gradio.Row(visible=True), gradio.Button(visible=False), gradio.Textbox(visible=True)

def get_examples(self, limit: int = 5):
records = self.__mongo_db.get_latest_requests(limit=limit)
_examples = []
for record in records:
request_id = str(record["_id"])
_input = FluxInput()
_input.prompt = record['input']['prompt']
_input.width = record['input']['width']
_input.height = record['input']['height']
_input.num_inference_steps = record['input']['num_inference_steps']
_input.generator_seed = record['input']['generator_seed']
_input.guidance_scale = record['input']['guidance_scale']
_output = FluxOutput(**record["output"])
_examples.append(
[
request_id,
_input.prompt,
_output.output_url,
_input.width,
_input.height,
_input.num_inference_steps,
_input.generator_seed,
_input.guidance_scale,
]
)
return _examples

def btn_load_examples_clicked(self):
_examples = self.get_examples(limit=50)
return gradio.Dataset(samples=_examples, visible=True)

def load_image_url(self, request_id: str) -> str:
_request_infor = self.__mongo_db.find_request(request_id=request_id)
_output = FluxOutput(**_request_infor['output'])
if not os.path.exists(_output.output_path):
download_file(url=_output.output_url, save_path=_output.output_path)
return _output.output_path

def run(self, session_id: str):
"""
:param session_id: User ID
:return:
"""

# Run flux
model_input = self.__session_infor[session_id]['input']
output_image = self.__flux.run(**model_input.model_dump())

# Upload minio
tmp_dir: str = self.__config["gradio"]["temp_dir"]
output_path = os.path.join(tmp_dir, f"{uuid.uuid4()}.png")
output_image.save(output_path)

output_url = self.__minio_storage.upload(output_path)

output = FluxOutput(output_url=output_url, output_path=output_path)

# Update db
record_info = {
"user_id": session_id,
"input": model_input.model_dump(),
"output": output.model_dump(exclude={'output_record_id'}),
}
request_id = self.__mongo_db.insert_one_request(record_info)

output.output_record_id = request_id
self.__session_infor[session_id]['output'] = output
return output_path, output_url, request_id
24 changes: 24 additions & 0 deletions gradio_demo/getting_started/core/flux_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import gc

import PIL.Image
import torch

from sdk.models.flux import Flux


class FluxWrapper:
def __init__(self, *args, **kwargs):
self._flux: Flux = None

def run(self, n_items: int = 1, *args, **kwargs) -> PIL.Image.Image:
if self._flux is None:
self._flux = Flux()

output_images = []
for i in range(n_items):
img = self._flux.run(*args, **kwargs)
output_images.append(img)

torch.cuda.empty_cache()
gc.collect()
return output_images[0]
15 changes: 15 additions & 0 deletions gradio_demo/getting_started/core/minio_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from sdk.cloud_storage.minio_storage import MinioStorage
from sdk.utils.url_handler import cloud_upload


class MinioWrapper:
def __init__(self, *args, **kwargs):
self.__minio = MinioStorage(*args, **kwargs)

def upload(self, file):
try:
public_url, upload_result = cloud_upload(cloud_storage=self.__minio, local_path=file)
print(f"File uploaded to: {public_url}")
return public_url
except Exception as e:
return str(e)
58 changes: 58 additions & 0 deletions gradio_demo/getting_started/core/mongo_client_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import datetime

from bson import ObjectId

from gradio_demo.getting_started.models.user import User
from sdk.database.mongo_db import MongoDB


class MongoClientWrapper:
def __init__(
self,
database: str,
users_collection: str,
requests_collection: str,
username: str,
password: str,
endpoint: str,
):
self._mongo_client = MongoDB(username=username, password=password, endpoint=endpoint)
self._database = database
self._users_collection = users_collection
self._request_collection = requests_collection

def insert_request(self, data, *args, **kwargs):
doc = {k.label: v for k, v in data.items()}
return self.insert_one_request(doc, *args, **kwargs)

def insert_one_request(self, data: dict, *args, **kwargs):
create_time = datetime.datetime.now()
data.update({'create_time': create_time})
result = self._mongo_client.insert_one(self._database, self._request_collection, data)
inserted_id = str(result.inserted_id)
return inserted_id

def insert_email(self, email: str):
query = {"email": email}
query_result = self._mongo_client.find_one(self._database, self._users_collection, query)

# Email already existed in db
if query_result is not None:
user = User(**query_result)
else:
# Create new user
user = User(email=email)
insert_result = self._mongo_client.insert_one(
self._database, self._users_collection, document=user.model_dump()
)
print(f"Inserted new user with document id: {insert_result.inserted_id}")
return user.user_id

def get_latest_requests(self, limit: int = 50):
return self._mongo_client.find(self._database, self._request_collection, query={}).sort('create_time', -1)[
:limit
]

def find_request(self, request_id: str):
query = {"_id": ObjectId(request_id)}
return self._mongo_client.find_one(self._database, self._request_collection, query=query)
Empty file.
48 changes: 48 additions & 0 deletions gradio_demo/getting_started/models/inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from pydantic import BaseModel, field_validator


class FluxInput(BaseModel):
__MIN_SIZE__ = 512
__MAX_SIZE__ = 1920
__DEFAULT_SIZE__ = 720

__MAX_NUM_INFERENCE_STEPS__ = 10
__MIN_NUM_INFERENCE_STEPS__ = 1
__DEFAULT_NUM_INFERENCE_STEPS__ = 4

__MAX_GUIDANCE_SCALE__ = 5.0
__MIN_GUIDANCE_SCALE__ = 0.5
__DEFAULT_GUIDANCE_SCALE__ = 3.5

__DEFAULT_GENERATOR_SEED__ = 12345

prompt: str = ""
width: int = __DEFAULT_SIZE__
height: int = __DEFAULT_SIZE__
num_inference_steps: int = __DEFAULT_NUM_INFERENCE_STEPS__
generator_seed: int = __DEFAULT_GENERATOR_SEED__
guidance_scale: float = __DEFAULT_GUIDANCE_SCALE__

@field_validator('width', 'height')
def validate_sizes(cls, v):
assert (
cls.__MIN_SIZE__ <= v <= cls.__MAX_SIZE__
), f"value must be in range {cls.__MIN_SIZE__} to {cls.__MAX_SIZE__}"

@field_validator('num_inference_steps')
def validate_num_inference_steps(cls, v):
assert v <= 10, 'num_inference_steps must not be greater than 10'
return v

@field_validator('guidance_scale')
def validate_guidance_scale(cls, v):
assert (
cls.__MIN_GUIDANCE_SCALE__ <= v <= cls.__MAX_GUIDANCE_SCALE__
), f'guidance_scale must be in range {cls.__MIN_GUIDANCE_SCALE__} to {cls.__MAX_GUIDANCE_SCALE__}'

@field_validator('num_inference_steps')
def validate_num_inference_steps(cls, v):
assert cls.__MIN_NUM_INFERENCE_STEPS__ <= v <= cls.__MAX_NUM_INFERENCE_STEPS__, (
f'num_inference_steps must be in range '
f'{cls.__MIN_NUM_INFERENCE_STEPS__} to {cls.__MAX_NUM_INFERENCE_STEPS__}'
)
Loading

0 comments on commit 5cb825e

Please sign in to comment.