Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sd-webui extension #401

Merged
merged 8 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions onediff_sd_webui_extensions/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Stable-Diffusion-WebUI-OneDiff

- [Installation Guide](#installation-guide)
- [Extensions Usage](#extensions-usage)

## Installation Guide

It is recommended to create a Python virtual environment in advance. For example `conda create -n sd-webui python=3.10`.

```bash
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
git clone https://github.com/Oneflow-Inc/onediff.git
cd stable-diffusion-webui && git checkout 4afaaf8 # The tested git commit id is 4afaaf8.
cp -r ../onediff/onediff_sd_webui_extensions stable-diffusion-webui/extensions/

# Install all of stable-diffusion-webui's dependencies.
venv_dir=- bash webui.sh --port=8080

# Exit webui server and upgrade some of the components that conflict with onediff.
cd repositories/generative-models && git checkout 9d759324 && cd -
pip install -U einops==0.7.0
```

## Run stable-diffusion-webui service

```bash
cd stable-diffusion-webui
python webui.py --port 8080
```

Accessing http://server:8080/ from a web browser.

## Extensions Usage

Type prompt in the text box, such as `a black dog`. Click the `Generate` button in the upper right corner to generate the image. As you can see in the image below:

![raw_webui](images/raw_webui.jpg)

To enable OneDiff extension acceleration, select `onediff_diffusion_model` in Script and click the `Generate` button.

![onediff_script](images/onediff_script.jpg)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added onediff_sd_webui_extensions/images/raw_webui.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 13 additions & 0 deletions onediff_sd_webui_extensions/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import launch


def install():
if not launch.is_installed("oneflow"):
print("oneflow is not installed! Installing...")
launch.run_pip("install --pre oneflow -f https://oneflow-pro.oss-cn-beijing.aliyuncs.com/branch/community/cu118")
if not launch.is_installed("onediff"):
print("onediff is not installed! Installing...")
launch.run_pip("install git+https://github.com/Oneflow-Inc/onediff.git")


install()
205 changes: 205 additions & 0 deletions onediff_sd_webui_extensions/scripts/onediff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import modules.scripts as scripts
from modules import script_callbacks
import modules.shared as shared
from modules.processing import process_images

import math
import torch
import oneflow as flow
from einops import rearrange
from oneflow import nn, einsum
from sgm.modules.attention import default, CrossAttention
from sgm.modules.diffusionmodules.util import GroupNorm32
from omegaconf import OmegaConf, ListConfig
from onediff.infer_compiler.transform.builtin_transform import torch2oflow
from onediff.infer_compiler import oneflow_compile, register


@torch2oflow.register
def _(mod, verbose=False) -> ListConfig:
converted_list = [torch2oflow(item, verbose) for item in mod]
return OmegaConf.create(converted_list)


"""oneflow_compiled UNetModel"""
_compiled = None


# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/sd_hijack_optimizations.py#L142
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/sd_hijack_optimizations.py#L221
class CrossAttentionOflow(nn.Module):
strint marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
backend=None,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.scale = dim_head**-0.5
self.heads = heads

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.backend = backend

def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
h = self.heads

q_in = self.to_q(x)
context = default(context, x)

# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
context_k, context_v = context, context
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)

dtype = q_in.dtype
# from modules import shared
# if shared.opts.upcast_attn:
# q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()

# with devices.without_autocast(disable=not shared.opts.upcast_attn):
k_in = k_in * self.scale

del context, x

q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in

r1 = flow.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)

# mem_free_total = get_available_vram()
from modules import sd_hijack_optimizations
mem_free_total = sd_hijack_optimizations.get_available_vram()

gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1

if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))

if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')

slice_size = q.shape[1] // steps
for i in range(0, q.shape[1], slice_size):
end = min(i + slice_size, q.shape[1])
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)

# s2 = s1.softmax(dim=-1, dtype=q.dtype)
s2 = s1.softmax(dim=-1)
del s1

