Skip to content

Commit

Permalink
Merge pull request #16164 from AUTOMATIC1111/sd3_textual_inversion
Browse files Browse the repository at this point in the history
sd3 TI support
  • Loading branch information
AUTOMATIC1111 authored Jul 13, 2024
2 parents 93c00b2 + 11cfe0d commit b4d62a0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
8 changes: 5 additions & 3 deletions modules/models/sd3/other_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast

from modules import sd_hijack


#################################################################################################
### Core/Utility
Expand Down Expand Up @@ -110,9 +112,9 @@ def forward(self, x, mask=None, intermediate_output=None):


class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key="clip_l"):
super().__init__()
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)

def forward(self, input_tokens):
Expand All @@ -127,7 +129,7 @@ def __init__(self, config_dict, dtype, device):
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)

Expand Down
6 changes: 5 additions & 1 deletion modules/models/sd3/sd3_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __getitem__(self, key):
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
"textual_inversion_key": "clip_g",
}

T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
Expand Down Expand Up @@ -204,7 +205,10 @@ def before_load_weights(self, state_dict):
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)

def encode_embedding_init_text(self, init_text, nvpt):
return torch.tensor([[0]], device=devices.device) # XXX
return self.model_lg.encode_embedding_init_text(init_text, nvpt)

def tokenize(self, texts):
return self.model_lg.tokenize(texts)

def medvram_modules(self):
return [self.clip_g, self.clip_l, self.t5xxl]
Expand Down
17 changes: 16 additions & 1 deletion modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,28 @@ def forward(self, input_ids):
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
emb = devices.cond_cast_unet(vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)

vecs.append(tensor)

return torch.stack(vecs)


class TextualInversionEmbeddings(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
super().__init__(num_embeddings, embedding_dim, **kwargs)

self.embeddings = model_hijack
self.textual_inversion_key = textual_inversion_key

@property
def wrapped(self):
return super().forward

def forward(self, input_ids):
return EmbeddingsWithFixes.forward(self, input_ids)


def add_circular_option_to_conv_2d():
conv2d_constructor = torch.nn.Conv2d.__init__

Expand Down

0 comments on commit b4d62a0

Please sign in to comment.