Skip to content

Commit 7a6e992

Browse files
authored
Collect pretrained weight binary files in one place (#3656)
* Update cache dir * Update pytorchcv cache root
1 parent d7538cb commit 7a6e992

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

src/otx/algo/classification/backbones/efficientnet.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import math
77
import os
8+
from pathlib import Path
89

910
import torch
1011
import torch.nn.functional as F
@@ -661,5 +662,6 @@ def init_weights(self, pretrained: bool | str | None = None):
661662
load_checkpoint_to_model(self, checkpoint)
662663
print(f"init weight - {pretrained}")
663664
elif pretrained is not None:
664-
download_model(net=self, model_name=self.model_name)
665+
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
666+
download_model(net=self, model_name=self.model_name, local_model_store_dir_path=str(cache_dir))
665667
print(f"init weight - {pretrained_urls[self.model_name]}")

src/otx/algo/detection/backbones/pytorchcv_backbones.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _build_pytorchcv_model(
128128
**kwargs,
129129
) -> nn.Module:
130130
"""Build pytorchcv model."""
131-
models_cache_root = kwargs.get("root", Path.home() / ".torch" / "models")
131+
models_cache_root = kwargs.get("root", Path.home() / ".cache" / "torch" / "hub" / "checkpoints")
132132
is_pretrained = kwargs.get("pretrained", False)
133133
print(
134134
f"Init model {type}, pretrained={is_pretrained}, models cache {models_cache_root}",

0 commit comments

Comments
 (0)