r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2

del q, k, v

r1 = r1.to(dtype)

r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1

return self.to_out(r2)


# https://github.com/Stability-AI/generative-models/blob/e5963321482a091a78375f3aeb2c3867562c913f/sgm/modules/diffusionmodules/wrappers.py#L24
def forward_wrapper( self, x, t, c, **kwargs):
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
with torch.autocast("cuda", enabled=False):
with flow.autocast("cuda", enabled=False):
return self.diffusion_model(
x.half(),
timesteps=t.half(),
context=c.get("crossattn", None).half(),
y=c.get("vector", None).half(),
**kwargs,
)


# https://github.com/Stability-AI/generative-models/blob/059d8e9cd9c55aea1ef2ece39abf605efb8b7cc9/sgm/modules/diffusionmodules/util.py#L274
class GroupNorm32Oflow(nn.GroupNorm):
def forward(self, x):
# return super().forward(x.float()).type(x.dtype)
return super().forward(x).type(x.dtype)


# https://github.com/Stability-AI/generative-models/blob/e5963321482a091a78375f3aeb2c3867562c913f/sgm/modules/diffusionmodules/openaimodel.py#L983-L984
class TimeEmbedModule(nn.Module):
def __init__( self, time_embed):
super().__init__()
self._time_embed_module = time_embed

def forward(self, t_emb):
return self._time_embed_module(t_emb.half())


torch2oflow_class_map = {
CrossAttention: CrossAttentionOflow,
GroupNorm32: GroupNorm32Oflow,
}
register(package_names=["sgm"], torch2oflow_class_map=torch2oflow_class_map)


def compile(sd_model):
unet_model = sd_model.model.diffusion_model
full_name = f"{unet_model.__module__}.{unet_model.__class__.__name__}"
if full_name != "sgm.modules.diffusionmodules.openaimodel.UNetModel":
return
global _compiled
_compiled = oneflow_compile(sd_model.model.diffusion_model, use_graph=True)
time_embed_wrapper = TimeEmbedModule(_compiled._deployable_module_model.oneflow_module.time_embed)
# https://github.com/Stability-AI/generative-models/blob/e5963321482a091a78375f3aeb2c3867562c913f/sgm/modules/diffusionmodules/openaimodel.py#L984
setattr(_compiled._deployable_module_model.oneflow_module, "time_embed", time_embed_wrapper)
# for refiner model
shared.sd_model.model.diffusion_model = _compiled


class Script(scripts.Script):
def title(self):
return "onediff_diffusion_model"

def show(self, is_img2img):
return not is_img2img

def run(self, p):
global _compiled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有 self 的话,就别用 global

把 _compiled 存在 Script 对象上?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile 函数中把编译结果赋值给 _compiled
而 compile 函数注册到 script_callbacks.on_model_loaded,当模型加载时进行编译。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._compiled 比较正常,全局变量不好

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是可以改的吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile 函数中拿不到具体的 Script 对象实例,也就无法访问其中的变量。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile(shared.sd_model, self)

这样就拿到了吧

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没有改?还是不能改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前回复到 Slack、没有同步到这里。
还不能这么改。因为注册到 script_callbacks 的函数只有一个入参 model。

if _compiled is None:
compile(shared.sd_model)
# compile(shared.sd_model)
original = shared.sd_model.model.diffusion_model
from sgm.modules.diffusionmodules.wrappers import OpenAIWrapper
orig_forward = OpenAIWrapper.forward
if _compiled is not None:
shared.sd_model.model.diffusion_model = _compiled
setattr(OpenAIWrapper, "forward", forward_wrapper)
proc = process_images(p)
shared.sd_model.model.diffusion_model = original
strint marked this conversation as resolved.
Show resolved Hide resolved
setattr(OpenAIWrapper, "forward", orig_forward)
return proc


script_callbacks.on_model_loaded(compile)
Loading