Skip to content

Commit

Permalink
Merge pull request #9 from greenmagenta/client-handling
Browse files Browse the repository at this point in the history
Update Client handling
  • Loading branch information
gsbm authored Jun 9, 2024
2 parents 0a0a42b + eccb5eb commit b14fe3b
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions autosculptor/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
from bpy.types import Operator
from .utils import ensure_gradio_installed, install_gradio

# Clients API config
client_config = {
"shap_e": "hysts/Shap-E",
"sdxl": "hysts/SDXL",
"one_2_3_45": "https://one-2-3-45-one-2-3-45.hf.space/",
"dreamgaussian": "https://jiawei011-dreamgaussian.hf.space/",
"instantmesh": "TencentARC/InstantMesh",
"triposr": "stabilityai/TripoSR"
}

class InstallDependenciesOperator(Operator):
bl_idname = "wm.install_dependencies"
bl_label = "Install Dependencies"
Expand Down Expand Up @@ -67,6 +77,9 @@ def run_pipeline(self, autosculptor_props):
image_height = autosculptor_props.image_height
api_key = autosculptor_props.api_key

if api_key == "":
api_key = None

for _ in range(batch_count):
# Get seed for generation
seed = autosculptor_props.seed
Expand Down Expand Up @@ -140,7 +153,8 @@ def generate_model(self, prompt, seed, guidance_scale, num_inference_steps, mode

def generate_shape_e_model(self, api_key, prompt, seed, guidance_scale, num_inference_steps):
from gradio_client import Client
client = Client("hysts/Shap-E", hf_token=api_key)

client = Client(client_config["shap_e"], hf_token=api_key) if api_key else Client(client_config["shap_e"])
result = client.predict(
prompt=prompt,
seed=seed,
Expand All @@ -152,7 +166,7 @@ def generate_shape_e_model(self, api_key, prompt, seed, guidance_scale, num_infe

def generate_sdxl_shape_e_model(self, api_key, prompt, seed, guidance_scale, num_inference_steps, image_width, image_height):
from gradio_client import Client, handle_file
client1 = Client("hysts/SDXL", hf_token=api_key)
client1 = Client(client_config["sdxl"], hf_token=api_key) if api_key else Client(client_config["sdxl"])
image = client1.predict(
prompt=prompt,
negative_prompt="",
Expand All @@ -167,13 +181,13 @@ def generate_sdxl_shape_e_model(self, api_key, prompt, seed, guidance_scale, num
)
image_path = image

client2 = Client("https://one-2-3-45-one-2-3-45.hf.space/", hf_token=api_key)
client2 = Client(client_config["one_2_3_45"], hf_token=api_key) if api_key else Client(client_config["one_2_3_45"])
segmented_img_filepath = client2.predict(
image_path,
api_name="/preprocess"
)
)

client3 = Client("hysts/Shap-E", hf_token=api_key)
client3 = Client(client_config["shap_e"], hf_token=api_key) if api_key else Client(client_config["shap_e"])
result = client3.predict(
image=handle_file(segmented_img_filepath),
seed=seed,
Expand All @@ -185,7 +199,7 @@ def generate_sdxl_shape_e_model(self, api_key, prompt, seed, guidance_scale, num

def generate_sdxl_dreamgaussian_model(self, api_key, prompt, seed, guidance_scale, num_inference_steps, image_width, image_height):
from gradio_client import Client, handle_file
client1 = Client("hysts/SDXL", hf_token=api_key)
client1 = Client(client_config["sdxl"], hf_token=api_key) if api_key else Client(client_config["sdxl"])
image = client1.predict(
prompt=prompt,
negative_prompt="",
Expand All @@ -200,17 +214,17 @@ def generate_sdxl_dreamgaussian_model(self, api_key, prompt, seed, guidance_scal
)
image_path = image

client2 = Client("https://one-2-3-45-one-2-3-45.hf.space/", hf_token=api_key)
client2 = Client(client_config["one_2_3_45"], hf_token=api_key) if api_key else Client(client_config["one_2_3_45"])
elevation_angle_deg = client2.predict(
image_path,
True,
api_name="/estimate_elevation"
)
)

if elevation_angle_deg < -90 or elevation_angle_deg > 90:
elevation_angle_deg = 0

client3 = Client("https://jiawei011-dreamgaussian.hf.space/", hf_token=api_key)
client3 = Client(client_config["dreamgaussian"], hf_token=api_key) if api_key else Client(client_config["dreamgaussian"])
result = client3.predict(
image_path,
True,
Expand All @@ -221,7 +235,7 @@ def generate_sdxl_dreamgaussian_model(self, api_key, prompt, seed, guidance_scal

def generate_sdxl_instantmesh_model(self, api_key, prompt, seed, guidance_scale, num_inference_steps, image_width, image_height):
from gradio_client import Client, handle_file
client1 = Client("hysts/SDXL", hf_token=api_key)
client1 = Client(client_config["sdxl"], hf_token=api_key) if api_key else Client(client_config["sdxl"])
image = client1.predict(
prompt=prompt,
negative_prompt="",
Expand All @@ -236,7 +250,7 @@ def generate_sdxl_instantmesh_model(self, api_key, prompt, seed, guidance_scale,
)
image_path = image

client2 = Client("TencentARC/InstantMesh", hf_token=api_key)
client2 = Client(client_config["instantmesh"], hf_token=api_key) if api_key else Client(client_config["instantmesh"])
processed_image = client2.predict(
input_image=handle_file(image_path),
do_remove_background=True,
Expand All @@ -257,7 +271,7 @@ def generate_sdxl_instantmesh_model(self, api_key, prompt, seed, guidance_scale,

def generate_sdxl_triposr_model(self, api_key, prompt, seed, guidance_scale, num_inference_steps, image_width, image_height):
from gradio_client import Client, handle_file
client1 = Client("hysts/SDXL", hf_token=api_key)
client1 = Client(client_config["sdxl"], hf_token=api_key) if api_key else Client(client_config["sdxl"])
image = client1.predict(
prompt=prompt,
negative_prompt="",
Expand All @@ -272,7 +286,7 @@ def generate_sdxl_triposr_model(self, api_key, prompt, seed, guidance_scale, num
)
image_path = image

client2 = Client("stabilityai/TripoSR", hf_token=api_key)
client2 = Client(client_config["triposr"], hf_token=api_key) if api_key else Client(client_config["triposr"])
processed_image = client2.predict(
handle_file(image_path),
True,
Expand Down

0 comments on commit b14fe3b

Please sign in to comment.