Skip to content

Commit

Permalink
add cache_dir arg to allow files to be stored anywhere (#159)
Browse files Browse the repository at this point in the history
* add cache_dir to a few models
  • Loading branch information
ieee8023 authored Sep 11, 2024
1 parent 543946b commit 5a8984c
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 53 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ jobs:
strategy:
max-parallel: 2
matrix:
python-version: ['3.9']
torch-version: [2.1.1]
python-version: ['3.11']
torch-version: [2.4.1]
os: [ubuntu-latest, macos-latest, windows-latest] # only run ubuntu for now because the other ones fail for no reason, macos-latest, windows-latest]

# Steps represent a sequence of tasks that will be executed as part of the job
Expand Down
35 changes: 12 additions & 23 deletions torchxrayvision/autoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import sys
import requests
from . import utils


model_urls = {}
Expand Down Expand Up @@ -218,7 +219,7 @@ def ResNetAE101(**kwargs):
return _ResNetAE(Bottleneck, DeconvBottleneck, [3, 4, 23, 2], 1, **kwargs)


def ResNetAE(weights=None):
def ResNetAE(weights=None, cache_dir=None):
"""A ResNet based autoencoder.
Possible weights for this class include:
Expand All @@ -231,6 +232,11 @@ def ResNetAE(weights=None):
z = ae.encode(image)
image2 = ae.decode(z)
params:
weights (str): Weights to use. See above for options.
cache_dir (str): Override directory used to store cached weights (default: ~/.torchxrayvision/)
"""

if weights == None:
Expand All @@ -245,14 +251,17 @@ def ResNetAE(weights=None):
# load pretrained models
url = model_urls[weights]["weights_url"]
weights_filename = os.path.basename(url)
weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data"))
if cache_dir is None:
weights_storage_folder = utils.get_cache_dir()
else:
weights_storage_folder = cache_dir
weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename))

if not os.path.isfile(weights_filename_local):
print("Downloading weights...")
print("If this fails you can run `wget {} -O {}`".format(url, weights_filename_local))
pathlib.Path(weights_storage_folder).mkdir(parents=True, exist_ok=True)
download(url, weights_filename_local)
utils.download(url, weights_filename_local)

try:
state_dict = torch.load(weights_filename_local, map_location='cpu')
Expand All @@ -268,23 +277,3 @@ def ResNetAE(weights=None):
ae.description = model_urls[weights]["description"]

return ae


# from here https://sumit-ghosh.com/articles/python-download-progress-bar/
def download(url, filename):
with open(filename, 'wb') as f:
response = requests.get(url, stream=True)
total = response.headers.get('content-length')

if total is None:
f.write(response.content)
else:
downloaded = 0
total = int(total)
for data in response.iter_content(chunk_size=max(int(total / 1000), 1024 * 1024)):
downloaded += len(data)
f.write(data)
done = int(50 * downloaded / total)
sys.stdout.write('\r[{}{}]'.format('█' * done, '.' * (50 - done)))
sys.stdout.flush()
sys.stdout.write('\n')
12 changes: 10 additions & 2 deletions torchxrayvision/baseline_models/chestx_det/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torchvision

from .ptsemseg.pspnet import pspnet
from ... import utils


