-
Notifications
You must be signed in to change notification settings - Fork 26.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3f401cd
commit f34c734
Showing
6 changed files
with
50 additions
and
26 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import open_clip.tokenizer | ||
import torch | ||
|
||
from modules import sd_hijack_clip, devices | ||
from modules.shared import opts | ||
|
||
|
||
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords): | ||
def __init__(self, wrapped, hijack): | ||
super().__init__(wrapped, hijack) | ||
|
||
self.id_start = wrapped.config.bos_token_id | ||
self.id_end = wrapped.config.eos_token_id | ||
self.id_pad = wrapped.config.pad_token_id | ||
|
||
self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma | ||
|
||
def encode_with_transformers(self, tokens): | ||
# there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a | ||
# trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer | ||
# layer to work with - you have to use the last | ||
|
||
attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64) | ||
features = self.wrapped(input_ids=tokens, attention_mask=attention_mask) | ||
z = features['projection_state'] | ||
|
||
return z | ||
|
||
def encode_embedding_init_text(self, init_text, nvpt): | ||
embedding_layer = self.wrapped.roberta.embeddings | ||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] | ||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) | ||
|
||
return embedded |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
f34c734
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess, this is the last commit......... as of 2022! Happy New Year!