Skip to content

Commit

Permalink
[#6] Feat: Load example
Browse files Browse the repository at this point in the history
  • Loading branch information
jiooum committed Oct 13, 2024
1 parent b2a8b53 commit 3071ff5
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 1 deletion.
23 changes: 23 additions & 0 deletions gradio_demo/getting_started/core/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from gradio_demo.getting_started.core.mongo_client_wrapper import MongoClientWrapper
from gradio_demo.getting_started.models.inputs import FluxInput
from gradio_demo.getting_started.models.outputs import FluxOutput
from sdk.utils.download import download_file


class AppController:
Expand Down Expand Up @@ -65,6 +66,21 @@ def update_guidance_scale(self, user_id: str, guidance_scale: float):
self.__session_infor[user_id]["input"].guidance_scale = guidance_scale
return []

def update_output_url(self, user_id: str, output_url: str):
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}"
self.__session_infor[user_id]["output"].output_url = output_url
return []

def update_output_image(self, user_id: str, output_path: str):
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}"
self.__session_infor[user_id]["output"].output_path = output_path
return []

def update_request_id(self, user_id: str, request_id: str):
assert user_id in self.__session_infor, f"Not found session information of user id: {user_id}"
self.__session_infor[user_id]["output"].output_record_id = request_id
return []

def btn_start_demo_clicked(self):
session_id = self.create_new_session()
return session_id, gradio.Row(visible=True), gradio.Button(visible=False), gradio.Textbox(visible=True)
Expand Down Expand Up @@ -96,6 +112,13 @@ def btn_load_examples_clicked(self):
)
return gradio.Dataset(samples=_examples, visible=True)

def load_image_url(self, request_id: str) -> str:
_request_infor = self.__mongo_db.find_request(request_id=request_id)
_output = FluxOutput(**_request_infor['output'])
if not os.path.exists(_output.output_path):
download_file(url=_output.output_url, save_path=_output.output_path)
return _output.output_path

def run(self, session_id: str):
"""
Expand Down
6 changes: 6 additions & 0 deletions gradio_demo/getting_started/core/mongo_client_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import datetime

from bson import ObjectId

from gradio_demo.getting_started.models.user import User
from sdk.database.mongo_db import MongoDB

Expand Down Expand Up @@ -50,3 +52,7 @@ def get_latest_requests(self, limit: int = 50):
return self._mongo_client.find(self._database, self._request_collection, query={}).sort('create_time', -1)[
:limit
]

def find_request(self, request_id: str):
query = {"_id": ObjectId(request_id)}
return self._mongo_client.find_one(self._database, self._request_collection, query=query)
10 changes: 9 additions & 1 deletion gradio_demo/getting_started/ui/ui.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from cProfile import label

import gradio
from dependency_injector.wiring import Provide, inject
from streamlit.web.server import allow_cross_origin_requests

from gradio_demo.getting_started.containers.app_container import AppContainer
from gradio_demo.getting_started.core.controller import AppController
Expand Down Expand Up @@ -64,6 +67,7 @@ def make_app_ui(
with gradio.Row():
btn_generate_images = gradio.Button("Generate images")
btn_clear = gradio.ClearButton()
btn_load_examples = gradio.Button("Load examples")

with gradio.Column(show_progress=True):
gradio.Markdown(
Expand All @@ -76,7 +80,6 @@ def make_app_ui(
output_image = gradio.Image(type="filepath", format="png", show_download_button=True, interactive=False)

with gradio.Column(show_progress=True):
btn_load_examples = gradio.Button("Load examples")
examples = gradio.Examples(
examples=[
[
Expand Down Expand Up @@ -114,6 +117,11 @@ def make_app_ui(
num_inference_step.change(app_controller.update_num_inference_steps, [session_id, num_inference_step])
generator_seed.change(app_controller.update_generator_seed, [session_id, generator_seed])
guidance_scale.change(app_controller.update_guidance_scale, [session_id, guidance_scale])
request_id.change(app_controller.update_request_id, [session_id, request_id])
output_url.change(app_controller.update_output_url, [session_id, output_url])
output_image.change(app_controller.update_output_image, [session_id, output_image])

examples.dataset.click(app_controller.load_image_url, [request_id], [output_image])

btn_load_examples.click(app_controller.btn_load_examples_clicked, None, [examples.dataset])
btn_clear.add(
Expand Down
24 changes: 24 additions & 0 deletions sdk/utils/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os

import requests


def download_file(url: str, save_path: str = None, save_folder: str = '', file_name: str = None):
if save_path is None:
if file_name is None:
file_name = os.path.basename(url)
max_length_filename = os.pathconf('/', 'PC_NAME_MAX')
if len(file_name) > max_length_filename:
split_text = os.path.splitext(file_name)
file_name = split_text[0][: max_length_filename - len(split_text[1])] + split_text[1]
os.makedirs(save_folder, exist_ok=True)
save_path = f'{save_folder}/{file_name}'

response = requests.get(url, timeout=20)
assert response.status_code == 200, f"Failed to reach {url} with the status code of {response.status_code}"
with open(save_path, 'wb') as file:
file.write(response.content)

# Download success
assert os.path.isfile(save_path), f"Downloaded file is not found in: {save_path}"
return save_path

0 comments on commit 3071ff5

Please sign in to comment.