-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
jiooum
committed
Oct 6, 2024
1 parent
b28ef82
commit a9b048e
Showing
25 changed files
with
397 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
[flake8] | ||
max-line-length = 120 | ||
extend-ignore = E203, F401, E501, E402, E714 | ||
extend-ignore = E203, F401, E501, E402, E714, F811 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from gradio_demo.getting_started.containers.app_container import AppContainer | ||
from gradio_demo.getting_started.ui.ui import make_app_ui | ||
|
||
|
||
if __name__ == "__main__": | ||
container = AppContainer() | ||
container.config.from_yaml( | ||
"/mnt/data/0.Workspace/0.SourceCode/project_zero/gradio_demo/getting_started/configs.yaml" | ||
) | ||
container.wire(modules=["gradio_demo.getting_started.ui.ui"]) | ||
|
||
# Make ui | ||
ui = make_app_ui() | ||
# os.environ['GRADIO_TEMP_DIR'] = container.config.gradio.temp_dir() | ||
ui.launch( | ||
server_name=container.config.gradio.server_name(), | ||
server_port=container.config.gradio.server_port(), | ||
share=container.config.gradio.share(), | ||
) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from dependency_injector import containers, providers | ||
|
||
from gradio_demo.getting_started.core.flux_wrapper import FluxWrapper | ||
from gradio_demo.getting_started.core.minio_wrapper import MinioWrapper | ||
from gradio_demo.getting_started.core.mongo_client_wrapper import MongoClientWrapper | ||
|
||
|
||
class AppContainer(containers.DeclarativeContainer): | ||
config = providers.Configuration() | ||
|
||
minio_storage = providers.Singleton( | ||
MinioWrapper, | ||
endpoint=config.minio_storage.endpoint, | ||
access_key=config.minio_storage.access_key, | ||
secret_key=config.minio_storage.secret_key, | ||
bucket_name=config.minio_storage.bucket_name, | ||
secure=config.minio_storage.secure, | ||
) | ||
mongo_db = providers.Singleton( | ||
MongoClientWrapper, | ||
endpoint=config.mongo_db.endpoint, | ||
username=config.mongo_db.username, | ||
password=config.mongo_db.password, | ||
database=config.mongo_db.database, | ||
collection=config.mongo_db.collection, | ||
) | ||
flux = providers.Singleton(FluxWrapper) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import gc | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from sdk.models.flux import Flux | ||
|
||
|
||
class FluxWrapper: | ||
def __init__(self, *args, **kwargs): | ||
self._flux: Flux = None | ||
|
||
def run(self, n_items: int = 1, *args, **kwargs): | ||
if self._flux is None: | ||
self._flux = Flux() | ||
|
||
output_images = [] | ||
for i in range(n_items): | ||
img = self._flux.run(*args, **kwargs) | ||
output_images.append(img) | ||
torch.cuda.empty_cache() | ||
gc.collect() | ||
return output_images[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from sdk.cloud_storage.minio_storage import MinioStorage | ||
from sdk.utils.url_handler import cloud_upload | ||
|
||
|
||
class MinioWrapper: | ||
def __init__(self, *args, **kwargs): | ||
self.__minio = MinioStorage(*args, **kwargs) | ||
|
||
def upload(self, file): | ||
try: | ||
public_url, upload_result = cloud_upload(cloud_storage=self.__minio, local_path=file) | ||
print(f"File uploaded to: {public_url}") | ||
return public_url | ||
except Exception as e: | ||
return str(e) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import datetime | ||
import uuid | ||
|
||
from gradio_client.documentation import document | ||
|
||
from sdk.database.mongo_db import MongoDB | ||
|
||
|
||
class MongoClientWrapper: | ||
def __init__( | ||
self, | ||
database: str, | ||
collection: str, | ||
username: str, | ||
password: str, | ||
endpoint: str, | ||
): | ||
self._mongo_client = MongoDB(username=username, password=password, endpoint=endpoint) | ||
self._database = database | ||
self._collection = collection | ||
|
||
def create(self, data, *args, **kwargs): | ||
doc = {k.label: v for k, v in data.items()} | ||
create_time = datetime.datetime.now() | ||
doc.update({'create_time': create_time}) | ||
|
||
result = self._mongo_client.insert_one(self._database, self._collection, doc) | ||
inserted_id = str(result.inserted_id) | ||
print(f"Created request_id: {inserted_id}!") | ||
return inserted_id |
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from pydantic import BaseModel, field_validator | ||
|
||
|
||
class FluxInput(BaseModel): | ||
prompt: str | ||
width: int | ||
height: int | ||
num_inference_steps: int | ||
generator_seed: int | ||
guidance_scale: float | ||
|
||
__MIN_SIZE__ = 512 | ||
__MAX_SIZE__ = 1920 | ||
__DEFAULT_SIZE__ = 720 | ||
|
||
__MAX_NUM_INFERENCE_STEPS__ = 10 | ||
__MIN_NUM_INFERENCE_STEPS__ = 1 | ||
__DEFAULT_NUM_INFERENCE_STEPS__ = 4 | ||
|
||
__MAX_GUIDANCE_SCALE__ = 5.0 | ||
__MIN_GUIDANCE_SCALE__ = 0.5 | ||
__DEFAULT_GUIDANCE_SCALE__ = 3.5 | ||
|
||
__DEFAULT_GENERATOR_SEED__ = 12345 | ||
|
||
@field_validator('width', 'height') | ||
def validate_sizes(cls, v): | ||
assert ( | ||
cls.__MIN_SIZE__ <= v <= cls.__MAX_SIZE__ | ||
), f"value must be in range {cls.__MIN_SIZE__} to {cls.__MAX_SIZE__}" | ||
|
||
@field_validator('num_inference_steps') | ||
def validate_num_inference_steps(cls, v): | ||
assert v <= 10, 'num_inference_steps must not be greater than 10' | ||
return v | ||
|
||
@field_validator('guidance_scale') | ||
def validate_guidance_scale(cls, v): | ||
assert ( | ||
cls.__MAX_GUIDANCE_SCALE__ <= v <= cls.__MIN_GUIDANCE_SCALE__ | ||
), f'guidance_scale must be in range {cls.__MIN_GUIDANCE_SCALE__} to {cls.__MAX_GUIDANCE_SCALE__}' | ||
|
||
@field_validator('num_inference_steps') | ||
def validate_num_inference_steps(cls, v): | ||
assert cls.__MIN_NUM_INFERENCE_STEPS__ <= v <= cls.__MAX_NUM_INFERENCE_STEPS__, ( | ||
f'num_inference_steps must be in range ' | ||
f'{cls.__MIN_NUM_INFERENCE_STEPS__} to {cls.__MAX_NUM_INFERENCE_STEPS__}' | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
torch~=2.4.1 | ||
torchvision | ||
einops | ||
fire >= 0.6.0 | ||
huggingface-hub | ||
safetensors | ||
sentencepiece | ||
transformers~=4.44.2 | ||
tokenizers | ||
protobuf | ||
requests | ||
invisible-watermark | ||
optimum-quanto | ||
git+https://github.com/huggingface/diffusers.git | ||
dependency-injector~=4.42.0 | ||
minio~=7.2.9 | ||
gradio~=4.44.1 | ||
pydantic~=2.9.2 | ||
wget~=3.2 | ||
exceptiongroup~=1.2.2 | ||
optimum~=1.22.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .ui import make_app_ui |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import gradio | ||
from dependency_injector.providers import Configuration | ||
from dependency_injector.wiring import Provide, inject | ||
|
||
from gradio_demo.getting_started.containers.app_container import AppContainer | ||
from gradio_demo.getting_started.core.flux_wrapper import FluxWrapper | ||
from gradio_demo.getting_started.core.minio_wrapper import MinioWrapper | ||
from gradio_demo.getting_started.core.mongo_client_wrapper import MongoClientWrapper | ||
from gradio_demo.getting_started.models.inputs import FluxInput | ||
|
||
|
||
@inject | ||
def make_app_ui( | ||
minio_storage: MinioWrapper = Provide[AppContainer.minio_storage], | ||
flux: FluxWrapper = Provide[AppContainer.flux], | ||
mongo_db: MongoClientWrapper = Provide[AppContainer.mongo_db], | ||
config: Configuration = Provide[AppContainer.config], | ||
): | ||
with gradio.Blocks() as demo: | ||
with gradio.Row(show_progress=False): | ||
with gradio.Column(show_progress=False): | ||
gradio.Markdown( | ||
""" | ||
# Inputs | ||
""" | ||
) | ||
input_prompt = gradio.Textbox(label="Input prompt", interactive=True) | ||
image_width = gradio.Slider( | ||
label="Image width", | ||
minimum=FluxInput.__MIN_SIZE__, | ||
maximum=FluxInput.__MAX_SIZE__, | ||
value=FluxInput.__DEFAULT_SIZE__, | ||
step=8, | ||
interactive=True, | ||
) | ||
image_height = gradio.Slider( | ||
label="Image height", | ||
minimum=FluxInput.__MIN_SIZE__, | ||
maximum=FluxInput.__MAX_SIZE__, | ||
value=FluxInput.__DEFAULT_SIZE__, | ||
step=8, | ||
interactive=True, | ||
) | ||
num_inference_step = gradio.Slider( | ||
label="Number of inference steps", | ||
minimum=FluxInput.__MIN_NUM_INFERENCE_STEPS__, | ||
maximum=FluxInput.__MAX_NUM_INFERENCE_STEPS__, | ||
value=FluxInput.__DEFAULT_NUM_INFERENCE_STEPS__, | ||
step=1, | ||
interactive=True, | ||
) | ||
generator_seed = gradio.Number( | ||
label="Generator seed", value=FluxInput.__DEFAULT_GENERATOR_SEED__, interactive=True | ||
) | ||
guidance_scale = gradio.Slider( | ||
label="Guidance scale", | ||
minimum=FluxInput.__MIN_GUIDANCE_SCALE__, | ||
maximum=FluxInput.__MAX_GUIDANCE_SCALE__, | ||
value=FluxInput.__DEFAULT_GUIDANCE_SCALE__, | ||
interactive=True, | ||
) | ||
n_items = gradio.Number(value=1, visible=False) | ||
|
||
btn_generate_images = gradio.Button("Generate images") | ||
|
||
with gradio.Column(show_progress=True): | ||
gradio.Markdown( | ||
""" | ||
# Outputs | ||
""" | ||
) | ||
request_id = gradio.Textbox(label="Request ID") | ||
output_url = gradio.Textbox(label="Output url", interactive=False, visible=True) | ||
output_image = gradio.Image(type="filepath", format="png", show_download_button=True, interactive=False) | ||
btn_generate_images.click( | ||
flux.run, | ||
[n_items, input_prompt, image_width, image_height, num_inference_step, guidance_scale, generator_seed], | ||
[output_image], | ||
) | ||
output_image.change(minio_storage.upload, [output_image], [output_url]) | ||
output_url.change( | ||
mongo_db.create, | ||
inputs={ | ||
input_prompt, | ||
image_width, | ||
image_height, | ||
num_inference_step, | ||
guidance_scale, | ||
generator_seed, | ||
output_url, | ||
}, | ||
outputs=[request_id], | ||
) | ||
return demo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import urllib.parse | ||
|
||
import pymongo | ||
|
||
|
||
class MongoDB: | ||
def __init__(self, username: str, password: str, endpoint: str): | ||
mongodb_uri = f"mongodb://{username}:{urllib.parse.quote(password)}@{endpoint}" | ||
self._client = pymongo.MongoClient(mongodb_uri, authSource="admin", connect=True) | ||
|
||
def insert_one(self, database: str, collection: str, document: dict): | ||
collection = self._client.get_database(database).get_collection(collection) | ||
return collection.insert_one(document) | ||
|
||
def update_one(self, database: str, collection: str, document_id: str, document: dict): | ||
collection = self._client.get_database(database).get_collection(collection) | ||
collection.update_one( | ||
filter={"_id": document_id}, | ||
update={"$set": document}, | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
mongodb = MongoDB(username="admin", password="mongodb123", endpoint="localhost:27017") | ||
mongodb.insert_one(database="app", collection="request", document={"msg": "test"}) |
Empty file.
Oops, something went wrong.