Skip to content

Commit

Permalink
Added option to use cudnn as backend for pytorch, this should help fi…
Browse files Browse the repository at this point in the history
…xing an issue with nvidia 16xx cards getting a black or green square instead of a proper image.
  • Loading branch information
ZeroCool940711 committed Dec 3, 2022
1 parent 175e5a1 commit 9283bb8
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
1 change: 1 addition & 0 deletions configs/webui/webui_streamlit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ general:
no_half: False
use_float16: False
precision: "autocast"
use_cudnn: False
optimized: False
optimized_turbo: False
optimized_config: "optimizedSD/v1-inference.yaml"
Expand Down
8 changes: 4 additions & 4 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ channels:
dependencies:
- conda-forge::nodejs=18.11.0
- yarn=1.22.19
- cudatoolkit=11.3
- cudatoolkit=11.7
- git
- numpy=1.22.3
- numpy=1.23.3
- pip=20.3
- python=3.8.5
- pytorch=1.11.0
- pytorch=1.13.0
- scikit-image=0.19.2
- torchvision=0.12.0
- torchvision=0.14.0
- pip:
- -r requirements.txt
10 changes: 10 additions & 0 deletions scripts/sd_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@
# remove all the annoying python warnings.
shutup.please()

# the following lines should help fixing an issue with nvidia 16xx cards.
if "defaults" in st.session_state:
if st.session_state["defaults"].general.use_cudnn:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
Expand Down Expand Up @@ -1613,6 +1619,10 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re
#
@retry(tries=5)
def generation_callback(img, i=0):

# try to do garbage collection before decoding the image
torch_gc()

if "update_preview_frequency" not in st.session_state:
raise StopException

Expand Down

0 comments on commit 9283bb8

Please sign in to comment.