Skip to content

Commit

Permalink
Merge branch 'hyp-lora-support' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
DominikDoom committed Jan 24, 2023
2 parents e144f0d + 040be35 commit e418a86
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 45 deletions.
4 changes: 3 additions & 1 deletion javascript/_result.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ const ResultType = Object.freeze({
"embedding": 2,
"wildcardTag": 3,
"wildcardFile": 4,
"yamlWildcard": 5
"yamlWildcard": 5,
"hypernetwork": 6,
"lora": 7
});

// Class to hold result data and annotations to make it clearer to use
Expand Down
200 changes: 157 additions & 43 deletions javascript/tagAutocomplete.js
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ async function syncOptions() {
delayTime: opts["tac_delayTime"],
useWildcards: opts["tac_useWildcards"],
useEmbeddings: opts["tac_useEmbeddings"],
useHypernetworks: opts["tac_useHypernetworks"],
useLoras: opts["tac_useLoras"],
showWikiLinks: opts["tac_showWikiLinks"],
// Insertion related settings
replaceUnderscores: opts["tac_replaceUnderscores"],
Expand All @@ -196,7 +198,9 @@ async function syncOptions() {
extra: {
extraFile: opts["tac_extra.extraFile"],
onlyAliasExtraFile: opts["tac_extra.onlyAliasExtraFile"]
}
},
// Settings not from tac but still used by the script
extraNetworksDefaultMultiplier: opts["extra_networks_default_multiplier"]
}

