Skip to content

Commit

Permalink
hide cards for networks of incompatible stable diffusion version in L…
Browse files Browse the repository at this point in the history
…ora extra networks interface
  • Loading branch information
AUTOMATIC1111 committed Jul 17, 2023
1 parent f97e359 commit 699108b
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 15 deletions.
20 changes: 20 additions & 0 deletions extensions-builtin/Lora/network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from collections import namedtuple
import enum

from modules import sd_models, cache, errors, hashes, shared

Expand All @@ -8,6 +9,13 @@
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}


class SdVersion(enum.Enum):
Unknown = 1
SD1 = 2
SD2 = 3
SDXL = 4


class NetworkOnDisk:
def __init__(self, name, filename):
self.name = name
Expand Down Expand Up @@ -44,6 +52,18 @@ def read_metadata():
''
)

self.sd_version = self.detect_version()

def detect_version(self):
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
return SdVersion.SDXL
elif str(self.metadata.get('ss_v2', "")) == "True":
return SdVersion.SD2
elif len(self.metadata):
return SdVersion.SD1

return SdVersion.Unknown

def set_hash(self, v):
self.hash = v
self.shorthash = self.hash[0:12]
Expand Down
2 changes: 2 additions & 0 deletions extensions-builtin/Lora/scripts/lora_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def before_ui():
"sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
}))


Expand Down
20 changes: 15 additions & 5 deletions extensions-builtin/Lora/ui_edit_user_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,17 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
def __init__(self, ui, tabname, page):
super().__init__(ui, tabname, page)

self.select_sd_version = None

self.taginfo = None
self.edit_activation_text = None
self.slider_preferred_weight = None
self.edit_notes = None

def save_lora_user_metadata(self, name, desc, activation_text, preferred_weight, notes):
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes):
user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc
user_metadata["sd version"] = sd_version
user_metadata["activation text"] = activation_text
user_metadata["preferred weight"] = preferred_weight
user_metadata["notes"] = notes
Expand Down Expand Up @@ -112,11 +115,11 @@ def put_values_into_components(self, name):
gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]]

return [
*values[0:4],
*values[0:5],
item.get("sd_version", "Unknown"),
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
user_metadata.get('activation text', ''),
float(user_metadata.get('preferred weight', 0.0)),
user_metadata.get('notes', ''),
gr.update(visible=True if tags else False),
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
]
Expand All @@ -141,10 +144,15 @@ def generate_random_prompt_from_tags(self, tags):

return ", ".join(sorted(res))

def create_extra_default_items_in_left_column(self):

# this would be a lot better as gr.Radio but I can't make it work
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)

def create_editor(self):
self.create_default_editor_elems()

self.taginfo = gr.HighlightedText(label="Tags")
self.taginfo = gr.HighlightedText(label="Training dataset tags")
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)

Expand Down Expand Up @@ -178,10 +186,11 @@ def select_tag(activation_text, evt: gr.SelectData):
self.edit_description,
self.html_filedata,
self.html_preview,
self.edit_notes,
self.select_sd_version,
self.taginfo,
self.edit_activation_text,
self.slider_preferred_weight,
self.edit_notes,
row_random_prompt,
random_prompt,
]
Expand All @@ -192,6 +201,7 @@ def select_tag(activation_text, evt: gr.SelectData):

