Skip to content

Commit

Permalink
FEAT: Support sdxl-turbo (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Dec 26, 2023
1 parent c6b133d commit fa54e1b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
6 changes: 6 additions & 0 deletions xinference/model/image/model_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
"model_id": "stabilityai/sd-turbo",
"model_revision": "1681ed09e0cff58eeb41e878a49893228b78b94c"
},
{
"model_name": "sdxl-turbo",
"model_family": "stable_diffusion",
"model_id": "stabilityai/sdxl-turbo",
"model_revision": "f4b0486b498f84668e828044de1d0c8ba486e05b"
},
{
"model_name": "stable-diffusion-v1.5",
"model_family": "stable_diffusion",
Expand Down
8 changes: 6 additions & 2 deletions xinference/model/image/tests/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,19 @@ def test_restful_api_for_image_with_mlsd_controlnet(setup):
logger.info("test result %s", r)


def test_restful_api_for_sd_turbo(setup):
@pytest.mark.parametrize("model_name", ["sd-turbo", "sdxl-turbo"])
def test_restful_api_for_sd_turbo(setup, model_name):
if model_name == "sdxl-turbo":
pytest.skip("sdxl-turbo cost too many resources.")

endpoint, _ = setup
from ....client import Client

client = Client(endpoint)

model_uid = client.launch_model(
model_uid="my_controlnet",
model_name="sd-turbo",
model_name=model_name,
model_type="image",
)
model = client.get_model(model_uid)
Expand Down

0 comments on commit fa54e1b

Please sign in to comment.