diff --git a/InstantIDNode.py b/InstantIDNode.py index c8cce5c..4236ec6 100644 --- a/InstantIDNode.py +++ b/InstantIDNode.py @@ -237,6 +237,7 @@ def INPUT_TYPES(cls): "guidance_scale": ("FLOAT", {"default": 5, "min": 0, "max": 10, "display": "slider"}), "enhance_face_region": ("BOOLEAN", {"default": True}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "vram_optimisation": (['off', 'level_1', 'level_2'], {"default": "level_1"}), }, "optional": { "pose_image_optional": ("IMAGE",), @@ -247,7 +248,7 @@ def INPUT_TYPES(cls): FUNCTION = "id_generate_image" CATEGORY = "📷InstantID" - def id_generate_image(self, insightface, positive, negative, face_image, pipe, ip_adapter_scale, controlnet_conditioning_scale, steps, guidance_scale, seed, enhance_face_region, pose_image_optional=None): + def id_generate_image(self, insightface, positive, negative, face_image, pipe, ip_adapter_scale, controlnet_conditioning_scale, steps, guidance_scale, seed, enhance_face_region, pose_image_optional=None, vram_optimisation="level_1"): face_image = resize_img(face_image) @@ -297,7 +298,8 @@ def id_generate_image(self, insightface, positive, negative, face_image, pipe, i guidance_scale=guidance_scale, width=width, height=height, - return_dict=False + return_dict=False, + vram_optimisation=vram_optimisation ) # 检查输出类型并相应处理 diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py index 5d9bc4d..919ca5d 100644 --- a/ip_adapter/attention_processor.py +++ b/ip_adapter/attention_processor.py @@ -305,4 +305,148 @@ def forward( hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states \ No newline at end of file + return hidden_states + + +class IPAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def forward( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + with torch.no_grad(): + self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) + #print(self.attn_map.shape) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + # region control + if len(region_control.prompt_image_conditioning) == 1: + region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None) + if region_mask is not None: + query = query.reshape([-1, query.shape[-2], query.shape[-1]]) + h, w = region_mask.shape[:2] + ratio = (h * w / query.shape[1]) ** 0.5 + mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) + else: + mask = torch.ones_like(ip_hidden_states) + ip_hidden_states = ip_hidden_states * mask + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + diff --git a/pipeline_stable_diffusion_xl_instantid.py b/pipeline_stable_diffusion_xl_instantid.py index 89e6e7d..446cb86 100644 --- a/pipeline_stable_diffusion_xl_instantid.py +++ b/pipeline_stable_diffusion_xl_instantid.py @@ -42,7 +42,10 @@ from .ip_adapter.resampler import Resampler from .ip_adapter.utils import is_torch2_available -from .ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor +if is_torch2_available(): + from .ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor +else: + from .ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -132,7 +135,7 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2 class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): - def cuda(self, dtype=torch.float16, use_xformers=False): + def cuda(self, dtype=torch.float16, use_xformers=True): self.to('cuda', dtype) if hasattr(self, 'image_proj_model'): @@ -232,6 +235,20 @@ def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifie prompt_image_emb = self.image_proj_model(prompt_image_emb) return prompt_image_emb + def free_model_hooks(self): + r""" + Function that offloads all components, removes all model hooks that were added when using + `enable_model_cpu_offload` and then DOESN'T applies them again. + """ + if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: + # `enable_model_cpu_offload` has not be called, so silently do nothing + return + + for hook in self._all_hooks: + # offload model and remove hook from model + hook.offload() + hook.remove() + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -256,6 +273,7 @@ def __call__( image_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + vram_optimisation: Optional[str] = "off", cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, @@ -457,6 +475,16 @@ def __call__( self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self.free_model_hooks() + if vram_optimisation == 'level_1': + # Level 1 VRAM optimisation + # Move one model at once to the CPU - faster but needs more VRAM + self.enable_model_cpu_offload() + elif vram_optimisation == 'level_2' and self.device != torch.device("meta"): + # Level 2 VRAM optimisation + # Slower but need minimal VRAM + self.enable_sequential_cpu_offload() + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -745,7 +773,7 @@ def __call__( image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models - self.maybe_free_model_hooks() + self.free_model_hooks() if not return_dict: return (image,)