-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from Jiooum102/6-getting-started-with-gradio
Getting started with gradio
- Loading branch information
Showing
32 changed files
with
788 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[flake8] | ||
max-line-length = 120 | ||
extend-ignore = E203, F401, E501, E402, E714, F811 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from gradio_demo.getting_started.containers.app_container import AppContainer | ||
from gradio_demo.getting_started.ui.ui import make_app_ui | ||
|
||
|
||
if __name__ == "__main__": | ||
container = AppContainer() | ||
container.config.from_yaml( | ||
"/mnt/data/0.Workspace/0.SourceCode/project_zero/gradio_demo/getting_started/configs.yaml" | ||
) | ||
container.wire(modules=["gradio_demo.getting_started.ui.ui"]) | ||
|
||
# Make ui | ||
ui = make_app_ui() | ||
# os.environ['GRADIO_TEMP_DIR'] = container.config.gradio.temp_dir() | ||
ui.launch( | ||
server_name=container.config.gradio.server_name(), | ||
server_port=container.config.gradio.server_port(), | ||
share=container.config.gradio.share(), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
gradio: | ||
server_name: "0.0.0.0" | ||
server_port: 7001 | ||
share: false | ||
temp_dir: '/mnt/data/0.Workspace/0.SourceCode/project_zero/data/temp' | ||
|
||
minio_storage: | ||
endpoint: "192.168.194.2:9000" | ||
access_key: "jKYA67p9LclmAhr30cDO" | ||
secret_key: "ovN1SNMTpIJxRKlps71JZBeypECoFP6LaFstXC3S" | ||
bucket_name: "public" | ||
secure: false | ||
|
||
mongo_db: | ||
endpoint: "192.168.194.2:27017" | ||
username: "admin" | ||
password: "mongodb123" | ||
database: "app" | ||
requests_collection: "requests" | ||
users_collection: "users" |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from dependency_injector import containers, providers | ||
|
||
from gradio_demo.getting_started.core.controller import AppController | ||
from gradio_demo.getting_started.core.flux_wrapper import FluxWrapper | ||
from gradio_demo.getting_started.core.minio_wrapper import MinioWrapper | ||
from gradio_demo.getting_started.core.mongo_client_wrapper import MongoClientWrapper | ||
|
||
|
||
class AppContainer(containers.DeclarativeContainer): | ||
config = providers.Configuration() | ||
|
||
minio_storage = providers.Factory( | ||
MinioWrapper, | ||
endpoint=config.minio_storage.endpoint, | ||
access_key=config.minio_storage.access_key, | ||
secret_key=config.minio_storage.secret_key, | ||
bucket_name=config.minio_storage.bucket_name, | ||
secure=config.minio_storage.secure, | ||
) | ||
mongo_db = providers.Singleton( | ||
MongoClientWrapper, | ||
endpoint=config.mongo_db.endpoint, | ||
username=config.mongo_db.username, | ||
password=config.mongo_db.password, | ||
database=config.mongo_db.database, | ||
users_collection=config.mongo_db.users_collection, | ||
requests_collection=config.mongo_db.requests_collection, | ||
) | ||
flux = providers.Singleton(FluxWrapper) | ||
|
||
app_controller = providers.Singleton( | ||
AppController, | ||
minio_storage=minio_storage, | ||
flux=flux, | ||
mongo_db=mongo_db, | ||
config=config, | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import os | ||
import uuid | ||
from typing import Dict, Union | ||
|
||
import gradio | ||
|
||
from gradio_demo.getting_started.core.flux_wrapper import FluxWrapper | ||
from gradio_demo.getting_started.core.minio_wrapper import MinioWrapper | ||
from gradio_demo.getting_started.core.mongo_client_wrapper import MongoClientWrapper | ||
from gradio_demo.getting_started.models.inputs import FluxInput | ||
from gradio_demo.getting_started.models.outputs import FluxOutput | ||
from sdk.utils.download import download_file | ||
|
||
|
||
class AppController: | ||
def __init__( | ||
self, | ||
minio_storage: MinioWrapper, | ||
flux: FluxWrapper, | ||
mongo_db: MongoClientWrapper, | ||
config: Dict, | ||
): | ||
self.__minio_storage = minio_storage | ||
self.__flux = flux | ||
self.__mongo_db = mongo_db | ||
self.__config = config | ||
|
||
self.__session_infor: Dict[str, Dict[str, Union[FluxInput, FluxOutput]]] = {} | ||
|
||
def create_new_session(self): | ||
session_id = str(uuid.uuid4()) | ||
self.__session_infor[session_id] = { | ||
'input': FluxInput(), | ||
'output': FluxOutput(), | ||
} | ||
return session_id | ||
|
||
def update_input_prompt(self, user_id: str, new_input_prompt: str): | ||
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}" | ||
|
||
self.__session_infor[user_id]['input'].prompt = new_input_prompt | ||
return [] | ||
|
||
def update_width(self, user_id: str, width: int): | ||
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}" | ||
self.__session_infor[user_id]['input'].width = width | ||
return [] | ||
|
||
def update_height(self, user_id: str, height: int): | ||
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}" | ||
self.__session_infor[user_id]["input"].height = height | ||
return [] | ||
|
||
def update_num_inference_steps(self, user_id: str, num_inference_steps: int): | ||
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}" | ||
self.__session_infor[user_id]["input"].num_inference_steps = num_inference_steps | ||
return [] | ||
|
||
def update_generator_seed(self, user_id: str, generator_seed: int): | ||
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}" | ||
self.__session_infor[user_id]['input'].generator_seed = generator_seed | ||
return [] | ||
|
||
def update_guidance_scale(self, user_id: str, guidance_scale: float): | ||
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}" | ||
self.__session_infor[user_id]["input"].guidance_scale = guidance_scale | ||
return [] | ||
|
||
def update_output_url(self, user_id: str, output_url: str): | ||
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}" | ||
self.__session_infor[user_id]["output"].output_url = output_url | ||
return [] | ||
|
||
def update_output_image(self, user_id: str, output_path: str): | ||
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}" | ||
self.__session_infor[user_id]["output"].output_path = output_path | ||
return [] | ||
|
||
def update_request_id(self, user_id: str, request_id: str): | ||
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}" | ||
self.__session_infor[user_id]["output"].output_record_id = request_id | ||
return [] | ||
|
||
def btn_start_demo_clicked(self): | ||
session_id = self.create_new_session() | ||
return session_id, gradio.Row(visible=True), gradio.Button(visible=False), gradio.Textbox(visible=True) | ||
|
||
def get_examples(self, limit: int = 5): | ||
records = self.__mongo_db.get_latest_requests(limit=limit) | ||
_examples = [] | ||
for record in records: | ||
request_id = str(record["_id"]) | ||
_input = FluxInput() | ||
_input.prompt = record['input']['prompt'] | ||
_input.width = record['input']['width'] | ||
_input.height = record['input']['height'] | ||
_input.num_inference_steps = record['input']['num_inference_steps'] | ||
_input.generator_seed = record['input']['generator_seed'] | ||
_input.guidance_scale = record['input']['guidance_scale'] | ||
_output = FluxOutput(**record["output"]) | ||
_examples.append( | ||
[ | ||
request_id, | ||
_input.prompt, | ||
_output.output_url, | ||
_input.width, | ||
_input.height, | ||
_input.num_inference_steps, | ||
_input.generator_seed, | ||
_input.guidance_scale, | ||
] | ||
) | ||
return _examples | ||
|
||
def btn_load_examples_clicked(self): | ||
_examples = self.get_examples(limit=50) | ||
return gradio.Dataset(samples=_examples, visible=True) | ||
|
||
def load_image_url(self, request_id: str) -> str: | ||
_request_infor = self.__mongo_db.find_request(request_id=request_id) | ||
_output = FluxOutput(**_request_infor['output']) | ||
if not os.path.exists(_output.output_path): | ||
download_file(url=_output.output_url, save_path=_output.output_path) | ||
return _output.output_path | ||
|
||
def run(self, session_id: str): | ||
""" | ||
:param session_id: User ID | ||
:return: | ||
""" | ||
|
||
# Run flux | ||
model_input = self.__session_infor[session_id]['input'] | ||
output_image = self.__flux.run(**model_input.model_dump()) | ||
|
||
# Upload minio | ||
tmp_dir: str = self.__config["gradio"]["temp_dir"] | ||
output_path = os.path.join(tmp_dir, f"{uuid.uuid4()}.png") | ||
output_image.save(output_path) | ||
|
||
output_url = self.__minio_storage.upload(output_path) | ||
|
||
output = FluxOutput(output_url=output_url, output_path=output_path) | ||
|
||
# Update db | ||
record_info = { | ||
"user_id": session_id, | ||
"input": model_input.model_dump(), | ||
"output": output.model_dump(exclude={'output_record_id'}), | ||
} | ||
request_id = self.__mongo_db.insert_one_request(record_info) | ||
|
||
output.output_record_id = request_id | ||
self.__session_infor[session_id]['output'] = output | ||
return output_path, output_url, request_id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import gc | ||
|
||
import PIL.Image | ||
import torch | ||
|
||
from sdk.models.flux import Flux | ||
|
||
|
||
class FluxWrapper: | ||
def __init__(self, *args, **kwargs): | ||
self._flux: Flux = None | ||
|
||
def run(self, n_items: int = 1, *args, **kwargs) -> PIL.Image.Image: | ||
if self._flux is None: | ||
self._flux = Flux() | ||
|
||
output_images = [] | ||
for i in range(n_items): | ||
img = self._flux.run(*args, **kwargs) | ||
output_images.append(img) | ||
|
||
torch.cuda.empty_cache() | ||
gc.collect() | ||
return output_images[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from sdk.cloud_storage.minio_storage import MinioStorage | ||
from sdk.utils.url_handler import cloud_upload | ||
|
||
|
||
class MinioWrapper: | ||
def __init__(self, *args, **kwargs): | ||
self.__minio = MinioStorage(*args, **kwargs) | ||
|
||
def upload(self, file): | ||
try: | ||
public_url, upload_result = cloud_upload(cloud_storage=self.__minio, local_path=file) | ||
print(f"File uploaded to: {public_url}") | ||
return public_url | ||
except Exception as e: | ||
return str(e) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import datetime | ||
|
||
from bson import ObjectId | ||
|
||
from gradio_demo.getting_started.models.user import User | ||
from sdk.database.mongo_db import MongoDB | ||
|
||
|
||
class MongoClientWrapper: | ||
def __init__( | ||
self, | ||
database: str, | ||
users_collection: str, | ||
requests_collection: str, | ||
username: str, | ||
password: str, | ||
endpoint: str, | ||
): | ||
self._mongo_client = MongoDB(username=username, password=password, endpoint=endpoint) | ||
self._database = database | ||
self._users_collection = users_collection | ||
self._request_collection = requests_collection | ||
|
||
def insert_request(self, data, *args, **kwargs): | ||
doc = {k.label: v for k, v in data.items()} | ||
return self.insert_one_request(doc, *args, **kwargs) | ||
|
||
def insert_one_request(self, data: dict, *args, **kwargs): | ||
create_time = datetime.datetime.now() | ||
data.update({'create_time': create_time}) | ||
result = self._mongo_client.insert_one(self._database, self._request_collection, data) | ||
inserted_id = str(result.inserted_id) | ||
return inserted_id | ||
|
||
def insert_email(self, email: str): | ||
query = {"email": email} | ||
query_result = self._mongo_client.find_one(self._database, self._users_collection, query) | ||
|
||
# Email already existed in db | ||
if query_result is not None: | ||
user = User(**query_result) | ||
else: | ||
# Create new user | ||
user = User(email=email) | ||
insert_result = self._mongo_client.insert_one( | ||
self._database, self._users_collection, document=user.model_dump() | ||
) | ||
print(f"Inserted new user with document id: {insert_result.inserted_id}") | ||
return user.user_id | ||
|
||
def get_latest_requests(self, limit: int = 50): | ||
return self._mongo_client.find(self._database, self._request_collection, query={}).sort('create_time', -1)[ | ||
:limit | ||
] | ||
|
||
def find_request(self, request_id: str): | ||
query = {"_id": ObjectId(request_id)} | ||
return self._mongo_client.find_one(self._database, self._request_collection, query=query) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from pydantic import BaseModel, field_validator | ||
|
||
|
||
class FluxInput(BaseModel): | ||
__MIN_SIZE__ = 512 | ||
__MAX_SIZE__ = 1920 | ||
__DEFAULT_SIZE__ = 720 | ||
|
||
__MAX_NUM_INFERENCE_STEPS__ = 10 | ||
__MIN_NUM_INFERENCE_STEPS__ = 1 | ||
__DEFAULT_NUM_INFERENCE_STEPS__ = 4 | ||
|
||
__MAX_GUIDANCE_SCALE__ = 5.0 | ||
__MIN_GUIDANCE_SCALE__ = 0.5 | ||
__DEFAULT_GUIDANCE_SCALE__ = 3.5 | ||
|
||
__DEFAULT_GENERATOR_SEED__ = 12345 | ||
|
||
prompt: str = "" | ||
width: int = __DEFAULT_SIZE__ | ||
height: int = __DEFAULT_SIZE__ | ||
num_inference_steps: int = __DEFAULT_NUM_INFERENCE_STEPS__ | ||
generator_seed: int = __DEFAULT_GENERATOR_SEED__ | ||
guidance_scale: float = __DEFAULT_GUIDANCE_SCALE__ | ||
|
||
@field_validator('width', 'height') | ||
def validate_sizes(cls, v): | ||
assert ( | ||
cls.__MIN_SIZE__ <= v <= cls.__MAX_SIZE__ | ||
), f"value must be in range {cls.__MIN_SIZE__} to {cls.__MAX_SIZE__}" | ||
|
||
@field_validator('num_inference_steps') | ||
def validate_num_inference_steps(cls, v): | ||
assert v <= 10, 'num_inference_steps must not be greater than 10' | ||
return v | ||
|
||
@field_validator('guidance_scale') | ||
def validate_guidance_scale(cls, v): | ||
assert ( | ||
cls.__MIN_GUIDANCE_SCALE__ <= v <= cls.__MAX_GUIDANCE_SCALE__ | ||
), f'guidance_scale must be in range {cls.__MIN_GUIDANCE_SCALE__} to {cls.__MAX_GUIDANCE_SCALE__}' | ||
|
||
@field_validator('num_inference_steps') | ||
def validate_num_inference_steps(cls, v): | ||
assert cls.__MIN_NUM_INFERENCE_STEPS__ <= v <= cls.__MAX_NUM_INFERENCE_STEPS__, ( | ||
f'num_inference_steps must be in range ' | ||
f'{cls.__MIN_NUM_INFERENCE_STEPS__} to {cls.__MAX_NUM_INFERENCE_STEPS__}' | ||
) |
Oops, something went wrong.