if (CFG && CFG.colors) {
Expand Down Expand Up @@ -314,11 +318,15 @@ function insertTextAtCursor(textArea, result, tagword) {
sanitizedText = text.replaceAll("_", " "); // Replace underscores only if the yaml tag is not using them
} else if (tagType === ResultType.embedding) {
sanitizedText = `${text.replace(/^.*?: /g, "")}`;
} else if (tagType === ResultType.hypernetwork) {
sanitizedText = `<hypernet:${text}:${CFG.extraNetworksDefaultMultiplier}>`;
} else if(tagType === ResultType.lora) {
sanitizedText = `<lora:${text}:${CFG.extraNetworksDefaultMultiplier}>`;
} else {
sanitizedText = CFG.replaceUnderscores ? text.replaceAll("_", " ") : text;
}

if (CFG.escapeParentheses) {
if (CFG.escapeParentheses && tagType === ResultType.tag) {
sanitizedText = sanitizedText
.replaceAll("(", "\\(")
.replaceAll(")", "\\)")
Expand Down Expand Up @@ -483,41 +491,39 @@ function addResultsToList(textArea, results, tagword, resetList) {

// Add post count & color if it's a tag
// Wildcards & Embeds have no tag category
if (![ResultType.wildcardFile, ResultType.wildcardTag, ResultType.embedding].includes(result.type)) {
if (result.category) {
// Set the color of the tag
let cat = result.category;
let colorGroup = tagColors[tagFileName];
// Default to danbooru scheme if no matching one is found
if (!colorGroup)
colorGroup = tagColors["danbooru"];

// Set tag type to invalid if not found
if (!colorGroup[cat])
cat = "-1";

flexDiv.style = `color: ${colorGroup[cat][mode]};`;
}
if (result.category) {
// Set the color of the tag
let cat = result.category;
let colorGroup = tagColors[tagFileName];
// Default to danbooru scheme if no matching one is found
if (!colorGroup)
colorGroup = tagColors["danbooru"];

// Set tag type to invalid if not found
if (!colorGroup[cat])
cat = "-1";

flexDiv.style = `color: ${colorGroup[cat][mode]};`;
}

// Post count
if (result.count && !isNaN(result.count)) {
let postCount = result.count;
let formatter;
// Post count
if (result.count && !isNaN(result.count)) {
let postCount = result.count;
let formatter;

// Danbooru formats numbers with a padded fraction for 1M or 1k, but not for 10/100k
if (postCount >= 1000000 || (postCount >= 1000 && postCount < 10000))
formatter = Intl.NumberFormat("en", { notation: "compact", minimumFractionDigits: 1, maximumFractionDigits: 1 });
else
formatter = Intl.NumberFormat("en", {notation: "compact"});

let formattedCount = formatter.format(postCount);

let countDiv = document.createElement("div");
countDiv.textContent = formattedCount;
countDiv.classList.add("acMetaText");
flexDiv.appendChild(countDiv);
}
} else if (result.meta) { // Check if it is an embedding we have version info for
// Danbooru formats numbers with a padded fraction for 1M or 1k, but not for 10/100k
if (postCount >= 1000000 || (postCount >= 1000 && postCount < 10000))
formatter = Intl.NumberFormat("en", { notation: "compact", minimumFractionDigits: 1, maximumFractionDigits: 1 });
else
formatter = Intl.NumberFormat("en", {notation: "compact"});

let formattedCount = formatter.format(postCount);

let countDiv = document.createElement("div");
countDiv.textContent = formattedCount;
countDiv.classList.add("acMetaText");
flexDiv.appendChild(countDiv);
} else if (result.meta) { // Check if there is meta info to display
let metaDiv = document.createElement("div");
metaDiv.textContent = result.meta;
metaDiv.classList.add("acMetaText");
Expand Down Expand Up @@ -568,6 +574,8 @@ var wildcardExtFiles = [];
var yamlWildcards = [];
var umiPreviousTags = [];
var embeddings = [];
var hypernetworks = [];
var loras = [];
var results = [];
var tagword = "";
var originalTagword = "";
Expand Down Expand Up @@ -831,11 +839,11 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
originalTagword = tagword;
tagword = "";
}
} else if (CFG.useEmbeddings && tagword.match(/<[^,> ]*>?/g)) {
} else if (CFG.useEmbeddings && tagword.match(/<e:[^,> ]*>?/g)) {
// Show embeddings
let tempResults = [];
if (tagword !== "<") {
let searchTerm = tagword.replace("<", "")
if (tagword !== "<e:") {
let searchTerm = tagword.replace("<e:", "")
let versionString;
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
versionString = searchTerm.slice(0, 2);
Expand All @@ -848,6 +856,73 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
} else {
tempResults = embeddings;
}

// Add final results
tempResults.forEach(t => {
let result = new AutocompleteResult(t[0].trim(), ResultType.embedding)
result.meta = t[1] + " Embedding";
results.push(result);
});
} else if(CFG.useHypernetworks && tagword.match(/<h:[^,> ]*>?/g)) {
// Show hypernetworks
let tempResults = [];
if (tagword !== "<h:") {
let searchTerm = tagword.replace("<h:", "")
tempResults = hypernetworks.filter(x => x.toLowerCase().includes(searchTerm)); // Filter by tagword
} else {
tempResults = hypernetworks;
}

// Add final results
tempResults.forEach(t => {
let result = new AutocompleteResult(t.trim(), ResultType.hypernetwork)
result.meta = "Hypernetwork";
results.push(result);
});
} else if(CFG.useLoras && tagword.match(/<l:[^,> ]*>?/g)){
// Show lora
let tempResults = [];
if (tagword !== "<l:") {
let searchTerm = tagword.replace("<l:", "")
tempResults = loras.filter(x => x.toLowerCase().includes(searchTerm)); // Filter by tagword
} else {
tempResults = loras;
}

// Add final results
tempResults.forEach(t => {
let result = new AutocompleteResult(t.trim(), ResultType.lora)
result.meta = "Lora";
results.push(result);
});
} else if ((CFG.useEmbeddings || CFG.useHypernetworks || CFG.useLoras) && tagword.match(/<[^,> ]*>?/g)) {
// Embeddings, lora, wildcards all together with generic options
let tempEmbResults = [];
let tempHypResults = [];
let tempLoraResults = [];
if (tagword !== "<") {
let searchTerm = tagword.replace("<", "")

let versionString;
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
versionString = searchTerm.slice(0, 2);
searchTerm = searchTerm.slice(2);
}

if (versionString && CFG.useEmbeddings) {
// Version string is only for embeddings atm, so we don't search the other lists here.
tempEmbResults = embeddings.filter(x => x[0].toLowerCase().includes(searchTerm) && x[1] && x[1] === versionString); // Filter by tagword
} else {
tempEmbResults = embeddings.filter(x => x[0].toLowerCase().includes(searchTerm)); // Filter by tagword
tempHypResults = hypernetworks.filter(x => x.toLowerCase().includes(searchTerm)); // Filter by tagword
tempLoraResults = loras.filter(x => x.toLowerCase().includes(searchTerm)); // Filter by tagword
}
} else {
tempEmbResults = embeddings;
tempHypResults = hypernetworks;
tempLoraResults = loras;
}

// Since some tags are kaomoji, we have to still get the normal results first.
// Create escaped search regex with support for * as a start placeholder
let searchRegex;
Expand All @@ -860,11 +935,32 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, CFG.maxResults);

// Add final results
tempResults.forEach(t => {
let result = new AutocompleteResult(t[0].trim(), ResultType.embedding)
result.meta = t[1] + " Embedding";
results.push(result);
});
let mixedResults = [];
if (CFG.useEmbeddings) {
tempEmbResults.forEach(t => {
let result = new AutocompleteResult(t[0].trim(), ResultType.embedding)
result.meta = t[1] + " Embedding";
mixedResults.push(result);
});
}
if (CFG.useHypernetworks) {
tempHypResults.forEach(t => {
let result = new AutocompleteResult(t.trim(), ResultType.hypernetwork)
result.meta = "Hypernetwork";
mixedResults.push(result);
});
}
if (CFG.useLoras) {
tempLoraResults.forEach(t => {
let result = new AutocompleteResult(t.trim(), ResultType.lora)
result.meta = "Lora";
mixedResults.push(result);
});
}

// Add all mixed results to the final results, sorted by name so that they aren't after one another.
results = mixedResults.sort((a, b) => a.text.localeCompare(b.text));

genericResults.forEach(g => {
let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
result.category = g[1];
Expand Down Expand Up @@ -1080,6 +1176,24 @@ async function setup() {
console.error("Error loading embeddings.txt: " + e);
}
}
// Load hypernetworks
if (hypernetworks.length === 0) {
try {
hypernetworks = (await readFile(`${tagBasePath}/temp/hyp.txt?${new Date().getTime()}`)).split("\n")
.filter(x => x.trim().length > 0) //Remove empty lines
} catch (e) {
console.error("Error loading hypernetworks.txt: " + e);
}
}
// Load lora
if (loras.length === 0) {
try {
loras = (await readFile(`${tagBasePath}/temp/lora.txt?${new Date().getTime()}`)).split("\n")
.filter(x => x.trim().length > 0) // Remove empty lines
} catch (e) {
console.error("Error loading lora.txt: " + e);
}
}

// Find all textareas
let textAreas = getTextAreas();
Expand Down
33 changes: 32 additions & 1 deletion scripts/tag_autocomplete_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
# The path to the folder containing the wildcards and embeddings
WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards')
EMB_PATH = Path(shared.cmd_opts.embeddings_dir)
LORA_PATH = Path(shared.cmd_opts.lora_dir)
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir)


def find_ext_wildcard_paths():
Expand Down Expand Up @@ -137,6 +139,22 @@ def get_embeddings(sd_model):

write_to_temp_file('emb.txt', results)

def get_hypernetworks():
"""Write a list of all hypernetworks"""

# Get a list of all hypernetworks in the folder
all_hypernetworks = [str(h.name) for h in HYP_PATH.rglob("*") if h.suffix in {".pt"}]
# Remove file extensions
return [h[:h.rfind('.')] for h in all_hypernetworks]

def get_lora():
"""Write a list of all lora"""

# Get a list of all lora in the folder
all_lora = [str(l.name) for l in LORA_PATH.rglob("*") if l.suffix in {".safetensors", ".ckpt", ".pt"}]
# Remove file extensions
return [l[:l.rfind('.')] for l in all_lora]


def write_tag_base_path():
"""Writes the tag base path to a fixed location temporary file"""
Expand Down Expand Up @@ -178,6 +196,8 @@ def update_tag_files():
write_to_temp_file('wc.txt', [])
write_to_temp_file('wce.txt', [])
write_to_temp_file('wcet.txt', [])
write_to_temp_file('hyp.txt', [])
write_to_temp_file('lora.txt', [])
# Only reload embeddings if the file doesn't exist, since they are already re-written on model load
if not TEMP_PATH.joinpath("emb.txt").exists():
write_to_temp_file('emb.txt', [])
Expand All @@ -202,7 +222,16 @@ def update_tag_files():
if EMB_PATH.exists():
# Get embeddings after the model loaded callback
script_callbacks.on_model_loaded(get_embeddings)


if HYP_PATH.exists():
hypernets = get_hypernetworks()
if hypernets:
write_to_temp_file('hyp.txt', hypernets)

if LORA_PATH.exists():
lora = get_lora()
if lora:
write_to_temp_file('lora.txt', lora)

# Register autocomplete options
def on_ui_settings():
Expand All @@ -224,6 +253,8 @@ def on_ui_settings():
shared.opts.add_option("tac_delayTime", shared.OptionInfo(100, "Time in ms to wait before triggering completion again (Requires restart)", section=TAC_SECTION))
shared.opts.add_option("tac_useWildcards", shared.OptionInfo(True, "Search for wildcards", section=TAC_SECTION))
shared.opts.add_option("tac_useEmbeddings", shared.OptionInfo(True, "Search for embeddings", section=TAC_SECTION))
shared.opts.add_option("tac_useHypernetworks", shared.OptionInfo(True, "Search for hypernetworks", section=TAC_SECTION))
shared.opts.add_option("tac_useLoras", shared.OptionInfo(True, "Search for Loras", section=TAC_SECTION))
shared.opts.add_option("tac_showWikiLinks", shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page (Warning: This is an external site and very likely contains NSFW examples!)", section=TAC_SECTION))
# Insertion related settings
shared.opts.add_option("tac_replaceUnderscores", shared.OptionInfo(True, "Replace underscores with spaces on insertion", section=TAC_SECTION))
Expand Down

0 comments on commit e418a86

Please sign in to comment.