edited_components = [
self.edit_description,
self.select_sd_version,
self.edit_activation_text,
self.slider_preferred_weight,
self.edit_notes,
Expand Down
34 changes: 29 additions & 5 deletions extensions-builtin/Lora/ui_extra_networks_lora.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os

import network
import networks

from modules import shared, ui_extra_networks
from modules import shared, ui_extra_networks, paths
from modules.ui_extra_networks import quote_js
from ui_edit_user_metadata import LoraUserMetadataEditor

Expand All @@ -13,14 +15,13 @@ def __init__(self):
def refresh(self):
networks.list_available_networks()

def create_item(self, name, index=None):
def create_item(self, name, index=None, enable_filter=True):
lora_on_disk = networks.available_networks.get(name)

path, ext = os.path.splitext(lora_on_disk.filename)

alias = lora_on_disk.get_alias()

# in 1.5 filename changes to be full filename instead of path without extension, and metadata is dict instead of json string
item = {
"name": name,
"filename": lora_on_disk.filename,
Expand All @@ -30,6 +31,7 @@ def create_item(self, name, index=None):
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": lora_on_disk.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
"sd_version": lora_on_disk.sd_version.name,
}

self.read_user_metadata(item)
Expand All @@ -40,15 +42,37 @@ def create_item(self, name, index=None):
if activation_text:
item["prompt"] += " + " + quote_js(" " + activation_text)

sd_version = item["user_metadata"].get("sd version")
if sd_version in network.SdVersion.__members__:
item["sd_version"] = sd_version
sd_version = network.SdVersion[sd_version]
else:
sd_version = lora_on_disk.sd_version

if shared.opts.lora_show_all or not enable_filter:
pass
elif sd_version == network.SdVersion.Unknown:
model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1
if model_version.name in shared.opts.lora_hide_unknown_for_versions:
return None
elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
return None
elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2:
return None
elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1:
return None

return item

def list_items(self):
for index, name in enumerate(networks.available_networks):
item = self.create_item(name, index)
yield item

if item is not None:
yield item

def allowed_directories_for_previews(self):
return [shared.cmd_opts.lora_dir]
return [shared.cmd_opts.lora_dir, os.path.join(paths.models_path, "LyCORIS")]

def create_user_metadata_editor(self, ui, tabname):
return LoraUserMetadataEditor(ui, tabname, self)
2 changes: 1 addition & 1 deletion html/extra-networks-card.html
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
<div class='card' style={style} onclick={card_clicked} data-name="{name}" {sort_keys}>
{background_image}
<div class="button-row">
{edit_button}
{metadata_button}
{edit_button}
</div>
<div class='actions'>
<div class='additional'>
Expand Down
2 changes: 1 addition & 1 deletion javascript/extraNetworks.js
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ function popup(contents) {
globalPopupInner.classList.add('global-popup-inner');
globalPopup.appendChild(globalPopupInner);

gradioApp().appendChild(globalPopup);
gradioApp().querySelector('.main').appendChild(globalPopup);
}

globalPopupInner.innerHTML = '';
Expand Down
3 changes: 3 additions & 0 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)

model.is_sdxl = hasattr(model, 'conditioner')
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2

if model.is_sdxl:
sd_models_xl.extend_sdxl(model)

Expand Down
3 changes: 2 additions & 1 deletion modules/ui_extra_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""):
page = next(iter([x for x in extra_pages if x.name == page]), None)

try:
item = page.create_item(name)
item = page.create_item(name, enable_filter=False)
page.items[name] = item
except Exception as e:
errors.display(e, "creating item for extra network")
item = page.items.get(name)
Expand Down
7 changes: 6 additions & 1 deletion modules/ui_extra_networks_user_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,18 @@ def get_user_metadata(self, name):

return user_metadata

def create_extra_default_items_in_left_column(self):
pass

def create_default_editor_elems(self):
with gr.Row():
with gr.Column(scale=2):
self.edit_name = gr.HTML(elem_classes="extra-network-name")
self.edit_description = gr.Textbox(label="Description", lines=4)
self.html_filedata = gr.HTML()

self.create_extra_default_items_in_left_column()

with gr.Column(scale=1, min_width=0):
self.html_preview = gr.HTML()

Expand Down Expand Up @@ -111,7 +116,7 @@ def put_values_into_components(self, name):

table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'

return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', ''),
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')

def write_user_metadata(self, name, metadata):
item = self.page.items.get(name, {})
Expand Down
6 changes: 5 additions & 1 deletion style.css
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ footer {

.extra-network-cards .card .card-button {
text-shadow: 2px 2px 3px black;
padding: 0.25em;
padding: 0.25em 0.1em;
font-size: 200%;
width: 1.5em;
}
Expand Down Expand Up @@ -957,6 +957,10 @@ div.block.gradio-box.edit-user-metadata {
text-align: left;
}

.edit-user-metadata .file-metadata th, .edit-user-metadata .file-metadata td{
padding: 0.3em 1em;
}

.edit-user-metadata .wrap.translucent{
background: var(--body-background-fill);
}
Expand Down

0 comments on commit 699108b

Please sign in to comment.