From 992695db3a6f80d4b11e200cd41a5bd61fe31ae3 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Wed, 11 Sep 2024 11:19:48 +0800 Subject: [PATCH] update stablediffusion_model.py --- .../models/stablediffusion_model.py | 247 +++++++++++------- 1 file changed, 158 insertions(+), 89 deletions(-) diff --git a/src/agentscope/models/stablediffusion_model.py b/src/agentscope/models/stablediffusion_model.py index a01b1e8f2..3b0a166a1 100644 --- a/src/agentscope/models/stablediffusion_model.py +++ b/src/agentscope/models/stablediffusion_model.py @@ -2,19 +2,25 @@ """Model wrapper for stable diffusion models.""" from abc import ABC import base64 +import json +import time from typing import Any, Optional, Union, List, Sequence +import requests +from loguru import logger + from . import ModelWrapperBase, ModelResponse +from ..constants import _DEFAULT_MAX_RETRIES +from ..constants import _DEFAULT_RETRY_INTERVAL from ..message import Msg from ..manager import FileManager -import requests from ..utils.common import _convert_to_str class StableDiffusionWrapperBase(ModelWrapperBase, ABC): """The base class for stable-diffusion model wrappers. - To use SD API, please + To use SD-webui API, please 1. First download stable-diffusion-webui from https://github.com/AUTOMATIC1111/stable-diffusion-webui and install it with 'webui-user.bat' 2. Move your checkpoint to 'models/Stable-diffusion' folder @@ -23,77 +29,176 @@ class StableDiffusionWrapperBase(ModelWrapperBase, ABC): query the available parameters on the http://localhost:7860/docs page """ - model_type: str - """The type of the model wrapper, which is to identify the model wrapper - class in model configuration.""" - - options: dict - """A dict contains the options for stable-diffusion option API. - Modifications made through this parameter are persistent, meaning they will - remain in effect for subsequent generation requests until explicitly changed or reset. - e.g. {"sd_model_checkpoint": "Anything-V3.0-pruned", "CLIP_stop_at_last_layers": 2}""" + model_type: str = "stable_diffusion" def __init__( self, config_name: str, - options: dict = None, + host: str = "127.0.0.1:7860", + base_url: Optional[Union[str, None]] = None, + use_https: bool = False, generate_args: dict = None, - url: Optional[Union[str, None]] = None, + headers: dict = None, + options: dict = None, + timeout: int = 30, + max_retries: int = _DEFAULT_MAX_RETRIES, + retry_interval: int = _DEFAULT_RETRY_INTERVAL, **kwargs: Any, ) -> None: - """Initialize the model wrapper for SD-webui API. + """ + Initializes the SD-webui API client. Args: - options (`dict`, default `None`): + config_name (`str`): + The name of the model config. + host (`str`, default `"127.0.0.1:7860"`): + The host port of the stable-diffusion webui server. + base_url (`str`, default `None`): + Base URL for the stable-diffusion webui services. If not provided, it will be generated based on `host` and `use_https`. + use_https (`bool`, default `False`): + Whether to generate the base URL with HTTPS protocol or HTTP. + generate_args (`dict`, default `None`): + The extra keyword arguments used in SD api generation, + e.g. `{"steps": 50}`. + headers (`dict`, default `None`): + HTTP request headers. + options (`dict`, default `None`): The keyword arguments to change the webui settings such as model or CLIP skip, this changes will persist across sessions. e.g. `{"sd_model_checkpoint": "Anything-V3.0-pruned", "CLIP_stop_at_last_layers": 2}`. - generate_args (`dict`, default `None`): - The extra keyword arguments used in SD-webui api generation, - e.g. `steps`, `seed`. - url (`str`, default `None`): - The url of the SD-webui server. - Defaults to `None`, which is http://127.0.0.1:7860. """ - if url is None: - url = "http://127.0.0.1:7860" + # If base_url is not provided, construct it based on whether HTTPS is used + if base_url is None: + if use_https: + base_url = f"https://{host}" + else: + base_url = f"http://{host}" - self.url = url + self.base_url = base_url + self.options_url = f"{base_url}/sdapi/v1/options" self.generate_args = generate_args or {} - options_url = f"{self.url}/sdapi/v1/options" - # Get the current default model - default_model_name = ( - requests.get(options_url) - .json()["sd_model_checkpoint"] - .split("[")[0] - .strip() + # Initialize the HTTP session and update the request headers + self.session = requests.Session() + if headers: + self.session.headers.update(headers) + + # Set options if provided + if options: + self._set_options(options) + + # Get the default model name from the web-options + model_name = self._get_options()["sd_model_checkpoint"].split("[")[0].strip() + # Update the model name if override_settings is provided in generate_args + if self.generate_args.get("override_settings"): + model_name = generate_args["override_settings"].get( + "sd_model_checkpoint", model_name + ) + + super().__init__(config_name=config_name, model_name=model_name) + + self.timeout = timeout + self.max_retries = max_retries + self.retry_interval = retry_interval + + @property + def url(self): + """SD-webui API endpoint URL""" + raise NotImplementedError() + + def _get_options(self) -> dict: + response = self.session.get(url=self.options_url) + if response.status_code != 200: + logger.error(f"Failed to get options with {response.json()}") + raise RuntimeError(f"Failed to get options with {response.json()}") + return response.json() + + def _set_options(self, options) -> None: + response = self.session.post(url=self.options_url, json=options) + if response.status_code != 200: + logger.error(json.dumps(options, indent=4)) + raise RuntimeError(f"Failed to set options with {response.json()}") + else: + logger.info("Optionsset successfully") + + def _invoke_model(self, payload: dict) -> dict: + """Invoke SD webui API and record the invocation if needed""" + # step1: prepare post requests + for i in range(1, self.max_retries + 1): + response = self.session.post(url=self.url, json=payload) + + if response.status_code == requests.codes.ok: + break + + if i < self.max_retries: + logger.warning( + f"Failed to call the model with " + f"requests.codes == {response.status_code}, retry " + f"{i + 1}/{self.max_retries} times", + ) + time.sleep(i * self.retry_interval) + + # step2: record model invocation + # record the model api invocation, which will be skipped if + # `FileManager.save_api_invocation` is `False` + self._save_model_invocation( + arguments=payload, + response=response.json(), ) - if options is not None: - # Update webui options if needed - requests.post(options_url, json=options) - model_name = options.get("sd_model_checkpoint", default_model_name) + # step3: return the response json + if response.status_code == requests.codes.ok: + return response.json() else: - model_name = default_model_name + logger.error(json.dumps({"url": self.url, "json": payload}, indent=4)) + raise RuntimeError( + f"Failed to call the model with {response.json()}", + ) - super().__init__(config_name=config_name, model_name=model_name) + def _parse_response(self, response: dict) -> ModelResponse: + """Parse the response json data into ModelResponse""" + return ModelResponse(raw=response) + + def __call__(self, **kwargs: Any) -> ModelResponse: + payload = { + **self.generate_args, + **kwargs, + } + response = self._invoke_model(payload) + return self._parse_response(response) - def format( - self, - *args: Union[Msg, Sequence[Msg]], - ) -> Union[List[dict], str]: - raise RuntimeError( - f"Model Wrapper [{type(self).__name__}] doesn't " - f"need to format the input. Please try to use the " - f"model wrapper directly.", - ) class StableDiffusionTxt2imgWrapper(StableDiffusionWrapperBase): + """Stable Diffusion txt2img API wrapper""" model_type: str = "sd_txt2img" + @property + def url(self): + return f"{self.base_url}/sdapi/v1/txt2img" + + def _parse_response(self, response: dict) -> ModelResponse: + session_parameters = response["parameters"] + size = f"{session_parameters['width']}*{session_parameters['height']}" + image_count = session_parameters["batch_size"] * session_parameters["n_iter"] + + self.monitor.update_image_tokens( + model_name=self.model_name, + image_count=image_count, + resolution=size, + ) + + # Get image base64code as a list + images = response["images"] + b64_images = [base64.b64decode(image) for image in images] + + file_manager = FileManager.get_instance() + # Return local url + image_urls = [file_manager.save_image(_) for _ in b64_images] + text = "Image saved to " + "\n".join(image_urls) + return ModelResponse(text=text, image_urls=image_urls, raw=response) + def __call__( self, prompt: str, @@ -109,13 +214,11 @@ def __call__( https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API#api-guide-by-kilvoctu or http://localhost:7860/docs for more detailed arguments. - Returns: `ModelResponse`: A list of image local urls in image_urls field and the raw response in raw field. """ - # step1: prepare keyword arguments payload = { "prompt": prompt, @@ -124,49 +227,15 @@ def __call__( } # step2: forward to generate response - txt2img_url = f"{self.url}/sdapi/v1/txt2img" - response = requests.post(url=txt2img_url, json=payload) - - if response.status_code != requests.codes.ok: - error_msg = f" Status code: {response.status_code}," - raise RuntimeError(error_msg) - - # step3: record the model api invocation if needed - output = response.json() - self._save_model_invocation( - arguments={ - "model": self.model_name, - **payload, - }, - response=output, - ) - - # step4: update monitor accordingly - session_parameters = output["parameters"] - size = f"{session_parameters['width']}*{session_parameters['height']}" - image_count = session_parameters["batch_size"] * session_parameters["n_iter"] + response = self._invoke_model(payload) - self.monitor.update_image_tokens( - model_name=self.model_name, - image_count=image_count, - resolution=size, - ) - - # step5: return response - # Get image base64code as a list - images = output["images"] - b64_images = [base64.b64decode(image) for image in images] - - file_manager = FileManager.get_instance() - # Return local url - urls = [file_manager.save_image(_) for _ in b64_images] - text = "Image saved to " + "\n".join(urls) - return ModelResponse(text=text, image_urls=urls, raw=response) + # step3: parse the response + return self._parse_response(response) def format(self, *args: Msg | Sequence[Msg]) -> List[dict] | str: - # This is a temporary implementation to focus on the prompt - # on single-turn image generation by preserving only the system prompt and - # the last user message. This logic might change in the future to support + # This is a temporary implementation to focus on the prompt + # on single-turn image generation by preserving only the system prompt and + # the last user message. This logic might change in the future to support # more complex conversational scenarios if len(args) == 0: raise ValueError( @@ -204,7 +273,7 @@ def format(self, *args: Msg | Sequence[Msg]) -> List[dict] | str: content_components = [] # Add system prompt at the beginning if provided - if sys_prompt is not None: + if sys_prompt: content_components.append(sys_prompt) # Add the last user message if the user messages is not empty if len(user_messages) > 0: