Skip to content

Commit

Permalink
optim: VRAM optimisation via offloading
Browse files Browse the repository at this point in the history
ref:
ZHO-ZHO-ZHO#87

Signed-off-by: Fu Lin <river@vvl.me>
  • Loading branch information
time-river committed Mar 23, 2024
1 parent cfdbfc7 commit 87d99ce
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 6 deletions.
6 changes: 4 additions & 2 deletions InstantIDNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def INPUT_TYPES(cls):
"guidance_scale": ("FLOAT", {"default": 5, "min": 0, "max": 10, "display": "slider"}),
"enhance_face_region": ("BOOLEAN", {"default": True}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"vram_optimisation": (['off', 'level_1', 'level_2'], {"default": "level_1"}),
},
"optional": {
"pose_image_optional": ("IMAGE",),
Expand All @@ -247,7 +248,7 @@ def INPUT_TYPES(cls):
FUNCTION = "id_generate_image"
CATEGORY = "📷InstantID"

def id_generate_image(self, insightface, positive, negative, face_image, pipe, ip_adapter_scale, controlnet_conditioning_scale, steps, guidance_scale, seed, enhance_face_region, pose_image_optional=None):
def id_generate_image(self, insightface, positive, negative, face_image, pipe, ip_adapter_scale, controlnet_conditioning_scale, steps, guidance_scale, seed, enhance_face_region, pose_image_optional=None, vram_optimisation="level_1"):

face_image = resize_img(face_image)

Expand Down Expand Up @@ -297,7 +298,8 @@ def id_generate_image(self, insightface, positive, negative, face_image, pipe, i
guidance_scale=guidance_scale,
width=width,
height=height,
return_dict=False
return_dict=False,
vram_optimisation=vram_optimisation
)

# 检查输出类型并相应处理
Expand Down
146 changes: 145 additions & 1 deletion ip_adapter/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,148 @@ def forward(

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states
return hidden_states


class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""

def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
super().__init__()

if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens

self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

def forward(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
with torch.no_grad():
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
#print(self.attn_map.shape)

ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)

# region control
if len(region_control.prompt_image_conditioning) == 1:
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
if region_mask is not None:
query = query.reshape([-1, query.shape[-2], query.shape[-1]])
h, w = region_mask.shape[:2]
ratio = (h * w / query.shape[1]) ** 0.5
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
else:
mask = torch.ones_like(ip_hidden_states)
ip_hidden_states = ip_hidden_states * mask

hidden_states = hidden_states + self.scale * ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states

34 changes: 31 additions & 3 deletions pipeline_stable_diffusion_xl_instantid.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
from .ip_adapter.resampler import Resampler
from .ip_adapter.utils import is_torch2_available

from .ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
if is_torch2_available():
from .ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
else:
from .ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -132,7 +135,7 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2

class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):

def cuda(self, dtype=torch.float16, use_xformers=False):
def cuda(self, dtype=torch.float16, use_xformers=True):
self.to('cuda', dtype)

if hasattr(self, 'image_proj_model'):
Expand Down Expand Up @@ -232,6 +235,20 @@ def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifie
prompt_image_emb = self.image_proj_model(prompt_image_emb)
return prompt_image_emb

def free_model_hooks(self):
r"""
Function that offloads all components, removes all model hooks that were added when using
`enable_model_cpu_offload` and then DOESN'T applies them again.
"""
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
# `enable_model_cpu_offload` has not be called, so silently do nothing
return

for hook in self._all_hooks:
# offload model and remove hook from model
hook.offload()
hook.remove()

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand All @@ -256,6 +273,7 @@ def __call__(
image_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
vram_optimisation: Optional[str] = "off",
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False,
Expand Down Expand Up @@ -457,6 +475,16 @@ def __call__(
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs

self.free_model_hooks()
if vram_optimisation == 'level_1':
# Level 1 VRAM optimisation
# Move one model at once to the CPU - faster but needs more VRAM
self.enable_model_cpu_offload()
elif vram_optimisation == 'level_2' and self.device != torch.device("meta"):
# Level 2 VRAM optimisation
# Slower but need minimal VRAM
self.enable_sequential_cpu_offload()

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -745,7 +773,7 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type)

# Offload all models
self.maybe_free_model_hooks()
self.free_model_hooks()

if not return_dict:
return (image,)
Expand Down

0 comments on commit 87d99ce

Please sign in to comment.