Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stable-diffusion model wrapper #438

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
27 changes: 27 additions & 0 deletions examples/conversation_with_stablediffusion_model/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Conversation with Stable-diffusion model

This example will show
- How to use Stable Diffusion models in AgentScope.

In this example, you can interact in a conversational format to generate images.
Once the image is generated, the agent will respond with the local file path where the image is saved.

## Prerequisites

You need to satisfy the following requirements to run this example:

- Install Stable Diffusion Web UI by following the instructions at [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
- Launching the Stable Diffusion Web UI with arguments: --api
- Ensure that your host can successfully access `http://127.0.0.1:7860/`(default) or any other specified host and port you choose.
cmgzn marked this conversation as resolved.
Show resolved Hide resolved
- Install the latest version of AgentScope by
```bash
git clone https://github.com/modelscope/agentscope.git
cd agentscope
pip install -e .
```

## Running the Example
Run the example and input your questions.
```bash
python conversation_with_stablediffusion_model.py
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
"""A simple example for conversation between user and stable-diffusion agent."""
import agentscope
from agentscope.agents import DialogAgent
from agentscope.agents.user_agent import UserAgent


def main() -> None:
"""A basic conversation demo"""

agentscope.init(
model_configs=[
{
"model_type": "sd_txt2img",
"config_name": "sd",
"options": {
"sd_model_checkpoint": "xxxxxx",
"CLIP_stop_at_last_layers": 2,
}, # global settings, for detailed parameters
# please refer to 127.0.0.1:7860/docs#/default/get_config_sdapi_v1_options_get
cmgzn marked this conversation as resolved.
Show resolved Hide resolved
"generate_args": {
"steps": 50,
"n_iter": 1,
"override_settings": {
"CLIP_stop_at_last_layers": 3,
# settings effective only for this conversation
# The parameters are consistent with the global settings.
},
},
},
],
project="txt2img-Agent Conversation",
save_api_invoke=True,
)

# Init two agents
dialog_agent = DialogAgent(
name="Assistant",
sys_prompt="high definition,dreamy", # replace by your desired image style prompts
cmgzn marked this conversation as resolved.
Show resolved Hide resolved
model_config_name="sd", # replace by your model config name
)
user_agent = UserAgent()

# start the conversation between user and assistant
msg = None
while True:
msg = user_agent(msg)
if msg.content == "exit":
break
msg = dialog_agent(msg)


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions src/agentscope/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
from .yi_model import (
YiChatWrapper,
)
from .stablediffusion_model import(
StableDiffusionTxt2imgWrapper
)

__all__ = [
"ModelWrapperBase",
Expand All @@ -64,6 +67,7 @@
"ZhipuAIEmbeddingWrapper",
"LiteLLMChatWrapper",
"YiChatWrapper",
"StableDiffusionTxt2imgWrapper",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change name to StableDiffusionImageSynthesisWrapper

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change name to StableDiffusionImageSynthesisWrapper

Since the SD API also provides an img2img interface, I think that both txt2img and img2img represent image synthesis, I'm concerned that naming the wrapper StableDiffusionImageSynthesisWrapper to represent text-to-image functionality might lead to confusion in the future if we decide to add an image-to-image wrapper.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DavdGao Please hv a look into this, thx.

]


Expand Down
284 changes: 284 additions & 0 deletions src/agentscope/models/stablediffusion_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
# -*- coding: utf-8 -*-
"""Model wrapper for stable diffusion models."""
from abc import ABC
import base64
import json
import time
from typing import Any, Optional, Union, List, Sequence

import requests
from loguru import logger

from . import ModelWrapperBase, ModelResponse
from ..constants import _DEFAULT_MAX_RETRIES
from ..constants import _DEFAULT_RETRY_INTERVAL
from ..message import Msg
from ..manager import FileManager
from ..utils.common import _convert_to_str


class StableDiffusionWrapperBase(ModelWrapperBase, ABC):
"""The base class for stable-diffusion model wrappers.

To use SD-webui API, please
1. First download stable-diffusion-webui from https://github.com/AUTOMATIC1111/stable-diffusion-webui and
install it with 'webui-user.bat'
2. Move your checkpoint to 'models/Stable-diffusion' folder
3. Start launch.py with the '--api' parameter to start the server
After that, you can use the SD-webui API and
query the available parameters on the http://localhost:7860/docs page
"""

model_type: str = "stable_diffusion"

def __init__(
self,
config_name: str,
host: str = "127.0.0.1:7860",
cmgzn marked this conversation as resolved.
Show resolved Hide resolved
base_url: Optional[Union[str, None]] = None,
use_https: bool = False,
generate_args: dict = None,
headers: dict = None,
options: dict = None,
timeout: int = 30,
max_retries: int = _DEFAULT_MAX_RETRIES,
retry_interval: int = _DEFAULT_RETRY_INTERVAL,
**kwargs: Any,
) -> None:
"""
Initializes the SD-webui API client.

Args:
config_name (`str`):
The name of the model config.
host (`str`, default `"127.0.0.1:7860"`):
The host port of the stable-diffusion webui server.
base_url (`str`, default `None`):
Base URL for the stable-diffusion webui services. If not provided, it will be generated based on `host` and `use_https`.
use_https (`bool`, default `False`):
Whether to generate the base URL with HTTPS protocol or HTTP.
generate_args (`dict`, default `None`):
The extra keyword arguments used in SD api generation,
e.g. `{"steps": 50}`.
headers (`dict`, default `None`):
HTTP request headers.
options (`dict`, default `None`):
The keyword arguments to change the webui settings
such as model or CLIP skip, this changes will persist across sessions.
e.g. `{"sd_model_checkpoint": "Anything-V3.0-pruned", "CLIP_stop_at_last_layers": 2}`.
"""
# If base_url is not provided, construct it based on whether HTTPS is used
if base_url is None:
if use_https:
base_url = f"https://{host}"
else:
base_url = f"http://{host}"

self.base_url = base_url
self.options_url = f"{base_url}/sdapi/v1/options"
self.generate_args = generate_args or {}

# Initialize the HTTP session and update the request headers
self.session = requests.Session()
if headers:
self.session.headers.update(headers)

# Set options if provided
if options:
self._set_options(options)

# Get the default model name from the web-options
model_name = self._get_options()["sd_model_checkpoint"].split("[")[0].strip()
# Update the model name if override_settings is provided in generate_args
if self.generate_args.get("override_settings"):
model_name = generate_args["override_settings"].get(
"sd_model_checkpoint", model_name
)

super().__init__(config_name=config_name, model_name=model_name)

self.timeout = timeout
self.max_retries = max_retries
self.retry_interval = retry_interval

@property
def url(self):
"""SD-webui API endpoint URL"""
raise NotImplementedError()

def _get_options(self) -> dict:
response = self.session.get(url=self.options_url)
if response.status_code != 200:
logger.error(f"Failed to get options with {response.json()}")
raise RuntimeError(f"Failed to get options with {response.json()}")
return response.json()

def _set_options(self, options) -> None:
response = self.session.post(url=self.options_url, json=options)
if response.status_code != 200:
logger.error(json.dumps(options, indent=4))
raise RuntimeError(f"Failed to set options with {response.json()}")
else:
logger.info("Optionsset successfully")

def _invoke_model(self, payload: dict) -> dict:
"""Invoke SD webui API and record the invocation if needed"""
# step1: prepare post requests
for i in range(1, self.max_retries + 1):
response = self.session.post(url=self.url, json=payload)

if response.status_code == requests.codes.ok:
break

if i < self.max_retries:
logger.warning(
f"Failed to call the model with "
f"requests.codes == {response.status_code}, retry "
f"{i + 1}/{self.max_retries} times",
)
time.sleep(i * self.retry_interval)

# step2: record model invocation
# record the model api invocation, which will be skipped if
# `FileManager.save_api_invocation` is `False`
self._save_model_invocation(
arguments=payload,
response=response.json(),
)

# step3: return the response json
if response.status_code == requests.codes.ok:
return response.json()
else:
logger.error(json.dumps({"url": self.url, "json": payload}, indent=4))
raise RuntimeError(
f"Failed to call the model with {response.json()}",
)

def _parse_response(self, response: dict) -> ModelResponse:
"""Parse the response json data into ModelResponse"""
return ModelResponse(raw=response)

def __call__(self, **kwargs: Any) -> ModelResponse:
payload = {
**self.generate_args,
**kwargs,
}
response = self._invoke_model(payload)
return self._parse_response(response)



class StableDiffusionTxt2imgWrapper(StableDiffusionWrapperBase):
"""Stable Diffusion txt2img API wrapper"""

model_type: str = "sd_txt2img"

@property
def url(self):
return f"{self.base_url}/sdapi/v1/txt2img"

def _parse_response(self, response: dict) -> ModelResponse:
session_parameters = response["parameters"]
size = f"{session_parameters['width']}*{session_parameters['height']}"
image_count = session_parameters["batch_size"] * session_parameters["n_iter"]

self.monitor.update_image_tokens(
model_name=self.model_name,
image_count=image_count,
resolution=size,
)

# Get image base64code as a list
images = response["images"]
b64_images = [base64.b64decode(image) for image in images]

file_manager = FileManager.get_instance()
# Return local url
image_urls = [file_manager.save_image(_) for _ in b64_images]
text = "Image saved to " + "\n".join(image_urls)
return ModelResponse(text=text, image_urls=image_urls, raw=response)

def __call__(
self,
prompt: str,
**kwargs: Any,
) -> ModelResponse:
"""
Args:
prompt (`str`):
The prompt string to generate images from.
**kwargs (`Any`):
The keyword arguments to SD-webui txt2img API, e.g.
`n_iter`, `steps`, `seed`, `width`, etc. Please refer to
https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API#api-guide-by-kilvoctu
or http://localhost:7860/docs
for more detailed arguments.
Returns:
`ModelResponse`:
A list of image local urls in image_urls field and the
raw response in raw field.
"""
# step1: prepare keyword arguments
payload = {
"prompt": prompt,
**kwargs,
**self.generate_args,
}

# step2: forward to generate response
response = self._invoke_model(payload)

# step3: parse the response
return self._parse_response(response)

def format(self, *args: Msg | Sequence[Msg]) -> List[dict] | str:
# This is a temporary implementation to focus on the prompt
# on single-turn image generation by preserving only the system prompt and
# the last user message. This logic might change in the future to support
# more complex conversational scenarios
if len(args) == 0:
raise ValueError(
"At least one message should be provided. An empty message "
"list is not allowed.",
)

# Parse all information into a list of messages
input_msgs = []
for _ in args:
if _ is None:
continue
if isinstance(_, Msg):
input_msgs.append(_)
elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _):
input_msgs.extend(_)
else:
raise TypeError(
f"The input should be a Msg object or a list "
f"of Msg objects, got {type(_)}.",
)

# record user message history as a list of strings
user_messages = []
sys_prompt = None
for i, unit in enumerate(input_msgs):
if i == 0 and unit.role == "system":
# if system prompt is available, place it at the beginning
sys_prompt = _convert_to_str(unit.content)
elif unit.role == "user":
# Merge user messages into a conversation history prompt
user_messages.append(_convert_to_str(unit.content))
else:
continue

content_components = []
# Add system prompt at the beginning if provided
if sys_prompt:
content_components.append(sys_prompt)
# Add the last user message if the user messages is not empty
if len(user_messages) > 0:
content_components.append(user_messages[-1])

prompt = ",".join(content_components)

return prompt
Loading