diff --git a/py/image.py b/py/image.py index 5ba74fc..5625917 100644 --- a/py/image.py +++ b/py/image.py @@ -822,9 +822,10 @@ def remove(self, rem_mode, images, image_output, save_prefix, torchscript_jit=Fa if rem_mode == "RMBG-2.0": repo_id = REMBG_MODELS[rem_mode]['model_url'] model_path = os.path.join(REMBG_DIR, 'RMBG-2.0') - from huggingface_hub import snapshot_download + if not os.path.exists(model_path): + from huggingface_hub import snapshot_download + snapshot_download(repo_id=repo_id, local_dir=model_path, ignore_patterns=["*.md", "*.txt"]) from transformers import AutoModelForImageSegmentation - snapshot_download(repo_id=repo_id, local_dir=model_path, ignore_patterns=["*.md", "*.txt"]) model = AutoModelForImageSegmentation.from_pretrained(model_path, trust_remote_code=True) torch.set_float32_matmul_precision('high') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")