diff --git a/.eslintrc.js b/.eslintrc.js index f33aca09fa0..4777c276e9b 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -87,5 +87,11 @@ module.exports = { modalNextImage: "readonly", // token-counters.js setupTokenCounters: "readonly", + // localStorage.js + localSet: "readonly", + localGet: "readonly", + localRemove: "readonly", + // resizeHandle.js + setupResizeHandle: "writable" } }; diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index d80b24e2bde..cf6a2be86fa 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -26,7 +26,7 @@ body: id: steps attributes: label: Steps to reproduce the problem - description: Please provide us with precise step by step information on how to reproduce the bug + description: Please provide us with precise step by step instructions on how to reproduce the bug value: | 1. Go to .... 2. Press .... @@ -37,64 +37,14 @@ body: id: what-should attributes: label: What should have happened? - description: Tell what you think the normal behavior should be + description: Tell us what you think the normal behavior should be validations: required: true - - type: input - id: commit - attributes: - label: Version or Commit where the problem happens - description: "Which webui version or commit are you running ? (Do not write *Latest Version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Version: v1.2.3** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)" - validations: - required: true - - type: dropdown - id: py-version - attributes: - label: What Python version are you running on ? - multiple: false - options: - - Python 3.10.x - - Python 3.11.x (above, no supported yet) - - Python 3.9.x (below, no recommended) - - type: dropdown - id: platforms - attributes: - label: What platforms do you use to access the UI ? - multiple: true - options: - - Windows - - Linux - - MacOS - - iOS - - Android - - Other/Cloud - - type: dropdown - id: device - attributes: - label: What device are you running WebUI on? - multiple: true - options: - - Nvidia GPUs (RTX 20 above) - - Nvidia GPUs (GTX 16 below) - - AMD GPUs (RX 6000 above) - - AMD GPUs (RX 5000 below) - - CPU - - Other GPUs - - type: dropdown - id: cross_attention_opt + - type: textarea + id: sysinfo attributes: - label: Cross attention optimization - description: What cross attention optimization are you using, Settings -> Optimizations -> Cross attention optimization - multiple: false - options: - - Automatic - - xformers - - sdp-no-mem - - sdp - - Doggettx - - V1 - - InvokeAI - - "None " + label: Sysinfo + description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file. validations: required: true - type: dropdown @@ -108,21 +58,7 @@ body: - Brave - Apple Safari - Microsoft Edge - - type: textarea - id: cmdargs - attributes: - label: Command Line Arguments - description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise. - render: Shell - validations: - required: true - - type: textarea - id: extensions - attributes: - label: List of extensions - description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise. - validations: - required: true + - Other - type: textarea id: logs attributes: diff --git a/CHANGELOG.md b/CHANGELOG.md index b18c6867348..1cd3572c8e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,158 @@ +## 1.6.0 + +### Features: + * refiner support [#12371](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12371) + * add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia videocards + * add style editor dialog + * hires fix: add an option to use a different checkpoint for second pass ([#12181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12181)) + * option to keep multiple loaded models in memory ([#12227](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12227)) + * new samplers: Restart, DPM++ 2M SDE Exponential, DPM++ 2M SDE Heun, DPM++ 2M SDE Heun Karras, DPM++ 2M SDE Heun Exponential, DPM++ 3M SDE, DPM++ 3M SDE Karras, DPM++ 3M SDE Exponential ([#12300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12300), [#12519](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12519), [#12542](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12542)) + * rework DDIM, PLMS, UniPC to use CFG denoiser same as in k-diffusion samplers: + * makes all of them work with img2img + * makes prompt composition posssible (AND) + * makes them available for SDXL + * always show extra networks tabs in the UI ([#11808](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11808)) + * use less RAM when creating models ([#11958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11958), [#12599](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12599)) + * textual inversion inference support for SDXL + * extra networks UI: show metadata for SD checkpoints + * checkpoint merger: add metadata support + * prompt editing and attention: add support for whitespace after the number ([ red : green : 0.5 ]) (seed breaking change) ([#12177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12177)) + * VAE: allow selecting own VAE for each checkpoint (in user metadata editor) + * VAE: add selected VAE to infotext + * options in main UI: add own separate setting for txt2img and img2img, correctly read values from pasted infotext, add setting for column count ([#12551](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12551)) + * add resize handle to txt2img and img2img tabs, allowing to change the amount of horizontable space given to generation parameters and resulting image gallery ([#12687](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12687), [#12723](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12723)) + * change default behavior for batching cond/uncond -- now it's on by default, and is disabled by an UI setting (Optimizatios -> Batch cond/uncond) - if you are on lowvram/medvram and are getting OOM exceptions, you will need to enable it + * show current position in queue and make it so that requests are processed in the order of arrival ([#12707](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12707)) + * add `--medvram-sdxl` flag that only enables `--medvram` for SDXL models + * prompt editing timeline has separate range for first pass and hires-fix pass (seed breaking change) ([#12457](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12457)) + +### Minor: + * img2img batch: RAM savings, VRAM savings, .tif, .tiff in img2img batch ([#12120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12120), [#12514](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12514), [#12515](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12515)) + * postprocessing/extras: RAM savings ([#12479](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12479)) + * XYZ: in the axis labels, remove pathnames from model filenames + * XYZ: support hires sampler ([#12298](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12298)) + * XYZ: new option: use text inputs instead of dropdowns ([#12491](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12491)) + * add gradio version warning + * sort list of VAE checkpoints ([#12297](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12297)) + * use transparent white for mask in inpainting, along with an option to select the color ([#12326](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12326)) + * move some settings to their own section: img2img, VAE + * add checkbox to show/hide dirs for extra networks + * Add TAESD(or more) options for all the VAE encode/decode operation ([#12311](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12311)) + * gradio theme cache, new gradio themes, along with explanation that the user can input his own values ([#12346](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12346), [#12355](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12355)) + * sampler fixes/tweaks: s_tmax, s_churn, s_noise, s_tmax ([#12354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12354), [#12356](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12356), [#12357](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12357), [#12358](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12358), [#12375](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12375), [#12521](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12521)) + * update README.md with correct instructions for Linux installation ([#12352](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12352)) + * option to not save incomplete images, on by default ([#12338](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12338)) + * enable cond cache by default + * git autofix for repos that are corrupted ([#12230](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12230)) + * allow to open images in new browser tab by middle mouse button ([#12379](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12379)) + * automatically open webui in browser when running "locally" ([#12254](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12254)) + * put commonly used samplers on top, make DPM++ 2M Karras the default choice + * zoom and pan: option to auto-expand a wide image, improved integration ([#12413](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12413), [#12727](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12727)) + * option to cache Lora networks in memory + * rework hires fix UI to use accordion + * face restoration and tiling moved to settings - use "Options in main UI" setting if you want them back + * change quicksettings items to have variable width + * Lora: add Norm module, add support for bias ([#12503](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12503)) + * Lora: output warnings in UI rather than fail for unfitting loras; switch to logging for error output in console + * support search and display of hashes for all extra network items ([#12510](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12510)) + * add extra noise param for img2img operations ([#12564](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12564)) + * support for Lora with bias ([#12584](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12584)) + * make interrupt quicker ([#12634](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12634)) + * configurable gallery height ([#12648](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12648)) + * make results column sticky ([#12645](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12645)) + * more hash filename patterns ([#12639](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12639)) + * make image viewer actually fit the whole page ([#12635](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12635)) + * make progress bar work independently from live preview display which results in it being updated a lot more often + * forbid Full live preview method for medvram and add a setting to undo the forbidding + * make it possible to localize tooltips and placeholders + * add option to align with sgm repo's sampling implementation ([#12818](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818)) + * Restore faces and Tiling generation parameters have been moved to settings out of main UI + * if you want to put them back into main UI, use `Options in main UI` setting on the UI page. + +### Extensions and API: + * gradio 3.41.2 + * also bump versions for packages: transformers, GitPython, accelerate, scikit-image, timm, tomesd + * support tooltip kwarg for gradio elements: gr.Textbox(label='hello', tooltip='world') + * properly clear the total console progressbar when using txt2img and img2img from API + * add cmd_arg --disable-extra-extensions and --disable-all-extensions ([#12294](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12294)) + * shared.py and webui.py split into many files + * add --loglevel commandline argument for logging + * add a custom UI element that combines accordion and checkbox + * avoid importing gradio in tests because it spams warnings + * put infotext label for setting into OptionInfo definition rather than in a separate list + * make `StableDiffusionProcessingImg2Img.mask_blur` a property, make more inline with PIL `GaussianBlur` ([#12470](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12470)) + * option to make scripts UI without gr.Group + * add a way for scripts to register a callback for before/after just a single component's creation + * use dataclass for StableDiffusionProcessing + * store patches for Lora in a specialized module instead of inside torch + * support http/https URLs in API ([#12663](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12663), [#12698](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12698)) + * add extra noise callback ([#12616](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12616)) + * dump current stack traces when exiting with SIGINT + * add type annotations for extra fields of shared.sd_model + +### Bug Fixes: + * Don't crash if out of local storage quota for javascriot localStorage + * XYZ plot do not fail if an exception occurs + * fix missing TI hash in infotext if generation uses both negative and positive TI ([#12269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12269)) + * localization fixes ([#12307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12307)) + * fix sdxl model invalid configuration after the hijack + * correctly toggle extras checkbox for infotext paste ([#12304](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12304)) + * open raw sysinfo link in new page ([#12318](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12318)) + * prompt parser: Account for empty field in alternating words syntax ([#12319](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12319)) + * add tab and carriage return to invalid filename chars ([#12327](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12327)) + * fix api only Lora not working ([#12387](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12387)) + * fix options in main UI misbehaving when there's just one element + * make it possible to use a sampler from infotext even if it's hidden in the dropdown + * fix styles missing from the prompt in infotext when making a grid of batch of multiplie images + * prevent bogus progress output in console when calculating hires fix dimensions + * fix --use-textbox-seed + * fix broken `Lora/Networks: use old method` option ([#12466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12466)) + * properly return `None` for VAE hash when using `--no-hashing` ([#12463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12463)) + * MPS/macOS fixes and optimizations ([#12526](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12526)) + * add second_order to samplers that mistakenly didn't have it + * when refreshing cards in extra networks UI, do not discard user's custom resolution + * fix processing error that happens if batch_size is not a multiple of how many prompts/negative prompts there are ([#12509](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12509)) + * fix inpaint upload for alpha masks ([#12588](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12588)) + * fix exception when image sizes are not integers ([#12586](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12586)) + * fix incorrect TAESD Latent scale ([#12596](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12596)) + * auto add data-dir to gradio-allowed-path ([#12603](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12603)) + * fix exception if extensuions dir is missing ([#12607](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12607)) + * fix issues with api model-refresh and vae-refresh ([#12638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12638)) + * fix img2img background color for transparent images option not being used ([#12633](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12633)) + * attempt to resolve NaN issue with unstable VAEs in fp32 mk2 ([#12630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12630)) + * implement missing undo hijack for SDXL + * fix xyz swap axes ([#12684](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12684)) + * fix errors in backup/restore tab if any of config files are broken ([#12689](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12689)) + * fix SD VAE switch error after model reuse ([#12685](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12685)) + * fix trying to create images too large for the chosen format ([#12667](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12667)) + * create Gradio temp directory if necessary ([#12717](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12717)) + * prevent possible cache loss if exiting as it's being written by using an atomic operation to replace the cache with the new version + * set devices.dtype_unet correctly + * run RealESRGAN on GPU for non-CUDA devices ([#12737](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737)) + * prevent extra network buttons being obscured by description for very small card sizes ([#12745](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12745)) + * fix error that causes some extra networks to be disabled if both and are present in the prompt + * fix defaults settings page breaking when any of main UI tabs are hidden + * fix incorrect save/display of new values in Defaults page in settings + * fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working + * fix an error that prevents VAE being reloaded after an option change if a VAE near the checkpoint exists ([#12797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737)) + * hide broken image crop tool ([#12792](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737)) + * don't show hidden samplers in dropdown for XYZ script ([#12780](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737)) + * fix style editing dialog breaking if it's opened in both img2img and txt2img tabs + * fix a bug allowing users to bypass gradio and API authentication (reported by vysecurity) + * fix notification not playing when built-in webui tab is inactive ([#12834](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12834)) + * honor `--skip-install` for extension installers ([#12832](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832)) + * don't print blank stdout in extension installers ([#12833](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832), [#12855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12855)) + * do not change quicksettings dropdown option when value returned is `None` ([#12854](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12854)) + * get progressbar to display correctly in extensions tab + + +## 1.5.2 + +### Bug Fixes: + * fix memory leak when generation fails + * update doggettx cross attention optimization to not use an unreasonable amount of memory in some edge cases -- suggestion by MorkTheOrk + + ## 1.5.1 ### Minor: diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 00000000000..2c781aff450 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,7 @@ +cff-version: 1.2.0 +message: "If you use this software, please cite it as below." +authors: + - given-names: AUTOMATIC1111 +title: "Stable Diffusion Web UI" +date-released: 2022-08-22 +url: "https://github.com/AUTOMATIC1111/stable-diffusion-webui" diff --git a/README.md b/README.md index b796d150041..4e08344008c 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ A browser interface based on Gradio library for Stable Diffusion. - Clip skip - Hypernetworks - Loras (same as Hypernetworks but more pretty) -- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt +- A separate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt - Can select to load a different VAE from settings screen - Estimated completion time in progress bar - API @@ -88,12 +88,15 @@ A browser interface based on Gradio library for Stable Diffusion. - [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions - Now without any bad letters! - Load checkpoints in safetensors format -- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64 +- Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64 - Now with a license! - Reorder elements in the UI from settings screen ## Installation and Running -Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. +Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for: +- [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) +- [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. +- [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page) Alternatively, use online services (like Google Colab): @@ -115,7 +118,7 @@ Alternatively, use online services (like Google Colab): 1. Install the dependencies: ```bash # Debian-based: -sudo apt install wget git python3 python3-venv +sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0 # Red Hat-based: sudo dnf install wget git python3 # Arch-based: @@ -123,7 +126,7 @@ sudo pacman -S wget git python3 ``` 2. Navigate to the directory you would like the webui to be installed and execute the following command: ```bash -bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh) +wget -q https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh ``` 3. Run `webui.sh`. 4. Check `webui-user.sh` for options. @@ -169,5 +172,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd - LyCORIS - KohakuBlueleaf +- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - (You) diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index ba2945c6fe1..005ff32cbe3 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def __init__(self): super().__init__('lora') + self.errors = {} + """mapping of network names to the number of errors the network had during operation""" + def activate(self, p, params_list): additional = shared.opts.sd_lora + self.errors.clear() + if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional): p.all_prompts = [x + f"" for x in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) @@ -56,4 +61,7 @@ def activate(self, p, params_list): p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes) def deactivate(self, p): - pass + if self.errors: + p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items())) + + self.errors.clear() diff --git a/extensions-builtin/Lora/lora_patches.py b/extensions-builtin/Lora/lora_patches.py new file mode 100644 index 00000000000..b394d8e9ed4 --- /dev/null +++ b/extensions-builtin/Lora/lora_patches.py @@ -0,0 +1,31 @@ +import torch + +import networks +from modules import patches + + +class LoraPatches: + def __init__(self): + self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward) + self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict) + self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward) + self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict) + self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward) + self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict) + self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward) + self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict) + self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward) + self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict) + + def undo(self): + self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward') + self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict') + self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward') + self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict') + self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward') + self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict') + self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward') + self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict') + self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward') + self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict') + diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 0a18d69eb26..d8e8dfb7ff0 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -133,7 +133,7 @@ def calc_scale(self): return 1.0 - def finalize_updown(self, updown, orig_weight, output_shape): + def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) @@ -145,7 +145,10 @@ def finalize_updown(self, updown, orig_weight, output_shape): if orig_weight.size().numel() == updown.size().numel(): updown = updown.reshape(orig_weight.shape) - return updown * self.calc_scale() * self.multiplier() + if ex_bias is not None: + ex_bias = ex_bias * self.multiplier() + + return updown * self.calc_scale() * self.multiplier(), ex_bias def calc_updown(self, target): raise NotImplementedError() diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index 109b4c2c594..bf6930e96c0 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -14,9 +14,14 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) self.weight = weights.w.get("diff") + self.ex_bias = weights.w.get("diff_b") def calc_updown(self, orig_weight): output_shape = self.weight.shape updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + if self.ex_bias is not None: + ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) + else: + ex_bias = None - return self.finalize_updown(updown, orig_weight, output_shape) + return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py new file mode 100644 index 00000000000..ce450158068 --- /dev/null +++ b/extensions-builtin/Lora/network_norm.py @@ -0,0 +1,28 @@ +import network + + +class ModuleTypeNorm(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["w_norm", "b_norm"]): + return NetworkModuleNorm(net, weights) + + return None + + +class NetworkModuleNorm(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w_norm = weights.w.get("w_norm") + self.b_norm = weights.w.get("b_norm") + + def calc_updown(self, orig_weight): + output_shape = self.w_norm.shape + updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + + if self.b_norm is not None: + ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + else: + ex_bias = None + + return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 17cbe1bb7fe..96f935b236f 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,12 +1,15 @@ +import logging import os import re +import lora_patches import network import network_lora import network_hada import network_ia3 import network_lokr import network_full +import network_norm import torch from typing import Union @@ -19,6 +22,7 @@ network_ia3.ModuleTypeIa3(), network_lokr.ModuleTypeLokr(), network_full.ModuleTypeFull(), + network_norm.ModuleTypeNorm(), ] @@ -31,6 +35,8 @@ "resnets": { "conv1": "in_layers_2", "conv2": "out_layers_3", + "norm1": "in_layers_0", + "norm2": "out_layers_0", "time_emb_proj": "emb_layers_1", "conv_shortcut": "skip_connection", } @@ -190,11 +196,19 @@ def load_network(name, network_on_disk): net.modules[key] = net_module if keys_failed_to_match: - print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}") + logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}") return net +def purge_networks_from_memory(): + while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0: + name = next(iter(networks_in_memory)) + networks_in_memory.pop(name, None) + + devices.torch_gc() + + def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): already_loaded = {} @@ -212,15 +226,19 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No failed_to_load_networks = [] - for i, name in enumerate(names): + for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): net = already_loaded.get(name, None) - network_on_disk = networks_on_disk[i] - if network_on_disk is not None: + if net is None: + net = networks_in_memory.get(name) + if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: try: net = load_network(name, network_on_disk) + + networks_in_memory.pop(name, None) + networks_in_memory[name] = net except Exception as e: errors.display(e, f"loading network {network_on_disk.filename}") continue @@ -231,7 +249,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No if net is None: failed_to_load_networks.append(name) - print(f"Couldn't find network with name {name}") + logging.info(f"Couldn't find network with name {name}") continue net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 @@ -240,23 +258,38 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No loaded_networks.append(net) if failed_to_load_networks: - sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks)) + sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) + purge_networks_from_memory() -def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + +def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): weights_backup = getattr(self, "network_weights_backup", None) + bias_backup = getattr(self, "network_bias_backup", None) - if weights_backup is None: + if weights_backup is None and bias_backup is None: return - if isinstance(self, torch.nn.MultiheadAttention): - self.in_proj_weight.copy_(weights_backup[0]) - self.out_proj.weight.copy_(weights_backup[1]) + if weights_backup is not None: + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) + + if bias_backup is not None: + if isinstance(self, torch.nn.MultiheadAttention): + self.out_proj.bias.copy_(bias_backup) + else: + self.bias.copy_(bias_backup) else: - self.weight.copy_(weights_backup) + if isinstance(self, torch.nn.MultiheadAttention): + self.out_proj.bias = None + else: + self.bias = None -def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): """ Applies the currently selected set of networks to the weights of torch layer self. If weights already have this particular set of networks applied, does nothing. @@ -271,7 +304,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) weights_backup = getattr(self, "network_weights_backup", None) - if weights_backup is None: + if weights_backup is None and wanted_names != (): + if current_names != (): + raise RuntimeError("no backup weights found and current weights are not unchanged") + if isinstance(self, torch.nn.MultiheadAttention): weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) else: @@ -279,21 +315,41 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.network_weights_backup = weights_backup + bias_backup = getattr(self, "network_bias_backup", None) + if bias_backup is None: + if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: + bias_backup = self.out_proj.bias.to(devices.cpu, copy=True) + elif getattr(self, 'bias', None) is not None: + bias_backup = self.bias.to(devices.cpu, copy=True) + else: + bias_backup = None + self.network_bias_backup = bias_backup + if current_names != wanted_names: network_restore_weights_from_backup(self) for net in loaded_networks: module = net.modules.get(network_layer_name, None) if module is not None and hasattr(self, 'weight'): - with torch.no_grad(): - updown = module.calc_updown(self.weight) - - if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: - # inpainting model. zero pad updown to make channel[1] 4 to 9 - updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) + try: + with torch.no_grad(): + updown, ex_bias = module.calc_updown(self.weight) + + if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + # inpainting model. zero pad updown to make channel[1] 4 to 9 + updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) + + self.weight += updown + if ex_bias is not None and hasattr(self, 'bias'): + if self.bias is None: + self.bias = torch.nn.Parameter(ex_bias) + else: + self.bias += ex_bias + except RuntimeError as e: + logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 - self.weight += updown - continue + continue module_q = net.modules.get(network_layer_name + "_q_proj", None) module_k = net.modules.get(network_layer_name + "_k_proj", None) @@ -301,21 +357,33 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn module_out = net.modules.get(network_layer_name + "_out_proj", None) if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: - with torch.no_grad(): - updown_q = module_q.calc_updown(self.in_proj_weight) - updown_k = module_k.calc_updown(self.in_proj_weight) - updown_v = module_v.calc_updown(self.in_proj_weight) - updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) - updown_out = module_out.calc_updown(self.out_proj.weight) - - self.in_proj_weight += updown_qkv - self.out_proj.weight += updown_out - continue + try: + with torch.no_grad(): + updown_q, _ = module_q.calc_updown(self.in_proj_weight) + updown_k, _ = module_k.calc_updown(self.in_proj_weight) + updown_v, _ = module_v.calc_updown(self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight) + + self.in_proj_weight += updown_qkv + self.out_proj.weight += updown_out + if ex_bias is not None: + if self.out_proj.bias is None: + self.out_proj.bias = torch.nn.Parameter(ex_bias) + else: + self.out_proj.bias += ex_bias + + except RuntimeError as e: + logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 + + continue if module is None: continue - print(f'failed to calculate network weights for layer {network_layer_name}') + logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation") + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 self.network_current_names = wanted_names @@ -342,7 +410,7 @@ def network_forward(module, input, original_forward): if module is None: continue - y = module.forward(y, input) + y = module.forward(input, y) return y @@ -354,44 +422,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): def network_Linear_forward(self, input): if shared.opts.lora_functional: - return network_forward(self, input, torch.nn.Linear_forward_before_network) + return network_forward(self, input, originals.Linear_forward) network_apply_weights(self) - return torch.nn.Linear_forward_before_network(self, input) + return originals.Linear_forward(self, input) def network_Linear_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs) + return originals.Linear_load_state_dict(self, *args, **kwargs) def network_Conv2d_forward(self, input): if shared.opts.lora_functional: - return network_forward(self, input, torch.nn.Conv2d_forward_before_network) + return network_forward(self, input, originals.Conv2d_forward) network_apply_weights(self) - return torch.nn.Conv2d_forward_before_network(self, input) + return originals.Conv2d_forward(self, input) def network_Conv2d_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) + return originals.Conv2d_load_state_dict(self, *args, **kwargs) + + +def network_GroupNorm_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, originals.GroupNorm_forward) + + network_apply_weights(self) + + return originals.GroupNorm_forward(self, input) + + +def network_GroupNorm_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return originals.GroupNorm_load_state_dict(self, *args, **kwargs) + + +def network_LayerNorm_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, originals.LayerNorm_forward) + + network_apply_weights(self) + + return originals.LayerNorm_forward(self, input) + + +def network_LayerNorm_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return originals.LayerNorm_load_state_dict(self, *args, **kwargs) def network_MultiheadAttention_forward(self, *args, **kwargs): network_apply_weights(self) - return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs) + return originals.MultiheadAttention_forward(self, *args, **kwargs) def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs) + return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs) def list_available_networks(): @@ -459,9 +557,14 @@ def infotext_pasted(infotext, params): params["Prompt"] += "\n" + "".join(added) +originals: lora_patches.LoraPatches = None + +extra_network_lora = None + available_networks = {} available_network_aliases = {} loaded_networks = [] +networks_in_memory = {} available_network_hash_lookup = {} forbidden_network_aliases = {} diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index cd28afc92e7..ef23968c563 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,57 +1,30 @@ import re -import torch import gradio as gr from fastapi import FastAPI import network import networks import lora # noqa:F401 +import lora_patches import extra_networks_lora import ui_extra_networks_lora from modules import script_callbacks, ui_extra_networks, extra_networks, shared + def unload(): - torch.nn.Linear.forward = torch.nn.Linear_forward_before_network - torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network - torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network - torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network - torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network - torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network + networks.originals.undo() def before_ui(): ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) - extra_network = extra_networks_lora.ExtraNetworkLora() - extra_networks.register_extra_network(extra_network) - extra_networks.register_extra_network_alias(extra_network, "lyco") - - -if not hasattr(torch.nn, 'Linear_forward_before_network'): - torch.nn.Linear_forward_before_network = torch.nn.Linear.forward - -if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'): - torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict + networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora() + extra_networks.register_extra_network(networks.extra_network_lora) + extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco") -if not hasattr(torch.nn, 'Conv2d_forward_before_network'): - torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward -if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'): - torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict - -if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'): - torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward - -if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'): - torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict - -torch.nn.Linear.forward = networks.network_Linear_forward -torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict -torch.nn.Conv2d.forward = networks.network_Conv2d_forward -torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict -torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward -torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict +networks.originals = lora_patches.LoraPatches() script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) @@ -65,6 +38,7 @@ def before_ui(): "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), + "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), })) @@ -121,3 +95,5 @@ def network_replacement(m): script_callbacks.on_infotext_pasted(infotext_pasted) + +shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 2ca997f7ce9..c7011909055 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -70,6 +70,7 @@ def get_metadata_table(self, name): metadata = item.get("metadata") or {} keys = { + 'ss_output_name': "Output name:", 'ss_sd_model_name': "Model:", 'ss_clip_skip': "Clip skip:", 'ss_network_module': "Kohya module:", @@ -167,7 +168,7 @@ def create_editor(self): random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) with gr.Column(scale=1, min_width=120): - generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg") + generate_random_prompt = gr.Button('Generate', size="lg", scale=1) self.edit_notes = gr.TextArea(label='Notes', lines=4) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 3629e5c0cf2..55409a7829d 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -25,9 +25,10 @@ def create_item(self, name, index=None, enable_filter=True): item = { "name": name, "filename": lora_on_disk.filename, + "shorthash": lora_on_disk.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(lora_on_disk.filename), + "search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""), "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": lora_on_disk.metadata, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js index 30199dcd60a..45c7600ac5f 100644 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -12,8 +12,22 @@ onUiLoaded(async() => { "Sketch": elementIDs.sketch }; + // Helper functions // Get active tab + + /** + * Waits for an element to be present in the DOM. + */ + const waitForElement = (id) => new Promise(resolve => { + const checkForElement = () => { + const element = document.querySelector(id); + if (element) return resolve(element); + setTimeout(checkForElement, 100); + }; + checkForElement(); + }); + function getActiveTab(elements, all = false) { const tabs = elements.img2imgTabs.querySelectorAll("button"); @@ -34,7 +48,7 @@ onUiLoaded(async() => { // Wait until opts loaded async function waitForOpts() { - for (;;) { + for (; ;) { if (window.opts && Object.keys(window.opts).length) { return window.opts; } @@ -42,6 +56,11 @@ onUiLoaded(async() => { } } + // Detect whether the element has a horizontal scroll bar + function hasHorizontalScrollbar(element) { + return element.scrollWidth > element.clientWidth; + } + // Function for defining the "Ctrl", "Shift" and "Alt" keys function isModifierKey(event, key) { switch (key) { @@ -201,7 +220,8 @@ onUiLoaded(async() => { canvas_hotkey_overlap: "KeyO", canvas_disabled_functions: [], canvas_show_tooltip: true, - canvas_blur_prompt: false + canvas_auto_expand: true, + canvas_blur_prompt: false, }; const functionMap = { @@ -249,7 +269,7 @@ onUiLoaded(async() => { input?.addEventListener("input", () => restoreImgRedMask(elements)); } - function applyZoomAndPan(elemId) { + function applyZoomAndPan(elemId, isExtension = true) { const targetElement = gradioApp().querySelector(elemId); if (!targetElement) { @@ -361,6 +381,12 @@ onUiLoaded(async() => { panY: 0 }; + if (isExtension) { + targetElement.style.overflow = "hidden"; + } + + targetElement.isZoomed = false; + fixCanvas(); targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`; @@ -371,8 +397,27 @@ onUiLoaded(async() => { toggleOverlap("off"); fullScreenMode = false; + const closeBtn = targetElement.querySelector("button[aria-label='Remove Image']"); + if (closeBtn) { + closeBtn.addEventListener("click", resetZoom); + } + + if (canvas && isExtension) { + const parentElement = targetElement.closest('[id^="component-"]'); + if ( + canvas && + parseFloat(canvas.style.width) > parentElement.offsetWidth && + parseFloat(targetElement.style.width) > parentElement.offsetWidth + ) { + fitToElement(); + return; + } + + } + if ( canvas && + !isExtension && parseFloat(canvas.style.width) > 865 && parseFloat(targetElement.style.width) > 865 ) { @@ -381,9 +426,6 @@ onUiLoaded(async() => { } targetElement.style.width = ""; - if (canvas) { - targetElement.style.height = canvas.style.height; - } } // Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements @@ -439,7 +481,7 @@ onUiLoaded(async() => { // Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables function updateZoom(newZoomLevel, mouseX, mouseY) { - newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15)); + newZoomLevel = Math.max(0.1, Math.min(newZoomLevel, 15)); elemData[elemId].panX += mouseX - (mouseX * newZoomLevel) / elemData[elemId].zoomLevel; @@ -450,6 +492,10 @@ onUiLoaded(async() => { targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`; toggleOverlap("on"); + if (isExtension) { + targetElement.style.overflow = "visible"; + } + return newZoomLevel; } @@ -472,10 +518,12 @@ onUiLoaded(async() => { fullScreenMode = false; elemData[elemId].zoomLevel = updateZoom( elemData[elemId].zoomLevel + - (operation === "+" ? delta : -delta), + (operation === "+" ? delta : -delta), zoomPosX - targetElement.getBoundingClientRect().left, zoomPosY - targetElement.getBoundingClientRect().top ); + + targetElement.isZoomed = true; } } @@ -489,10 +537,19 @@ onUiLoaded(async() => { //Reset Zoom targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`; + let parentElement; + + if (isExtension) { + parentElement = targetElement.closest('[id^="component-"]'); + } else { + parentElement = targetElement.parentElement; + } + + // Get element and screen dimensions const elementWidth = targetElement.offsetWidth; const elementHeight = targetElement.offsetHeight; - const parentElement = targetElement.parentElement; + const screenWidth = parentElement.clientWidth; const screenHeight = parentElement.clientHeight; @@ -545,8 +602,12 @@ onUiLoaded(async() => { if (!canvas) return; - if (canvas.offsetWidth > 862) { - targetElement.style.width = canvas.offsetWidth + "px"; + if (canvas.offsetWidth > 862 || isExtension) { + targetElement.style.width = (canvas.offsetWidth + 2) + "px"; + } + + if (isExtension) { + targetElement.style.overflow = "visible"; } if (fullScreenMode) { @@ -648,8 +709,48 @@ onUiLoaded(async() => { mouseY = e.offsetY; } + // Simulation of the function to put a long image into the screen. + // We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element. + // We hide the image and show it to the user when it is ready. + + targetElement.isExpanded = false; + function autoExpand() { + const canvas = document.querySelector(`${elemId} canvas[key="interface"]`); + if (canvas) { + if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) { + targetElement.style.visibility = "hidden"; + setTimeout(() => { + fitToScreen(); + resetZoom(); + targetElement.style.visibility = "visible"; + targetElement.isExpanded = true; + }, 10); + } + } + } + targetElement.addEventListener("mousemove", getMousePosition); + //observers + // Creating an observer with a callback function to handle DOM changes + const observer = new MutationObserver((mutationsList, observer) => { + for (let mutation of mutationsList) { + // If the style attribute of the canvas has changed, by observation it happens only when the picture changes + if (mutation.type === 'attributes' && mutation.attributeName === 'style' && + mutation.target.tagName.toLowerCase() === 'canvas') { + targetElement.isExpanded = false; + setTimeout(resetZoom, 10); + } + } + }); + + // Apply auto expand if enabled + if (hotkeysConfig.canvas_auto_expand) { + targetElement.addEventListener("mousemove", autoExpand); + // Set up an observer to track attribute changes + observer.observe(targetElement, {attributes: true, childList: true, subtree: true}); + } + // Handle events only inside the targetElement let isKeyDownHandlerAttached = false; @@ -754,6 +855,11 @@ onUiLoaded(async() => { if (isMoving && elemId === activeElement) { updatePanPosition(e.movementX, e.movementY); targetElement.style.pointerEvents = "none"; + + if (isExtension) { + targetElement.style.overflow = "visible"; + } + } else { targetElement.style.pointerEvents = "auto"; } @@ -764,13 +870,93 @@ onUiLoaded(async() => { isMoving = false; }; + // Checks for extension + function checkForOutBox() { + const parentElement = targetElement.closest('[id^="component-"]'); + if (parentElement.offsetWidth < targetElement.offsetWidth && !targetElement.isExpanded) { + resetZoom(); + targetElement.isExpanded = true; + } + + if (parentElement.offsetWidth < targetElement.offsetWidth && elemData[elemId].zoomLevel == 1) { + resetZoom(); + } + + if (parentElement.offsetWidth < targetElement.offsetWidth && targetElement.offsetWidth * elemData[elemId].zoomLevel > parentElement.offsetWidth && elemData[elemId].zoomLevel < 1 && !targetElement.isZoomed) { + resetZoom(); + } + } + + if (isExtension) { + targetElement.addEventListener("mousemove", checkForOutBox); + } + + + window.addEventListener('resize', (e) => { + resetZoom(); + + if (isExtension) { + targetElement.isExpanded = false; + targetElement.isZoomed = false; + } + }); + gradioApp().addEventListener("mousemove", handleMoveByKey); + + } - applyZoomAndPan(elementIDs.sketch); - applyZoomAndPan(elementIDs.inpaint); - applyZoomAndPan(elementIDs.inpaintSketch); + applyZoomAndPan(elementIDs.sketch, false); + applyZoomAndPan(elementIDs.inpaint, false); + applyZoomAndPan(elementIDs.inpaintSketch, false); // Make the function global so that other extensions can take advantage of this solution - window.applyZoomAndPan = applyZoomAndPan; + const applyZoomAndPanIntegration = async(id, elementIDs) => { + const mainEl = document.querySelector(id); + if (id.toLocaleLowerCase() === "none") { + for (const elementID of elementIDs) { + const el = await waitForElement(elementID); + if (!el) break; + applyZoomAndPan(elementID); + } + return; + } + + if (!mainEl) return; + mainEl.addEventListener("click", async() => { + for (const elementID of elementIDs) { + const el = await waitForElement(elementID); + if (!el) break; + applyZoomAndPan(elementID); + } + }, {once: true}); + }; + + window.applyZoomAndPan = applyZoomAndPan; // Only 1 elements, argument elementID, for example applyZoomAndPan("#txt2img_controlnet_ControlNet_input_image") + + window.applyZoomAndPanIntegration = applyZoomAndPanIntegration; // for any extension + + /* + The function `applyZoomAndPanIntegration` takes two arguments: + + 1. `id`: A string identifier for the element to which zoom and pan functionality will be applied on click. + If the `id` value is "none", the functionality will be applied to all elements specified in the second argument without a click event. + + 2. `elementIDs`: An array of string identifiers for elements. Zoom and pan functionality will be applied to each of these elements on click of the element specified by the first argument. + If "none" is specified in the first argument, the functionality will be applied to each of these elements without a click event. + + Example usage: + applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]); + In this example, zoom and pan functionality will be applied to the element with the identifier "txt2img_controlnet_ControlNet_input_image" upon clicking the element with the identifier "txt2img_controlnet". + */ + + // More examples + // Add integration with ControlNet txt2img One TAB + // applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]); + + // Add integration with ControlNet txt2img Tabs + // applyZoomAndPanIntegration("#txt2img_controlnet",Array.from({ length: 10 }, (_, i) => `#txt2img_controlnet_ControlNet-${i}_input_image`)); + + // Add integration with Inpaint Anything + // applyZoomAndPanIntegration("None", ["#ia_sam_image", "#ia_sel_mask"]); }); diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py index 380176ce26c..2d8d2d1c014 100644 --- a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py +++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py @@ -9,6 +9,7 @@ "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"), "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"), + "canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"), "canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"), "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}), })) diff --git a/extensions-builtin/canvas-zoom-and-pan/style.css b/extensions-builtin/canvas-zoom-and-pan/style.css index 6bcc9570c45..5d8054e6519 100644 --- a/extensions-builtin/canvas-zoom-and-pan/style.css +++ b/extensions-builtin/canvas-zoom-and-pan/style.css @@ -61,3 +61,6 @@ to {opacity: 1;} } +.styler { + overflow:inherit !important; +} \ No newline at end of file diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index a05e10d865a..983f87ff033 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,5 +1,7 @@ +import math + import gradio as gr -from modules import scripts, shared, ui_components, ui_settings +from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste from modules.ui_components import FormColumn @@ -19,18 +21,38 @@ def show(self, is_img2img): def ui(self, is_img2img): self.comps = [] self.setting_names = [] + self.infotext_fields = [] + extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img + + mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row(): - for setting_name in shared.opts.extra_options: - with FormColumn(): - comp = ui_settings.create_setting_component(setting_name) + with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group(): + + row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) + + for row in range(row_count): + with gr.Row(): + for col in range(shared.opts.extra_options_cols): + index = row * shared.opts.extra_options_cols + col + if index >= len(extra_options): + break + + setting_name = extra_options[index] - self.comps.append(comp) - self.setting_names.append(setting_name) + with FormColumn(): + comp = ui_settings.create_setting_component(setting_name) + + self.comps.append(comp) + self.setting_names.append(setting_name) + + setting_infotext_name = mapping.get(setting_name) + if setting_infotext_name is not None: + self.infotext_fields.append((comp, setting_infotext_name)) def get_settings_values(): - return [ui_settings.get_value_for_setting(key) for key in self.setting_names] + res = [ui_settings.get_value_for_setting(key) for key in self.setting_names] + return res[0] if len(res) == 1 else res interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False) @@ -43,6 +65,10 @@ def before_process(self, p, *args): shared.options_templates.update(shared.options_section(('ui', "User interface"), { - "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(), - "extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion") + "extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), + "extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(), + "extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui() })) + + diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js index 12cae4b7576..652f07ac7ec 100644 --- a/extensions-builtin/mobile/javascript/mobile.js +++ b/extensions-builtin/mobile/javascript/mobile.js @@ -20,7 +20,13 @@ function reportWindowSize() { var button = gradioApp().getElementById(tab + '_generate_box'); var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column'); target.insertBefore(button, target.firstElementChild); + + gradioApp().getElementById(tab + '_results').classList.toggle('mobile', currentlyMobile); } } window.addEventListener("resize", reportWindowSize); + +onUiLoaded(function() { + reportWindowSize(); +}); diff --git a/javascript/extensions.js b/javascript/extensions.js index 1f7254c5dfe..312131b76eb 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -33,7 +33,7 @@ function extensions_check() { var id = randomId(); - requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function() { + requestProgress(id, gradioApp().getElementById('extensions_installed_html'), null, function() { }); diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 5582a6e5d3b..493f31af28a 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -1,20 +1,38 @@ +function toggleCss(key, css, enable) { + var style = document.getElementById(key); + if (enable && !style) { + style = document.createElement('style'); + style.id = key; + style.type = 'text/css'; + document.head.appendChild(style); + } + if (style && !enable) { + document.head.removeChild(style); + } + if (style) { + style.innerHTML == ''; + style.appendChild(document.createTextNode(css)); + } +} + function setupExtraNetworksForTab(tabname) { gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); - var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea'); + var searchDiv = gradioApp().getElementById(tabname + '_extra_search'); + var search = searchDiv.querySelector('textarea'); var sort = gradioApp().getElementById(tabname + '_extra_sort'); var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); + var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs'); + var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); - search.classList.add('search'); - sort.classList.add('sort'); - sortOrder.classList.add('sortorder'); sort.dataset.sortkey = 'sortDefault'; - tabs.appendChild(search); + tabs.appendChild(searchDiv); tabs.appendChild(sort); tabs.appendChild(sortOrder); tabs.appendChild(refresh); + tabs.appendChild(showDirsDiv); var applyFilter = function() { var searchTerm = search.value.toLowerCase(); @@ -80,6 +98,15 @@ function setupExtraNetworksForTab(tabname) { }); extraNetworksApplyFilter[tabname] = applyFilter; + + var showDirsUpdate = function() { + var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }'; + toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked); + localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0); + }; + showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1; + showDirs.addEventListener("change", showDirsUpdate); + showDirsUpdate(); } function applyExtraNetworkFilter(tabname) { @@ -179,7 +206,7 @@ function saveCardPreview(event, tabname, filename) { } function extraNetworksSearchButton(tabs_id, event) { - var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea'); + var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); var button = event.target; var text = button.classList.contains("search-all") ? "" : button.textContent.trim(); @@ -222,6 +249,15 @@ function popup(contents) { globalPopup.style.display = "flex"; } +var storedPopupIds = {}; +function popupId(id) { + if (!storedPopupIds[id]) { + storedPopupIds[id] = gradioApp().getElementById(id); + } + + popup(storedPopupIds[id]); +} + function extraNetworksShowMetadata(text) { var elem = document.createElement('pre'); elem.classList.add('popup-metadata'); @@ -305,7 +341,7 @@ function extraNetworksRefreshSingleCard(page, tabname, name) { newDiv.innerHTML = data.html; var newCard = newDiv.firstElementChild; - newCard.style = ''; + newCard.style.display = ''; card.parentElement.insertBefore(newCard, card); card.parentElement.removeChild(card); } diff --git a/javascript/hints.js b/javascript/hints.js index 4167cb28b7c..6de9372e8ea 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -190,3 +190,14 @@ onUiUpdate(function(mutationRecords) { tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000); } }); + +onUiLoaded(function() { + for (var comp of window.gradio_config.components) { + if (comp.props.webui_tooltip && comp.props.elem_id) { + var elem = gradioApp().getElementById(comp.props.elem_id); + if (elem) { + elem.title = comp.props.webui_tooltip; + } + } + } +}); diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 677e95c1bc7..c21d396eefd 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -136,6 +136,11 @@ function setupImageForLightbox(e) { var event = isFirefox ? 'mousedown' : 'click'; e.addEventListener(event, function(evt) { + if (evt.button == 1) { + open(evt.target.src); + evt.preventDefault(); + return; + } if (!opts.js_modal_lightbox || evt.button != 0) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed); diff --git a/javascript/inputAccordion.js b/javascript/inputAccordion.js new file mode 100644 index 00000000000..f2839852ee7 --- /dev/null +++ b/javascript/inputAccordion.js @@ -0,0 +1,37 @@ +var observerAccordionOpen = new MutationObserver(function(mutations) { + mutations.forEach(function(mutationRecord) { + var elem = mutationRecord.target; + var open = elem.classList.contains('open'); + + var accordion = elem.parentNode; + accordion.classList.toggle('input-accordion-open', open); + + var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input"); + checkbox.checked = open; + updateInput(checkbox); + + var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); + if (extra) { + extra.style.display = open ? "" : "none"; + } + }); +}); + +function inputAccordionChecked(id, checked) { + var label = gradioApp().querySelector('#' + id + " .label-wrap"); + if (label.classList.contains('open') != checked) { + label.click(); + } +} + +onUiLoaded(function() { + for (var accordion of gradioApp().querySelectorAll('.input-accordion')) { + var labelWrap = accordion.querySelector('.label-wrap'); + observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']}); + + var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); + if (extra) { + labelWrap.insertBefore(extra, labelWrap.lastElementChild); + } + } +}); diff --git a/javascript/localStorage.js b/javascript/localStorage.js new file mode 100644 index 00000000000..dc1a36c3287 --- /dev/null +++ b/javascript/localStorage.js @@ -0,0 +1,26 @@ + +function localSet(k, v) { + try { + localStorage.setItem(k, v); + } catch (e) { + console.warn(`Failed to save ${k} to localStorage: ${e}`); + } +} + +function localGet(k, def) { + try { + return localStorage.getItem(k); + } catch (e) { + console.warn(`Failed to load ${k} from localStorage: ${e}`); + } + + return def; +} + +function localRemove(k) { + try { + return localStorage.removeItem(k); + } catch (e) { + console.warn(`Failed to remove ${k} from localStorage: ${e}`); + } +} diff --git a/javascript/localization.js b/javascript/localization.js index eb22b8a7e99..8f00c186860 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -11,11 +11,11 @@ var ignore_ids_for_localization = { train_hypernetwork: 'OPTION', txt2img_styles: 'OPTION', img2img_styles: 'OPTION', - setting_random_artist_categories: 'SPAN', - setting_face_restoration_model: 'SPAN', - setting_realesrgan_enabled_models: 'SPAN', - extras_upscaler_1: 'SPAN', - extras_upscaler_2: 'SPAN', + setting_random_artist_categories: 'OPTION', + setting_face_restoration_model: 'OPTION', + setting_realesrgan_enabled_models: 'OPTION', + extras_upscaler_1: 'OPTION', + extras_upscaler_2: 'OPTION', }; var re_num = /^[.\d]+$/; @@ -107,12 +107,41 @@ function processNode(node) { }); } +function localizeWholePage() { + processNode(gradioApp()); + + function elem(comp) { + var elem_id = comp.props.elem_id ? comp.props.elem_id : "component-" + comp.id; + return gradioApp().getElementById(elem_id); + } + + for (var comp of window.gradio_config.components) { + if (comp.props.webui_tooltip) { + let e = elem(comp); + + let tl = e ? getTranslation(e.title) : undefined; + if (tl !== undefined) { + e.title = tl; + } + } + if (comp.props.placeholder) { + let e = elem(comp); + let textbox = e ? e.querySelector('[placeholder]') : null; + + let tl = textbox ? getTranslation(textbox.placeholder) : undefined; + if (tl !== undefined) { + textbox.placeholder = tl; + } + } + } +} + function dumpTranslations() { if (!hasLocalization()) { // If we don't have any localization, // we will not have traversed the app to find // original_lines, so do that now. - processNode(gradioApp()); + localizeWholePage(); } var dumped = {}; if (localization.rtl) { @@ -154,7 +183,7 @@ document.addEventListener("DOMContentLoaded", function() { }); }); - processNode(gradioApp()); + localizeWholePage(); if (localization.rtl) { // if the language is from right to left, (new MutationObserver((mutations, observer) => { // wait for the style to load diff --git a/javascript/notification.js b/javascript/notification.js index 76c5715dab4..6d79956125c 100644 --- a/javascript/notification.js +++ b/javascript/notification.js @@ -15,7 +15,7 @@ onAfterUiUpdate(function() { } } - const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] .thumbnail-item > img'); + const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"] div[id$="_results"] .thumbnail-item > img'); if (galleryPreviews == null) return; diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 29299787e30..777614954b2 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -69,7 +69,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre var dateStart = new Date(); var wasEverActive = false; var parentProgressbar = progressbarContainer.parentNode; - var parentGallery = gallery ? gallery.parentNode : null; var divProgress = document.createElement('div'); divProgress.className = 'progressDiv'; @@ -80,32 +79,26 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre divProgress.appendChild(divInner); parentProgressbar.insertBefore(divProgress, progressbarContainer); - if (parentGallery) { - var livePreview = document.createElement('div'); - livePreview.className = 'livePreview'; - parentGallery.insertBefore(livePreview, gallery); - } + var livePreview = null; var removeProgressBar = function() { + if (!divProgress) return; + setTitle(""); parentProgressbar.removeChild(divProgress); - if (parentGallery) parentGallery.removeChild(livePreview); + if (gallery && livePreview) gallery.removeChild(livePreview); atEnd(); + + divProgress = null; }; - var fun = function(id_task, id_live_preview) { - request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) { + var funProgress = function(id_task) { + request("./internal/progress", {id_task: id_task, live_preview: false}, function(res) { if (res.completed) { removeProgressBar(); return; } - var rect = progressbarContainer.getBoundingClientRect(); - - if (rect.width) { - divProgress.style.width = rect.width + "px"; - } - let progressText = ""; divInner.style.width = ((res.progress || 0) * 100.0) + '%'; @@ -119,7 +112,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre progressText += " ETA: " + formatTime(res.eta); } - setTitle(progressText); if (res.textinfo && res.textinfo.indexOf("\n") == -1) { @@ -142,16 +134,33 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre return; } + if (onProgress) { + onProgress(res); + } - if (res.live_preview && gallery) { - rect = gallery.getBoundingClientRect(); - if (rect.width) { - livePreview.style.width = rect.width + "px"; - livePreview.style.height = rect.height + "px"; - } + setTimeout(() => { + funProgress(id_task, res.id_live_preview); + }, opts.live_preview_refresh_period || 500); + }, function() { + removeProgressBar(); + }); + }; + var funLivePreview = function(id_task, id_live_preview) { + request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) { + if (!divProgress) { + return; + } + + if (res.live_preview && gallery) { var img = new Image(); img.onload = function() { + if (!livePreview) { + livePreview = document.createElement('div'); + livePreview.className = 'livePreview'; + gallery.insertBefore(livePreview, gallery.firstElementChild); + } + livePreview.appendChild(img); if (livePreview.childElementCount > 2) { livePreview.removeChild(livePreview.firstElementChild); @@ -160,18 +169,18 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre img.src = res.live_preview; } - - if (onProgress) { - onProgress(res); - } - setTimeout(() => { - fun(id_task, res.id_live_preview); + funLivePreview(id_task, res.id_live_preview); }, opts.live_preview_refresh_period || 500); }, function() { removeProgressBar(); }); }; - fun(id_task, 0); + funProgress(id_task, 0); + + if (gallery) { + funLivePreview(id_task, 0); + } + } diff --git a/javascript/resizeHandle.js b/javascript/resizeHandle.js new file mode 100644 index 00000000000..8c5c5169210 --- /dev/null +++ b/javascript/resizeHandle.js @@ -0,0 +1,141 @@ +(function() { + const GRADIO_MIN_WIDTH = 320; + const GRID_TEMPLATE_COLUMNS = '1fr 16px 1fr'; + const PAD = 16; + const DEBOUNCE_TIME = 100; + + const R = { + tracking: false, + parent: null, + parentWidth: null, + leftCol: null, + leftColStartWidth: null, + screenX: null, + }; + + let resizeTimer; + let parents = []; + + function setLeftColGridTemplate(el, width) { + el.style.gridTemplateColumns = `${width}px 16px 1fr`; + } + + function displayResizeHandle(parent) { + if (window.innerWidth < GRADIO_MIN_WIDTH * 2 + PAD * 4) { + parent.style.display = 'flex'; + if (R.handle != null) { + R.handle.style.opacity = '0'; + } + return false; + } else { + parent.style.display = 'grid'; + if (R.handle != null) { + R.handle.style.opacity = '100'; + } + return true; + } + } + + function afterResize(parent) { + if (displayResizeHandle(parent) && parent.style.gridTemplateColumns != GRID_TEMPLATE_COLUMNS) { + const oldParentWidth = R.parentWidth; + const newParentWidth = parent.offsetWidth; + const widthL = parseInt(parent.style.gridTemplateColumns.split(' ')[0]); + + const ratio = newParentWidth / oldParentWidth; + + const newWidthL = Math.max(Math.floor(ratio * widthL), GRADIO_MIN_WIDTH); + setLeftColGridTemplate(parent, newWidthL); + + R.parentWidth = newParentWidth; + } + } + + function setup(parent) { + const leftCol = parent.firstElementChild; + const rightCol = parent.lastElementChild; + + parents.push(parent); + + parent.style.display = 'grid'; + parent.style.gap = '0'; + parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS; + + const resizeHandle = document.createElement('div'); + resizeHandle.classList.add('resize-handle'); + parent.insertBefore(resizeHandle, rightCol); + + resizeHandle.addEventListener('mousedown', (evt) => { + if (evt.button !== 0) return; + + evt.preventDefault(); + evt.stopPropagation(); + + document.body.classList.add('resizing'); + + R.tracking = true; + R.parent = parent; + R.parentWidth = parent.offsetWidth; + R.handle = resizeHandle; + R.leftCol = leftCol; + R.leftColStartWidth = leftCol.offsetWidth; + R.screenX = evt.screenX; + }); + + resizeHandle.addEventListener('dblclick', (evt) => { + evt.preventDefault(); + evt.stopPropagation(); + + parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS; + }); + + afterResize(parent); + } + + window.addEventListener('mousemove', (evt) => { + if (evt.button !== 0) return; + + if (R.tracking) { + evt.preventDefault(); + evt.stopPropagation(); + + const delta = R.screenX - evt.screenX; + const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - GRADIO_MIN_WIDTH - PAD), GRADIO_MIN_WIDTH); + setLeftColGridTemplate(R.parent, leftColWidth); + } + }); + + window.addEventListener('mouseup', (evt) => { + if (evt.button !== 0) return; + + if (R.tracking) { + evt.preventDefault(); + evt.stopPropagation(); + + R.tracking = false; + + document.body.classList.remove('resizing'); + } + }); + + + window.addEventListener('resize', () => { + clearTimeout(resizeTimer); + + resizeTimer = setTimeout(function() { + for (const parent of parents) { + afterResize(parent); + } + }, DEBOUNCE_TIME); + }); + + setupResizeHandle = setup; +})(); + +onUiLoaded(function() { + for (var elem of gradioApp().querySelectorAll('.resize-handle-row')) { + if (!elem.querySelector('.resize-handle')) { + setupResizeHandle(elem); + } + } +}); diff --git a/javascript/ui.js b/javascript/ui.js index d70a681bff7..bedcbf3e211 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -19,28 +19,11 @@ function all_gallery_buttons() { } function selected_gallery_button() { - var allCurrentButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnail-item.thumbnail-small.selected'); - var visibleCurrentButton = null; - allCurrentButtons.forEach(function(elem) { - if (elem.parentElement.offsetParent) { - visibleCurrentButton = elem; - } - }); - return visibleCurrentButton; + return all_gallery_buttons().find(elem => elem.classList.contains('selected')) ?? null; } function selected_gallery_index() { - var buttons = all_gallery_buttons(); - var button = selected_gallery_button(); - - var result = -1; - buttons.forEach(function(v, i) { - if (v == button) { - result = i; - } - }); - - return result; + return all_gallery_buttons().findIndex(elem => elem.classList.contains('selected')); } function extract_image_from_gallery(gallery) { @@ -152,11 +135,11 @@ function submit() { showSubmitButtons('txt2img', false); var id = randomId(); - localStorage.setItem("txt2img_task_id", id); + localSet("txt2img_task_id", id); requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { showSubmitButtons('txt2img', true); - localStorage.removeItem("txt2img_task_id"); + localRemove("txt2img_task_id"); showRestoreProgressButton('txt2img', false); }); @@ -171,11 +154,11 @@ function submit_img2img() { showSubmitButtons('img2img', false); var id = randomId(); - localStorage.setItem("img2img_task_id", id); + localSet("img2img_task_id", id); requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() { showSubmitButtons('img2img', true); - localStorage.removeItem("img2img_task_id"); + localRemove("img2img_task_id"); showRestoreProgressButton('img2img', false); }); @@ -189,9 +172,7 @@ function submit_img2img() { function restoreProgressTxt2img() { showRestoreProgressButton("txt2img", false); - var id = localStorage.getItem("txt2img_task_id"); - - id = localStorage.getItem("txt2img_task_id"); + var id = localGet("txt2img_task_id"); if (id) { requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { @@ -205,7 +186,7 @@ function restoreProgressTxt2img() { function restoreProgressImg2img() { showRestoreProgressButton("img2img", false); - var id = localStorage.getItem("img2img_task_id"); + var id = localGet("img2img_task_id"); if (id) { requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() { @@ -218,8 +199,8 @@ function restoreProgressImg2img() { onUiLoaded(function() { - showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id")); - showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id")); + showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); + showRestoreProgressButton('img2img', localGet("img2img_task_id")); }); diff --git a/launch.py b/launch.py index 1dbc4c6e33e..f83820d2534 100644 --- a/launch.py +++ b/launch.py @@ -1,6 +1,5 @@ from modules import launch_utils - args = launch_utils.args python = launch_utils.python git = launch_utils.git @@ -26,8 +25,18 @@ def main(): - if not args.skip_prepare_environment: - prepare_environment() + if args.dump_sysinfo: + filename = launch_utils.dump_sysinfo() + + print(f"Sysinfo saved as {filename}. Exiting...") + + exit(0) + + launch_utils.startup_timer.record("initial startup") + + with launch_utils.startup_timer.subcategory("prepare environment"): + if not args.skip_prepare_environment: + prepare_environment() if args.test_server: configure_for_tests() diff --git a/modules/api/api.py b/modules/api/api.py index 606db179d4c..e6edffe7144 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -4,6 +4,8 @@ import time import datetime import uvicorn +import ipaddress +import requests import gradio as gr from threading import Lock from io import BytesIO @@ -15,7 +17,7 @@ from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -23,8 +25,7 @@ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases -from modules.sd_vae import vae_dict +from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -56,7 +57,41 @@ def setUpscalers(req: dict): return reqDict +def verify_url(url): + """Returns True if the url refers to a global resource.""" + + import socket + from urllib.parse import urlparse + try: + parsed_url = urlparse(url) + domain_name = parsed_url.netloc + host = socket.gethostbyname_ex(domain_name) + for ip in host[2]: + ip_addr = ipaddress.ip_address(ip) + if not ip_addr.is_global: + return False + except Exception: + return False + + return True + + def decode_base64_to_image(encoding): + if encoding.startswith("http://") or encoding.startswith("https://"): + if not opts.api_enable_requests: + raise HTTPException(status_code=500, detail="Requests not allowed") + + if opts.api_forbid_local_requests and not verify_url(encoding): + raise HTTPException(status_code=500, detail="Request to local resource not allowed") + + headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {} + response = requests.get(encoding, timeout=30, headers=headers) + try: + image = Image.open(BytesIO(response.content)) + return image + except Exception as e: + raise HTTPException(status_code=500, detail="Invalid image url") from e + if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] try: @@ -197,6 +232,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) + self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) @@ -329,6 +365,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): with self.queue_lock: with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: + p.is_api = True p.scripts = script_runner p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples @@ -343,6 +380,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): processed = process_images(p) finally: shared.state.end() + shared.total_tqdm.clear() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -388,6 +426,7 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): with self.queue_lock: with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: p.init_images = [decode_base64_to_image(x) for x in init_images] + p.is_api = True p.scripts = script_runner p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples @@ -402,6 +441,7 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): processed = process_images(p) finally: shared.state.end() + shared.total_tqdm.clear() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -530,7 +570,7 @@ def set_config(self, req: Dict[str, Any]): raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): - shared.opts.set(k, v) + shared.opts.set(k, v, is_api=True) shared.opts.save(shared.config_filename) return @@ -562,10 +602,12 @@ def get_latent_upscale_modes(self): ] def get_sd_models(self): - return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] + import modules.sd_models as sd_models + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()] def get_sd_vaes(self): - return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()] + import modules.sd_vae as sd_vae + return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] @@ -608,6 +650,10 @@ def refresh_checkpoints(self): with self.queue_lock: shared.refresh_checkpoints() + def refresh_vae(self): + with self.queue_lock: + shared_items.refresh_vae_list() + def create_embedding(self, args: dict): try: shared.state.begin(job="create_embedding") diff --git a/modules/api/models.py b/modules/api/models.py index 800c9b93f14..6a574771c33 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -50,10 +50,12 @@ def __init__( additional_fields = None, ): def field_type_generator(k, v): - # field_type = str if not overrides.get(k) else overrides[k]["type"] - # print(k, v.annotation, v.default) field_type = v.annotation + if field_type == 'Image': + # images are sent as base64 strings via API + field_type = 'str' + return Optional[field_type] def merge_class_params(class_): @@ -63,7 +65,6 @@ def merge_class_params(class_): parameters = {**parameters, **inspect.signature(classes.__init__).parameters} return parameters - self._model_name = model_name self._class_data = merge_class_params(class_instance) @@ -72,7 +73,7 @@ def merge_class_params(class_): field=underscore(k), field_alias=k, field_type=field_type_generator(k, v), - field_value=v.default + field_value=None if isinstance(v.default, property) else v.default ) for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED ] diff --git a/modules/cache.py b/modules/cache.py index 71fe6302134..ff26a2132d9 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -1,11 +1,12 @@ import json +import os import os.path import threading import time from modules.paths import data_path, script_path -cache_filename = os.path.join(data_path, "cache.json") +cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json")) cache_data = None cache_lock = threading.Lock() @@ -29,9 +30,12 @@ def thread_func(): time.sleep(1) with cache_lock: - with open(cache_filename, "w", encoding="utf8") as file: + cache_filename_tmp = cache_filename + "-" + with open(cache_filename_tmp, "w", encoding="utf8") as file: json.dump(cache_data, file, indent=4) + os.replace(cache_filename_tmp, cache_filename) + dump_cache_after = None dump_cache_thread = None diff --git a/modules/call_queue.py b/modules/call_queue.py index 61aa240fb32..ddf0d57383c 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -1,11 +1,10 @@ from functools import wraps import html -import threading import time -from modules import shared, progress, errors +from modules import shared, progress, errors, devices, fifo_lock -queue_lock = threading.Lock() +queue_lock = fifo_lock.FIFOLock() def wrap_queued_call(func): @@ -75,6 +74,8 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): error_message = f'{type(e).__name__}: {e}' res = extra_outputs_array + [f"
{html.escape(error_message)}
"] + devices.torch_gc() + shared.state.skipped = False shared.state.interrupted = False shared.state.job_count = 0 diff --git a/modules/cmd_args.py b/modules/cmd_args.py index e401f6413a4..aab62286e24 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -13,8 +13,11 @@ parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed") parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup") parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing") +parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup") parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation") parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages") +parser.add_argument("--dump-sysinfo", action='store_true', help="launch.py argument: dump limited sysinfo file (without information about extensions, options) to disk and quit") +parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None) parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint") parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored") parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) @@ -33,9 +36,10 @@ parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") +parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") -parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") +parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") @@ -66,6 +70,7 @@ parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) @@ -78,7 +83,7 @@ parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None) parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") -parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it") +parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path]) parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) @@ -110,3 +115,5 @@ parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') +parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) +parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) diff --git a/modules/config_states.py b/modules/config_states.py index 6f1ab53fc59..b766aef11d8 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -8,14 +8,12 @@ import tqdm from datetime import datetime -from collections import OrderedDict import git from modules import shared, extensions, errors from modules.paths_internal import script_path, config_states_dir - -all_config_states = OrderedDict() +all_config_states = {} def list_config_states(): @@ -28,10 +26,14 @@ def list_config_states(): for filename in os.listdir(config_states_dir): if filename.endswith(".json"): path = os.path.join(config_states_dir, filename) - with open(path, "r", encoding="utf-8") as f: - j = json.load(f) - j["filepath"] = path - config_states.append(j) + try: + with open(path, "r", encoding="utf-8") as f: + j = json.load(f) + assert "created_at" in j, '"created_at" does not exist' + j["filepath"] = path + config_states.append(j) + except Exception as e: + print(f'[ERROR]: Config states {path}, {e}') config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True) diff --git a/modules/devices.py b/modules/devices.py index 57e51da30e2..c01f06024b4 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,7 +3,7 @@ from functools import lru_cache import torch -from modules import errors +from modules import errors, shared if sys.platform == "darwin": from modules import mac_specific @@ -17,8 +17,6 @@ def has_mps() -> bool: def get_cuda_device_string(): - from modules import shared - if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -40,8 +38,6 @@ def get_optimal_device(): def get_device_for(task): - from modules import shared - if task in shared.cmd_opts.use_cpu: return cpu @@ -71,14 +67,17 @@ def enable_tf32(): torch.backends.cudnn.allow_tf32 = True - errors.run(enable_tf32, "Enabling TF32") -cpu = torch.device("cpu") -device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None -dtype = torch.float16 -dtype_vae = torch.float16 -dtype_unet = torch.float16 +cpu: torch.device = torch.device("cpu") +device: torch.device = None +device_interrogate: torch.device = None +device_gfpgan: torch.device = None +device_esrgan: torch.device = None +device_codeformer: torch.device = None +dtype: torch.dtype = torch.float16 +dtype_vae: torch.dtype = torch.float16 +dtype_unet: torch.dtype = torch.float16 unet_needs_upcast = False @@ -90,26 +89,10 @@ def cond_cast_float(input): return input.float() if unet_needs_upcast else input -def randn(seed, shape): - from modules.shared import opts - - torch.manual_seed(seed) - if opts.randn_source == "CPU" or device.type == 'mps': - return torch.randn(shape, device=cpu).to(device) - return torch.randn(shape, device=device) - - -def randn_without_seed(shape): - from modules.shared import opts - - if opts.randn_source == "CPU" or device.type == 'mps': - return torch.randn(shape, device=cpu).to(device) - return torch.randn(shape, device=device) +nv_rng = None def autocast(disable=False): - from modules import shared - if disable: return contextlib.nullcontext() @@ -128,8 +111,6 @@ class NansException(Exception): def test_for_nans(x, where): - from modules import shared - if shared.cmd_opts.disable_nan_check: return @@ -169,3 +150,4 @@ def first_time_calculation(): x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) + diff --git a/modules/errors.py b/modules/errors.py index 5271a9fe1de..8c339464d46 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -14,7 +14,8 @@ def record_exception(): if exception_records and exception_records[-1] == e: return - exception_records.append((e, tb)) + from modules import sysinfo + exception_records.append(sysinfo.format_exception(e, tb)) if len(exception_records) > 5: exception_records.pop(0) @@ -83,3 +84,53 @@ def run(code, task): code() except Exception as e: display(task, e) + + +def check_versions(): + from packaging import version + from modules import shared + + import torch + import gradio + + expected_torch_version = "2.0.0" + expected_xformers_version = "0.0.20" + expected_gradio_version = "3.41.2" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded, as well as +there are reports of issues with training tab on the latest version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + if gradio.__version__ != expected_gradio_version: + print_error_explanation(f""" +You are running gradio {gradio.__version__}. +The program is designed to work with gradio {expected_gradio_version}. +Using a different version of gradio is extremely likely to break the program. + +Reasons why you have the mismatched gradio version can be: + - you use --skip-install flag. + - you use webui.py to start the program instead of launch.py. + - an extension installs the incompatible gradio version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + diff --git a/modules/extensions.py b/modules/extensions.py index 3ad5ed53160..bf9a1878f5d 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,7 +1,7 @@ import os import threading -from modules import shared, errors, cache +from modules import shared, errors, cache, scripts from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 @@ -11,9 +11,9 @@ def active(): - if shared.opts.disable_all_extensions == "all": + if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all": return [] - elif shared.opts.disable_all_extensions == "extra": + elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra": return [x for x in extensions if x.enabled and x.is_builtin] else: return [x for x in extensions if x.enabled] @@ -90,8 +90,6 @@ def do_read_info_from_repo(self): self.have_info_from_repo = True def list_files(self, subdir, extension): - from modules import scripts - dirpath = os.path.join(self.path, subdir) if not os.path.isdir(dirpath): return [] @@ -141,8 +139,12 @@ def list_extensions(): if not os.path.isdir(extensions_dir): return - if shared.opts.disable_all_extensions == "all": + if shared.cmd_opts.disable_all_extensions: + print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") + elif shared.opts.disable_all_extensions == "all": print("*** \"Disable all extensions\" option was set, will not load any extensions ***") + elif shared.cmd_opts.disable_extra_extensions: + print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***") elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 6ae07e91b1c..b9533677887 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -1,4 +1,7 @@ +import json +import os import re +import logging from collections import defaultdict from modules import errors @@ -84,27 +87,55 @@ def deactivate(self, p): raise NotImplementedError -def activate(p, extra_network_data): - """call activate for extra networks in extra_network_data in specified order, then call - activate for all remaining registered networks with an empty argument list""" +def lookup_extra_networks(extra_network_data): + """returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks. - activated = [] + Example input: + { + 'lora': [], + 'lyco': [], + 'hypernet': [] + } + + Example output: + + { + : [, ], + : [] + } + """ + + res = {} - for extra_network_name, extra_network_args in extra_network_data.items(): + for extra_network_name, extra_network_args in list(extra_network_data.items()): extra_network = extra_network_registry.get(extra_network_name, None) + alias = extra_network_aliases.get(extra_network_name, None) - if extra_network is None: - extra_network = extra_network_aliases.get(extra_network_name, None) + if alias is not None and extra_network is None: + extra_network = alias if extra_network is None: - print(f"Skipping unknown extra network: {extra_network_name}") + logging.info(f"Skipping unknown extra network: {extra_network_name}") continue + res.setdefault(extra_network, []).extend(extra_network_args) + + return res + + +def activate(p, extra_network_data): + """call activate for extra networks in extra_network_data in specified order, then call + activate for all remaining registered networks with an empty argument list""" + + activated = [] + + for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items(): + try: extra_network.activate(p, extra_network_args) activated.append(extra_network) except Exception as e: - errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") + errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}") for extra_network_name, extra_network in extra_network_registry.items(): if extra_network in activated: @@ -123,19 +154,16 @@ def deactivate(p, extra_network_data): """call deactivate for extra networks in extra_network_data in specified order, then call deactivate for all remaining registered networks""" - for extra_network_name in extra_network_data: - extra_network = extra_network_registry.get(extra_network_name, None) - if extra_network is None: - continue + data = lookup_extra_networks(extra_network_data) + for extra_network in data: try: extra_network.deactivate(p) except Exception as e: - errors.display(e, f"deactivating extra network {extra_network_name}") + errors.display(e, f"deactivating extra network {extra_network.name}") for extra_network_name, extra_network in extra_network_registry.items(): - args = extra_network_data.get(extra_network_name, None) - if args is not None: + if extra_network in data: continue try: @@ -177,3 +205,20 @@ def parse_prompts(prompts): return res, extra_data + +def get_user_metadata(filename): + if filename is None: + return {} + + basename, ext = os.path.splitext(filename) + metadata_filename = basename + '.json' + + metadata = {} + try: + if os.path.isfile(metadata_filename): + with open(metadata_filename, "r", encoding="utf8") as file: + metadata = json.load(file) + except Exception as e: + errors.display(e, f"reading extra network user metadata from {metadata_filename}") + + return metadata diff --git a/modules/extras.py b/modules/extras.py index e9c0263ec7d..2a310ae3f25 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -7,7 +7,7 @@ import torch import tqdm -from modules import shared, images, sd_models, sd_vae, sd_models_config +from modules import shared, images, sd_models, sd_vae, sd_models_config, errors from modules.ui_common import plaintext_to_html import gradio as gr import safetensors.torch @@ -72,7 +72,20 @@ def to_half(tensor, enable): return tensor -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata): +def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name): + metadata = {} + + for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]: + checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None) + if checkpoint_info is None: + continue + + metadata.update(checkpoint_info.metadata) + + return json.dumps(metadata, indent=4, ensure_ascii=False) + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json): shared.state.begin(job="model-merge") def fail(message): @@ -241,11 +254,25 @@ def filename_nothing(): shared.state.textinfo = "Saving" print(f"Saving to {output_modelname}...") - metadata = None + metadata = {} + + if save_metadata and copy_metadata_fields: + if primary_model_info: + metadata.update(primary_model_info.metadata) + if secondary_model_info: + metadata.update(secondary_model_info.metadata) + if tertiary_model_info: + metadata.update(tertiary_model_info.metadata) if save_metadata: - metadata = {"format": "pt"} + try: + metadata.update(json.loads(metadata_json)) + except Exception as e: + errors.display(e, "readin metadata from json") + + metadata["format"] = "pt" + if save_metadata and add_merge_recipe: merge_recipe = { "type": "webui", # indicate this model was merged with webui's built-in merger "primary_model_hash": primary_model_info.sha256, @@ -261,7 +288,6 @@ def filename_nothing(): "is_inpainting": result_is_inpainting_model, "is_instruct_pix2pix": result_is_instruct_pix2pix_model } - metadata["sd_merge_recipe"] = json.dumps(merge_recipe) sd_merge_models = {} @@ -281,11 +307,12 @@ def add_model_metadata(checkpoint_info): if tertiary_model_info: add_model_metadata(tertiary_model_info) + metadata["sd_merge_recipe"] = json.dumps(merge_recipe) metadata["sd_merge_models"] = json.dumps(sd_merge_models) _, extension = os.path.splitext(output_modelname) if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata) + safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None) else: torch.save(theta_0, output_modelname) diff --git a/modules/fifo_lock.py b/modules/fifo_lock.py new file mode 100644 index 00000000000..c35b3ae25a3 --- /dev/null +++ b/modules/fifo_lock.py @@ -0,0 +1,37 @@ +import threading +import collections + + +# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a +class FIFOLock(object): + def __init__(self): + self._lock = threading.Lock() + self._inner_lock = threading.Lock() + self._pending_threads = collections.deque() + + def acquire(self, blocking=True): + with self._inner_lock: + lock_acquired = self._lock.acquire(False) + if lock_acquired: + return True + elif not blocking: + return False + + release_event = threading.Event() + self._pending_threads.append(release_event) + + release_event.wait() + return self._lock.acquire() + + def release(self): + with self._inner_lock: + if self._pending_threads: + release_event = self._pending_threads.popleft() + release_event.set() + + self._lock.release() + + __enter__ = acquire + + def __exit__(self, t, v, tb): + self.release() diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a3448be9db8..d39f2ebac36 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,10 +6,10 @@ import gradio as gr from modules.paths import data_path -from modules import shared, ui_tempdir, script_callbacks +from modules import shared, ui_tempdir, script_callbacks, processing from PIL import Image -re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' +re_param_code = r'\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") @@ -32,6 +32,7 @@ def __init__(self, paste_button, tabname, source_text_component=None, source_ima def reset(): paste_fields.clear() + registered_param_bindings.clear() def quote(text): @@ -198,7 +199,6 @@ def restore_old_hires_fix_params(res): height = int(res.get("Size-2", 512)) if firstpass_width == 0 or firstpass_height == 0: - from modules import processing firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height) res['Size-1'] = firstpass_width @@ -280,6 +280,9 @@ def parse_generation_parameters(x: str): if "Hires sampler" not in res: res["Hires sampler"] = "Use same sampler" + if "Hires checkpoint" not in res: + res["Hires checkpoint"] = "Use same checkpoint" + if "Hires prompt" not in res: res["Hires prompt"] = "" @@ -304,32 +307,28 @@ def parse_generation_parameters(x: str): if "Schedule rho" not in res: res["Schedule rho"] = 0 + if "VAE Encoder" not in res: + res["VAE Encoder"] = "Full" + + if "VAE Decoder" not in res: + res["VAE Decoder"] = "Full" + return res infotext_to_setting_name_mapping = [ - ('Clip skip', 'CLIP_stop_at_last_layers', ), + +] +"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead. +Example content: + +infotext_to_setting_name_mapping = [ ('Conditional mask weight', 'inpainting_mask_weight'), ('Model hash', 'sd_model_checkpoint'), ('ENSD', 'eta_noise_seed_delta'), ('Schedule type', 'k_sched_type'), - ('Schedule max sigma', 'sigma_max'), - ('Schedule min sigma', 'sigma_min'), - ('Schedule rho', 'rho'), - ('Noise multiplier', 'initial_noise_multiplier'), - ('Eta', 'eta_ancestral'), - ('Eta DDIM', 'eta_ddim'), - ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'), - ('UniPC variant', 'uni_pc_variant'), - ('UniPC skip type', 'uni_pc_skip_type'), - ('UniPC order', 'uni_pc_order'), - ('UniPC lower order final', 'uni_pc_lower_order_final'), - ('Token merging ratio', 'token_merging_ratio'), - ('Token merging ratio hr', 'token_merging_ratio_hr'), - ('RNG', 'randn_source'), - ('NGMS', 's_min_uncond'), - ('Pad conds', 'pad_cond_uncond'), ] +""" def create_override_settings_dict(text_pairs): @@ -350,7 +349,8 @@ def create_override_settings_dict(text_pairs): params[k] = v.strip() - for param_name, setting_name in infotext_to_setting_name_mapping: + mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] + for param_name, setting_name in mapping + infotext_to_setting_name_mapping: value = params.get(param_name, None) if value is None: @@ -399,10 +399,16 @@ def paste_func(prompt): return res if override_settings_component is not None: + already_handled_fields = {key: 1 for _, key in paste_fields} + def paste_settings(params): vals = {} - for param_name, setting_name in infotext_to_setting_name_mapping: + mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] + for param_name, setting_name in mapping + infotext_to_setting_name_mapping: + if param_name in already_handled_fields: + continue + v = params.get(param_name, None) if v is None: continue diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py new file mode 100644 index 00000000000..e6b6835adcc --- /dev/null +++ b/modules/gradio_extensons.py @@ -0,0 +1,73 @@ +import gradio as gr + +from modules import scripts, ui_tempdir, patches + + +def add_classes_to_gradio_component(comp): + """ + this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others + """ + + comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])] + + if getattr(comp, 'multiselect', False): + comp.elem_classes.append('multiselect') + + +def IOComponent_init(self, *args, **kwargs): + self.webui_tooltip = kwargs.pop('tooltip', None) + + if scripts.scripts_current is not None: + scripts.scripts_current.before_component(self, **kwargs) + + scripts.script_callbacks.before_component_callback(self, **kwargs) + + res = original_IOComponent_init(self, *args, **kwargs) + + add_classes_to_gradio_component(self) + + scripts.script_callbacks.after_component_callback(self, **kwargs) + + if scripts.scripts_current is not None: + scripts.scripts_current.after_component(self, **kwargs) + + return res + + +def Block_get_config(self): + config = original_Block_get_config(self) + + webui_tooltip = getattr(self, 'webui_tooltip', None) + if webui_tooltip: + config["webui_tooltip"] = webui_tooltip + + config.pop('example_inputs', None) + + return config + + +def BlockContext_init(self, *args, **kwargs): + res = original_BlockContext_init(self, *args, **kwargs) + + add_classes_to_gradio_component(self) + + return res + + +def Blocks_get_config_file(self, *args, **kwargs): + config = original_Blocks_get_config_file(self, *args, **kwargs) + + for comp_config in config["components"]: + if "example_inputs" in comp_config: + comp_config["example_inputs"] = {"serialized": []} + + return config + + +original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init) +original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config) +original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init) +original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file) + + +ui_tempdir.install_ui_tempdir_override() diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c4821d21a7e..70f1cbd26b6 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -10,7 +10,7 @@ import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors +from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # images allows training previews to have infotext. Importing it at the top causes a circular import problem. - from modules import images + from modules import images, processing save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 diff --git a/modules/images.py b/modules/images.py index 38aa933d6e5..eb644733898 100644 --- a/modules/images.py +++ b/modules/images.py @@ -21,8 +21,6 @@ from modules.paths_internal import roboto_ttf_file from modules.shared import opts -import modules.sd_vae as sd_vae - LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) @@ -318,7 +316,7 @@ def resize(im, w, h): return res -invalid_filename_chars = '<>:"/\\|?*\n' +invalid_filename_chars = '<>:"/\\|?*\n\r\t' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') @@ -342,16 +340,6 @@ def sanitize_filename_part(text, replace_spaces=True): class FilenameGenerator: - def get_vae_filename(self): #get the name of the VAE file. - if sd_vae.loaded_vae_file is None: - return "NoneType" - file_name = os.path.basename(sd_vae.loaded_vae_file) - split_file_name = file_name.split('.') - if len(split_file_name) > 1 and split_file_name[0] == '': - return split_file_name[1] # if the first character of the filename is "." then [1] is obtained. - else: - return split_file_name[0] - replacements = { 'seed': lambda self: self.seed if self.seed is not None else '', 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0], @@ -367,7 +355,9 @@ def get_vae_filename(self): #get the name of the VAE file. 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime], [datetime