def _convert_state_dict(state_dict):
Expand Down Expand Up @@ -51,6 +52,10 @@ class PSPNet(nn.Module):
url = {https://arxiv.org/abs/2104.10326},
year = {2021}
}
params:
cache_dir (str): Override directory used to store cached weights (default: ~/.torchxrayvision/)
"""

targets: List[str] = [
Expand All @@ -62,7 +67,7 @@ class PSPNet(nn.Module):
]
""""""

def __init__(self):
def __init__(self, cache_dir:str = None):

super(PSPNet, self).__init__()

Expand All @@ -78,7 +83,10 @@ def __init__(self):
url = "https://github.com/mlmed/torchxrayvision/releases/download/v1/pspnet_chestxray_best_model_4.pth"

weights_filename = os.path.basename(url)
weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data"))
if cache_dir is None:
weights_storage_folder = utils.get_cache_dir()
else:
weights_storage_folder = cache_dir
self.weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename))

if not os.path.isfile(self.weights_filename_local):
Expand Down
39 changes: 13 additions & 26 deletions torchxrayvision/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
from collections import OrderedDict
from . import datasets
from . import utils
import warnings
warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -191,6 +192,7 @@ class DenseNet(nn.Module):
model = xrv.models.DenseNet(weights="densenet121-res224-mimic_ch") # MIMIC-CXR (MIT)
:param weights: Specify a weight name to load pre-trained weights
:param cache_dir: Override where the weights will be stored (default is ~/.torchxrayvision/)
:param op_threshs: Specify a weight name to load pre-trained weights
:param apply_sigmoid: Apply a sigmoid
Expand Down Expand Up @@ -227,6 +229,7 @@ def __init__(self,
num_classes=len(datasets.default_pathologies),
in_channels=1,
weights=None,
cache_dir=None,
op_threshs=None,
apply_sigmoid=False
):
Expand Down Expand Up @@ -291,7 +294,7 @@ def __init__(self,
self.register_buffer('op_threshs', op_threshs)

if self.weights != None:
self.weights_filename_local = get_weights(weights)
self.weights_filename_local = get_weights(weights, cache_dir)

try:
savedmodel = torch.load(self.weights_filename_local, map_location='cpu')
Expand Down Expand Up @@ -355,6 +358,7 @@ class ResNet(nn.Module):
model = xrv.models.ResNet(weights="resnet50-res512-all")
:param weights: Specify a weight name to load pre-trained weights
:param cache_dir: Override where the weights will be stored (default is ~/.torchxrayvision/)
:param op_threshs: Specify a weight name to load pre-trained weights
:param apply_sigmoid: Apply a sigmoid
Expand Down Expand Up @@ -382,7 +386,7 @@ class ResNet(nn.Module):
]
""""""

def __init__(self, weights: str = None, apply_sigmoid: bool = False):
def __init__(self, weights: str = None, apply_sigmoid: bool = False, cache_dir: str = None):
super(ResNet, self).__init__()

self.weights = weights
Expand All @@ -392,7 +396,7 @@ def __init__(self, weights: str = None, apply_sigmoid: bool = False):
possible_weights = [k for k in model_urls.keys() if k.startswith("resnet")]
raise Exception("Weights value must be in {}".format(possible_weights))

self.weights_filename_local = get_weights(weights)
self.weights_filename_local = get_weights(weights, cache_dir=cache_dir)
self.weights_dict = model_urls[weights]
self.targets = model_urls[weights]["labels"]
self.pathologies = self.targets # keep to be backward compatible
Expand Down Expand Up @@ -546,39 +550,22 @@ def get_model(weights: str, **kwargs):
raise Exception("Unknown model")


def get_weights(weights: str):
def get_weights(weights: str, cache_dir:str = None):
if not weights in model_urls:
raise Exception("Weights not found. Valid options: {}".format(list(model_urls.keys())))

url = model_urls[weights]["weights_url"]
weights_filename = os.path.basename(url)
weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data"))
if cache_dir is None:
weights_storage_folder = utils.get_cache_dir()
else:
weights_storage_folder = cache_dir
weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename))

if not os.path.isfile(weights_filename_local):
print("Downloading weights...")
print("If this fails you can run `wget {} -O {}`".format(url, weights_filename_local))
pathlib.Path(weights_storage_folder).mkdir(parents=True, exist_ok=True)
download(url, weights_filename_local)
utils.download(url, weights_filename_local)

return weights_filename_local


# from here https://sumit-ghosh.com/articles/python-download-progress-bar/
def download(url: str, filename: str):
with open(filename, 'wb') as f:
response = requests.get(url, stream=True)
total = response.headers.get('content-length')

if total is None:
f.write(response.content)
else:
downloaded = 0
total = int(total)
for data in response.iter_content(chunk_size=max(int(total / 1000), 1024 * 1024)):
downloaded += len(data)
f.write(data)
done = int(50 * downloaded / total)
sys.stdout.write('\r[{}{}]'.format('█' * done, '.' * (50 - done)))
sys.stdout.flush()
sys.stdout.write('\n')
4 changes: 4 additions & 0 deletions torchxrayvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
import numpy as np
import skimage
import torch
import os

from os import PathLike
from numpy import ndarray
import warnings
from tqdm.autonotebook import tqdm


def get_cache_dir():
return os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data/"))

def in_notebook():
try:
from IPython import get_ipython
Expand Down

0 comments on commit 5a8984c

Please sign in to comment.