Skip to content

Commit

Permalink
Show version info for embeddings
Browse files Browse the repository at this point in the history
Also allows searching by version to quickly find v1 or v2 model embeddings
Closes #97
  • Loading branch information
DominikDoom committed Jan 1, 2023
1 parent b57042e commit 6deefda
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 12 deletions.
37 changes: 31 additions & 6 deletions javascript/tagAutocomplete.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ const styleColors = {
"--results-bg-odd": ["#111827", "#f9fafb"],
"--results-hover": ["#1f2937", "#f5f6f8"],
"--results-selected": ["#374151", "#e5e7eb"],
"--post-count-color": ["#6b6f7b", "#a2a9b4"]
"--post-count-color": ["#6b6f7b", "#a2a9b4"],
"--embedding-v1-color": ["lightsteelblue", "#2b5797"],
"--embedding-v2-color": ["skyblue", "#2d89ef"],
}
const browserVars = {
"--results-overflow-y": {
Expand Down Expand Up @@ -66,6 +68,12 @@ const autocompleteCSS = `
flex-grow: 1;
color: var(--post-count-color);
}
.acListItem.acEmbeddingV1 {
color: var(--embedding-v1-color);
}
.acListItem.acEmbeddingV2 {
color: var(--embedding-v2-color);
}
`;

// Parse the CSV file into a 2D array. Doesn't use regex, so it is very lightweight.
Expand Down Expand Up @@ -364,7 +372,7 @@ function insertTextAtCursor(textArea, result, tagword) {
} else if (tagType === "yamlWildcard" && !yamlWildcards.includes(text)) {
sanitizedText = text.replaceAll("_", " "); // Replace underscores only if the yaml tag is not using them
} else if (tagType === "embedding") {
sanitizedText = `<${text.replace(/^.*?: /g, "")}>`;
sanitizedText = `${text.replace(/^.*?: /g, "")}`;
} else {
sanitizedText = CFG.replaceUnderscores ? text.replaceAll("_", " ") : text;
}
Expand Down Expand Up @@ -422,7 +430,6 @@ function insertTextAtCursor(textArea, result, tagword) {
let match = surrounding.match(new RegExp(escapeRegExp(`${tagword}`), "i"));
let insert = surrounding.replace(match, sanitizedText);

let modifiedTagword = prompt.substring(0, editStart) + insert + prompt.substring(editEnd);
let umiSubPrompts = [...newPrompt.matchAll(UMI_PROMPT_REGEX)];

let umiTags = [];
Expand Down Expand Up @@ -549,6 +556,17 @@ function addResultsToList(textArea, results, tagword, resetList) {
countDiv.classList.add("acPostCount");
flexDiv.appendChild(countDiv);
}
} else if (result[1] === "embedding" && result[2]) { // Check if it is an embedding we have version info for
let versionDiv = document.createElement("div");
versionDiv.textContent = result[2];
versionDiv.classList.add("acPostCount");

if (result[2].startsWith("v1"))
itemText.classList.add("acEmbeddingV1");
else if (result[2].startsWith("v2"))
itemText.classList.add("acEmbeddingV2");

flexDiv.appendChild(versionDiv);
}

// Add listener
Expand Down Expand Up @@ -811,7 +829,14 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
// Show embeddings
let tempResults = [];
if (tagword !== "<") {
tempResults = embeddings.filter(x => x.toLowerCase().includes(tagword.replace("<", ""))) // Filter by tagword
let searchTerm = tagword.replace("<", "")
let versionString;
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
versionString = searchTerm.slice(0, 2);
searchTerm = searchTerm.slice(2);
}
let versionCondition = x => x[1] && x[1] === versionString;
tempResults = embeddings.filter(x => x[0].toLowerCase().includes(searchTerm) && versionCondition(x)); // Filter by tagword
} else {
tempResults = embeddings;
}
Expand All @@ -825,7 +850,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
}
genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, CFG.maxResults);
results = genericResults.concat(tempResults.map(x => ["Embeddings: " + x.trim(), "embedding"])); // Mark as embedding
results = tempResults.map(x => [x[0].trim(), "embedding", x[1] + " Embedding"]).concat(genericResults); // Mark as embedding
} else {
// Create escaped search regex with support for * as a start placeholder
let searchRegex;
Expand Down Expand Up @@ -1022,7 +1047,7 @@ async function setup() {
try {
embeddings = (await readFile(`${tagBasePath}/temp/emb.txt?${new Date().getTime()}`)).split("\n")
.filter(x => x.trim().length > 0) // Remove empty lines
.map(x => x.replace(".bin", "").replace(".pt", "").replace(".png", "")); // Remove file extensions
.map(x => x.trim().split(",")); // Split into name, version type pairs
} catch (e) {
console.error("Error loading embeddings.txt: " + e);
}
Expand Down
60 changes: 54 additions & 6 deletions scripts/tag_autocomplete_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import gradio as gr
from pathlib import Path
from modules import scripts, script_callbacks, shared
from modules import scripts, script_callbacks, shared, sd_hijack
import yaml
import time
import threading

# Webui root path
FILE_DIR = Path().absolute()
Expand Down Expand Up @@ -78,9 +80,54 @@ def get_ext_wildcard_tags():
output.append(f"{tag},{count}")
return output


def get_embeddings():
"""Returns a list of all embeddings"""
return [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt", ".png"}]
"""Write a list of all embeddings with their version"""
# Get a list of all embeddings in the folder
embs_in_dir = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}]
# Remove file extensions
embs_in_dir = [e[:e.rfind('.')] for e in embs_in_dir]

# Wait for all embeddings to be loaded
while len(sd_hijack.model_hijack.embedding_db.word_embeddings) + len(sd_hijack.model_hijack.embedding_db.skipped_embeddings) < len(embs_in_dir):
time.sleep(2) # Sleep for 2 seconds

# Get embedding dict from sd_hijack to separate v1/v2 embeddings
emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
# Get the shape of the first item in the dict
emb_a_shape = -1
if (len(emb_type_a) > 0):
emb_a_shape = next(iter(emb_type_a.items()))[1].shape

# Add embeddings to the correct list
V1_SHAPE = 768
V2_SHAPE = 1024
emb_v1 = []
emb_v2 = []

if (emb_a_shape == V1_SHAPE):
emb_v1 = list(emb_type_a.keys())
emb_v2 = list(emb_type_b)
elif (emb_a_shape == V2_SHAPE):
emb_v1 = list(emb_type_b)
emb_v2 = list(emb_type_a.keys())

# Create a new list to store the modified strings
results = []

# Iterate through each string in the big list
for string in embs_in_dir:
if string in emb_v1:
results.append(string + ",v1")
elif string in emb_v2:
results.append(string + ",v2")
# If the string is not in either, default to v1
# (we can't know what it is since the startup model loaded none of them, but it's probably v1 since v2 is newer)
else:
results.append(string + ",v1")

write_to_temp_file('emb.txt', results)


def write_tag_base_path():
Expand Down Expand Up @@ -143,9 +190,10 @@ def update_tag_files():

# Write embeddings to emb.txt if found
if EMB_PATH.exists():
embeddings = get_embeddings()
if embeddings:
write_to_temp_file('emb.txt', embeddings)
# We need to load the embeddings in a separate thread since we wait for them to be checked (after the model loads)
thread = threading.Thread(target=get_embeddings)
thread.start()


# Register autocomplete options
def on_ui_settings():
Expand Down

0 comments on commit 6deefda

Please sign in to comment.