Skip to content

Commit

Permalink
Merge pull request #265 from Layer-norm/main
Browse files Browse the repository at this point in the history
Fix custom model path, add custom temp path. Fix(#264)
  • Loading branch information
Fannovel16 authored Mar 9, 2024
2 parents 9365529 + e655821 commit 738aee6
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 14 deletions.
9 changes: 7 additions & 2 deletions config.example.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# this is an example for config.yaml file, you can rename it to config.yaml if you want to use it
# ###############################################################################################
# you can also use absolute paths like: "/root/ComfyUI/custom_nodes/comfyui_controlnet_aux/ckpts" or "D:\\comfyui\\custom_nodes\\comfyui_controlnet_aux\\ckpts"
# This path is for custom pressesor models base folder. default is "./ckpts"
# you can also use absolute paths like: "/root/ComfyUI/custom_nodes/comfyui_controlnet_aux/ckpts" or "D:\\ComfyUI\\custom_nodes\\comfyui_controlnet_aux\\ckpts"
annotator_ckpts_path: "./ckpts"
# ###############################################################################################
# This path is for downloading temporary files.
# You SHOULD use absolute path for this like"D:\\temp", DO NOT use relative paths. None for default.
custom_temp_path: None
# ###############################################################################################
# if you already have downloaded ckpts via huggingface hub into default cache path like: ~/.cache/huggingface/hub, you can set this True to use symlinks to save space
USE_SYMLINKS: False
# ###############################################################################################
Expand All @@ -12,4 +17,4 @@ USE_SYMLINKS: False
# empty list or only keep ["CPUExecutionProvider"] means you use cv2.dnn.readNetFromONNX to load onnx models
# if your onnx models can only run on the CPU or have other issues, we recommend using pt model instead.
# default value is ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"]
EP_list: ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"]
EP_list: ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"]
39 changes: 28 additions & 11 deletions src/controlnet_aux/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@
warnings.warn("USE_SYMLINKS not set successfully. Using default value: False to download models.")
pass

try:
temp_dir = os.environ['AUX_TEMP_DIR']
if len(temp_dir) >= 60:
warnings.warn(f"custom temp dir is too long. Using default")
temp_dir = tempfile.gettempdir()
except:
warnings.warn(f"custom temp dir not set successfully")
pass

here = Path(__file__).parent.resolve()

def HWC3(x):
Expand Down Expand Up @@ -240,12 +249,12 @@ def check_hash_from_torch_hub(file_path, filename):
curr_hash = sha256sum(file_path)
return curr_hash[:len(ref_hash)] == ref_hash

def custom_torch_download(filename, cache_dir=annotator_ckpts_path):
def custom_torch_download(filename, ckpts_dir=annotator_ckpts_path):
local_dir = os.path.join(get_dir(), 'checkpoints')
model_path = os.path.join(local_dir, filename)

if not os.path.exists(model_path):
local_dir = os.path.join(cache_dir, "torch")
local_dir = os.path.join(ckpts_dir, "torch")
if not os.path.exists(local_dir):
os.mkdir(local_dir)

Expand All @@ -260,13 +269,19 @@ def custom_torch_download(filename, cache_dir=annotator_ckpts_path):
download_url_to_file(url = model_url, dst = model_path)
assert check_hash_from_torch_hub(model_path, filename), f"Hash check failed as file {filename} is corrupted"
print("Hash check passed")

print(f"model_path is {model_path}")
return model_path

def custom_hf_download(pretrained_model_or_path, filename, cache_dir=temp_dir, subfolder='', use_symlinks=USE_SYMLINKS, repo_type="model"):
if use_symlinks:
cache_dir = annotator_ckpts_path
local_dir = os.path.join(cache_dir, pretrained_model_or_path)
def custom_hf_download(pretrained_model_or_path, filename, cache_dir=temp_dir, ckpts_dir=annotator_ckpts_path, subfolder='', use_symlinks=USE_SYMLINKS, repo_type="model"):

print(f"cacher folder is {cache_dir}, you can set it by custom_tmp_path in config.yaml")

local_dir = os.path.join(ckpts_dir, pretrained_model_or_path)
model_path = os.path.join(local_dir, *subfolder.split('/'), filename)

if len(str(model_path)) >= 255:
warnings.warn(f"Path {model_path} is too long, \n please change annotator_ckpts_path in config.yaml")

if not os.path.exists(model_path):
print(f"Failed to find {model_path}.\n Downloading from huggingface.co")
Expand All @@ -283,8 +298,8 @@ def custom_hf_download(pretrained_model_or_path, filename, cache_dir=temp_dir, s
if not os.path.exists(cache_dir_d):
os.makedirs(cache_dir_d)
open(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), "w")
os.link(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), os.path.join(cache_dir, f"linktest_{filename}.txt"))
os.remove(os.path.join(cache_dir, f"linktest_{filename}.txt"))
os.link(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), os.path.join(ckpts_dir, f"linktest_{filename}.txt"))
os.remove(os.path.join(ckpts_dir, f"linktest_{filename}.txt"))
os.remove(os.path.join(cache_dir_d, f"linktest_{filename}.txt"))
print("Using symlinks to download models. \n",\
"Make sure you have enough space on your cache folder. \n",\
Expand All @@ -294,9 +309,9 @@ def custom_hf_download(pretrained_model_or_path, filename, cache_dir=temp_dir, s
except:
print("Maybe not able to create symlink. Disable using symlinks.")
use_symlinks = False
cache_dir_d = os.path.join(cache_dir, pretrained_model_or_path, "cache")
cache_dir_d = os.path.join(cache_dir, "aux", pretrained_model_or_path)
else:
cache_dir_d = os.path.join(cache_dir, pretrained_model_or_path, "cache")
cache_dir_d = os.path.join(cache_dir, "aux", pretrained_model_or_path)

model_path = hf_hub_download(repo_id=pretrained_model_or_path,
cache_dir=cache_dir_d,
Expand All @@ -314,5 +329,7 @@ def custom_hf_download(pretrained_model_or_path, filename, cache_dir=temp_dir, s
shutil.rmtree(cache_dir_d)
except Exception as e :
print(e)


print(f"model_path is {model_path}")

return model_path
15 changes: 14 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import subprocess
import threading
import comfy
import tempfile

here = Path(__file__).parent.resolve()

Expand All @@ -18,13 +19,23 @@
config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)

annotator_ckpts_path = str(Path(here, config["annotator_ckpts_path"]))
TEMP_DIR = config["custom_temp_path"]
USE_SYMLINKS = config["USE_SYMLINKS"]
ORT_PROVIDERS = config["EP_list"]

if USE_SYMLINKS is None or type(USE_SYMLINKS) != bool:
log.error("USE_SYMLINKS must be a boolean. Using False by default.")
USE_SYMLINKS = False

if TEMP_DIR is None:
TEMP_DIR = tempfile.gettempdir()
elif not os.path.isdir(TEMP_DIR):
try:
os.makedirs(TEMP_DIR)
except:
log.error("Failed to create custom temp directory. Using default.")
TEMP_DIR = tempfile.gettempdir()

if not os.path.isdir(annotator_ckpts_path):
try:
os.makedirs(annotator_ckpts_path)
Expand All @@ -33,11 +44,13 @@
annotator_ckpts_path = str(Path(here, "./ckpts"))
else:
annotator_ckpts_path = str(Path(here, "./ckpts"))
TEMP_DIR = tempfile.gettempdir()
USE_SYMLINKS = False
ORT_PROVIDERS = ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider", "CoreMLExecutionProvider"]

os.environ['AUX_USE_SYMLINKS'] = str(USE_SYMLINKS)
os.environ['AUX_ANNOTATOR_CKPTS_PATH'] = annotator_ckpts_path
os.environ['AUX_TEMP_DIR'] = str(TEMP_DIR)
os.environ['AUX_USE_SYMLINKS'] = str(USE_SYMLINKS)
os.environ['AUX_ORT_PROVIDERS'] = str(",".join(ORT_PROVIDERS))

log.info(f"Using ckpts path: {annotator_ckpts_path}")
Expand Down

0 comments on commit 738aee6

Please sign in to comment.