Skip to content

Commit

Permalink
3329 compatibility with pathlike obj (#3332)
Browse files Browse the repository at this point in the history
* update pathlike obj

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* support of pathlike obj

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* review path obj

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* update tests

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* autofix

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* fixes unit test

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* update based on comments

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* fixes dep issue

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Nov 16, 2021
1 parent 45d4a61 commit 4c2c1dc
Show file tree
Hide file tree
Showing 27 changed files with 200 additions and 145 deletions.
42 changes: 23 additions & 19 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union

import numpy as np

from monai.apps.utils import download_and_extract
from monai.config.type_definitions import PathLike
from monai.data import (
CacheDataset,
load_decathlon_datalist,
Expand Down Expand Up @@ -64,7 +65,7 @@ class MedNISTDataset(Randomizable, CacheDataset):

def __init__(
self,
root_dir: str,
root_dir: PathLike,
section: str,
transform: Union[Sequence[Callable], Callable] = (),
download: bool = False,
Expand All @@ -75,19 +76,20 @@ def __init__(
cache_rate: float = 1.0,
num_workers: int = 0,
) -> None:
if not os.path.isdir(root_dir):
root_dir = Path(root_dir)
if not root_dir.is_dir():
raise ValueError("Root directory root_dir must be a directory.")
self.section = section
self.val_frac = val_frac
self.test_frac = test_frac
self.set_random_state(seed=seed)
tarfile_name = os.path.join(root_dir, self.compressed_file_name)
dataset_dir = os.path.join(root_dir, self.dataset_folder_name)
tarfile_name = root_dir / self.compressed_file_name
dataset_dir = root_dir / self.dataset_folder_name
self.num_class = 0
if download:
download_and_extract(self.resource, tarfile_name, root_dir, self.md5)

if not os.path.exists(dataset_dir):
if not dataset_dir.is_dir():
raise RuntimeError(
f"Cannot find dataset directory: {dataset_dir}, please use download=True to download it."
)
Expand All @@ -105,19 +107,17 @@ def get_num_classes(self) -> int:
"""Get number of classes."""
return self.num_class

def _generate_data_list(self, dataset_dir: str) -> List[Dict]:
def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]:
"""
Raises:
ValueError: When ``section`` is not one of ["training", "validation", "test"].
"""
class_names = sorted(x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x)))
dataset_dir = Path(dataset_dir)
class_names = sorted(f"{x}" for x in dataset_dir.iterdir() if (dataset_dir / x).is_dir())
self.num_class = len(class_names)
image_files = [
[
os.path.join(dataset_dir, class_names[i], x)
for x in os.listdir(os.path.join(dataset_dir, class_names[i]))
]
[f"{dataset_dir.joinpath(class_names[i], x)}" for x in (dataset_dir / class_names[i]).iterdir()]
for i in range(self.num_class)
]
num_each = [len(image_files[i]) for i in range(self.num_class)]
Expand Down Expand Up @@ -146,6 +146,7 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]:
f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].'
)

# the types of label and class name should be compatible with the pytorch dataloader
return [
{"image": image_files_list[i], "label": image_class[i], "class_name": class_name[i]}
for i in section_indices
Expand Down Expand Up @@ -234,7 +235,7 @@ class DecathlonDataset(Randomizable, CacheDataset):

def __init__(
self,
root_dir: str,
root_dir: PathLike,
task: str,
section: str,
transform: Union[Sequence[Callable], Callable] = (),
Expand All @@ -245,19 +246,20 @@ def __init__(
cache_rate: float = 1.0,
num_workers: int = 0,
) -> None:
if not os.path.isdir(root_dir):
root_dir = Path(root_dir)
if not root_dir.is_dir():
raise ValueError("Root directory root_dir must be a directory.")
self.section = section
self.val_frac = val_frac
self.set_random_state(seed=seed)
if task not in self.resource:
raise ValueError(f"Unsupported task: {task}, available options are: {list(self.resource.keys())}.")
dataset_dir = os.path.join(root_dir, task)
dataset_dir = root_dir / task
tarfile_name = f"{dataset_dir}.tar"
if download:
download_and_extract(self.resource[task], tarfile_name, root_dir, self.md5[task])

if not os.path.exists(dataset_dir):
if not dataset_dir.exists():
raise RuntimeError(
f"Cannot find dataset directory: {dataset_dir}, please use download=True to download it."
)
Expand All @@ -275,7 +277,7 @@ def __init__(
"numTraining",
"numTest",
]
self._properties = load_decathlon_properties(os.path.join(dataset_dir, "dataset.json"), property_keys)
self._properties = load_decathlon_properties(dataset_dir / "dataset.json", property_keys)
if transform == ():
transform = LoadImaged(["image", "label"])
CacheDataset.__init__(
Expand Down Expand Up @@ -304,9 +306,11 @@ def get_properties(self, keys: Optional[Union[Sequence[str], str]] = None):
return {key: self._properties[key] for key in ensure_tuple(keys)}
return {}

def _generate_data_list(self, dataset_dir: str) -> List[Dict]:
def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]:
# the types of the item in data list should be compatible with the dataloader
dataset_dir = Path(dataset_dir)
section = "training" if self.section in ["training", "validation"] else "test"
datalist = load_decathlon_datalist(os.path.join(dataset_dir, "dataset.json"), True, section)
datalist = load_decathlon_datalist(dataset_dir / "dataset.json", True, section)
return self._split_datalist(datalist)

def _split_datalist(self, datalist: List[Dict]) -> List[Dict]:
Expand Down
29 changes: 16 additions & 13 deletions monai/apps/mmars/mmars.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
"""

import json
import os
import warnings
from typing import Mapping, Union
from pathlib import Path
from typing import Mapping, Optional, Union

import torch

import monai.networks.nets as monai_nets
from monai.apps.utils import download_and_extract, logger
from monai.config.type_definitions import PathLike
from monai.utils.module import optional_import

from .model_desc import MODEL_DESC
Expand Down Expand Up @@ -98,7 +99,9 @@ def _get_ngc_doc_url(model_name: str, model_prefix=""):
return f"https://ngc.nvidia.com/catalog/models/{model_prefix}{model_name}"


def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, version: int = -1):
def download_mmar(
item, mmar_dir: Optional[PathLike] = None, progress: bool = True, api: bool = False, version: int = -1
):
"""
Download and extract Medical Model Archive (MMAR) from Nvidia Clara Train.
Expand Down Expand Up @@ -128,10 +131,10 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False,
if not mmar_dir:
get_dir, has_home = optional_import("torch.hub", name="get_dir")
if has_home:
mmar_dir = os.path.join(get_dir(), "mmars")
mmar_dir = Path(get_dir()) / "mmars"
else:
raise ValueError("mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?")

mmar_dir = Path(mmar_dir)
if api:
model_dict = _get_all_ngc_models(item)
if len(model_dict) == 0:
Expand All @@ -140,10 +143,10 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False,
for k, v in model_dict.items():
ver = v["latest"] if version == -1 else str(version)
download_url = _get_ngc_url(k, ver)
model_dir = os.path.join(mmar_dir, v["name"])
model_dir = mmar_dir / v["name"]
download_and_extract(
url=download_url,
filepath=os.path.join(mmar_dir, f'{v["name"]}_{ver}.zip'),
filepath=mmar_dir / f'{v["name"]}_{ver}.zip',
output_dir=model_dir,
hash_val=None,
hash_type="md5",
Expand All @@ -161,11 +164,11 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False,
if version > 0:
ver = str(version)
model_fullname = f"{item[Keys.NAME]}_{ver}"
model_dir = os.path.join(mmar_dir, model_fullname)
model_dir = mmar_dir / model_fullname
model_url = item.get(Keys.URL) or _get_ngc_url(item[Keys.NAME], version=ver, model_prefix="nvidia/med/")
download_and_extract(
url=model_url,
filepath=os.path.join(mmar_dir, f"{model_fullname}.{item[Keys.FILE_TYPE]}"),
filepath=mmar_dir / f"{model_fullname}.{item[Keys.FILE_TYPE]}",
output_dir=model_dir,
hash_val=item[Keys.HASH_VAL],
hash_type=item[Keys.HASH_TYPE],
Expand All @@ -178,7 +181,7 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False,

def load_from_mmar(
item,
mmar_dir=None,
mmar_dir: Optional[PathLike] = None,
progress: bool = True,
version: int = -1,
map_location=None,
Expand Down Expand Up @@ -212,11 +215,11 @@ def load_from_mmar(
if not isinstance(item, Mapping):
item = get_model_spec(item)
model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version)
model_file = os.path.join(model_dir, item[Keys.MODEL_FILE])
model_file = model_dir / item[Keys.MODEL_FILE]
logger.info(f'\n*** "{item[Keys.ID]}" available at {model_dir}.')

# loading with `torch.jit.load`
if f"{model_file}".endswith(".ts"):
if model_file.name.endswith(".ts"):
if not pretrained:
warnings.warn("Loading a ScriptModule, 'pretrained' option ignored.")
if weights_only:
Expand All @@ -232,7 +235,7 @@ def load_from_mmar(
model_config = _get_val(dict(model_dict).get("train_conf", {}), key=model_key, default={})
if not model_config:
# 2. search json CONFIG_FILE for model config spec.
json_path = os.path.join(model_dir, item.get(Keys.CONFIG_FILE, "config_train.json"))
json_path = model_dir / item.get(Keys.CONFIG_FILE, "config_train.json")
with open(json_path) as f:
conf_dict = json.load(f)
conf_dict = dict(conf_dict)
Expand Down
46 changes: 25 additions & 21 deletions monai/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import tempfile
import warnings
import zipfile
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from urllib.error import ContentTooShortError, HTTPError, URLError
from urllib.request import urlretrieve

from monai.config.type_definitions import PathLike
from monai.utils import look_up_option, min_version, optional_import

gdown, has_gdown = optional_import("gdown", "3.6")
Expand Down Expand Up @@ -70,10 +72,10 @@ def get_logger(
__all__.append("logger")


def _basename(p):
def _basename(p: PathLike) -> str:
"""get the last part of the path (removing the trailing slash if it exists)"""
sep = os.path.sep + (os.path.altsep or "") + "/ "
return os.path.basename(p.rstrip(sep))
return Path(f"{p}".rstrip(sep)).name


def _download_with_progress(url, filepath, progress: bool = True):
Expand Down Expand Up @@ -111,7 +113,7 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: Optional[int] = None):
raise e


def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5") -> bool:
def check_hash(filepath: PathLike, val: Optional[str] = None, hash_type: str = "md5") -> bool:
"""
Verify hash signature of specified file.
Expand Down Expand Up @@ -144,7 +146,7 @@ def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5")


def download_url(
url: str, filepath: str = "", hash_val: Optional[str] = None, hash_type: str = "md5", progress: bool = True
url: str, filepath: PathLike = "", hash_val: Optional[str] = None, hash_type: str = "md5", progress: bool = True
) -> None:
"""
Download file from specified URL link, support process bar and hash check.
Expand All @@ -170,9 +172,10 @@ def download_url(
"""
if not filepath:
filepath = os.path.abspath(os.path.join(".", _basename(url)))
filepath = Path(".", _basename(url)).resolve()
logger.info(f"Default downloading to '{filepath}'")
if os.path.exists(filepath):
filepath = Path(filepath)
if filepath.exists():
if not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}."
Expand All @@ -181,21 +184,21 @@ def download_url(
return

with tempfile.TemporaryDirectory() as tmp_dir:
tmp_name = os.path.join(tmp_dir, f"{_basename(filepath)}")
tmp_name = Path(tmp_dir, _basename(filepath))
if url.startswith("https://drive.google.com"):
if not has_gdown:
raise RuntimeError("To download files from Google Drive, please install the gdown dependency.")
gdown.download(url, tmp_name, quiet=not progress)
gdown.download(url, f"{tmp_name}", quiet=not progress)
else:
_download_with_progress(url, tmp_name, progress=progress)
if not os.path.exists(tmp_name):
if not tmp_name.exists():
raise RuntimeError(
f"Download of file from {url} to {filepath} failed due to network issue or denied permission."
)
file_dir = os.path.dirname(filepath)
file_dir = filepath.parent
if file_dir:
os.makedirs(file_dir, exist_ok=True)
shutil.move(tmp_name, filepath) # copy the downloaded to a user-specified cache.
shutil.move(f"{tmp_name}", f"{filepath}") # copy the downloaded to a user-specified cache.
logger.info(f"Downloaded: {filepath}")
if not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
Expand All @@ -205,8 +208,8 @@ def download_url(


def extractall(
filepath: str,
output_dir: str = ".",
filepath: PathLike,
output_dir: PathLike = ".",
hash_val: Optional[str] = None,
hash_type: str = "md5",
file_type: str = "",
Expand Down Expand Up @@ -235,24 +238,25 @@ def extractall(
"""
if has_base:
# the extracted files will be in this folder
cache_dir = os.path.join(output_dir, _basename(filepath).split(".")[0])
cache_dir = Path(output_dir, _basename(filepath).split(".")[0])
else:
cache_dir = output_dir
if os.path.exists(cache_dir) and len(os.listdir(cache_dir)) > 0:
cache_dir = Path(output_dir)
if cache_dir.exists() and len(list(cache_dir.iterdir())) > 0:
logger.info(f"Non-empty folder exists in {cache_dir}, skipped extracting.")
return
filepath = Path(filepath)
if hash_val and not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of compressed file failed: " f"filepath={filepath}, expected {hash_type}={hash_val}."
)
logger.info(f"Writing into directory: {output_dir}.")
_file_type = file_type.lower().strip()
if filepath.endswith("zip") or _file_type == "zip":
if filepath.name.endswith("zip") or _file_type == "zip":
zip_file = zipfile.ZipFile(filepath)
zip_file.extractall(output_dir)
zip_file.close()
return
if filepath.endswith("tar") or filepath.endswith("tar.gz") or "tar" in _file_type:
if filepath.name.endswith("tar") or filepath.name.endswith("tar.gz") or "tar" in _file_type:
tar_file = tarfile.open(filepath)
tar_file.extractall(output_dir)
tar_file.close()
Expand All @@ -264,8 +268,8 @@ def extractall(

def download_and_extract(
url: str,
filepath: str = "",
output_dir: str = ".",
filepath: PathLike = "",
output_dir: PathLike = ".",
hash_val: Optional[str] = None,
hash_type: str = "md5",
file_type: str = "",
Expand All @@ -292,6 +296,6 @@ def download_and_extract(
progress: whether to display progress bar.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
filename = filepath or os.path.join(tmp_dir, f"{_basename(url)}")
filename = filepath or Path(tmp_dir, _basename(url)).resolve()
download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)
10 changes: 9 additions & 1 deletion monai/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,12 @@
print_gpu_info,
print_system_info,
)
from .type_definitions import DtypeLike, IndexSelection, KeysCollection, NdarrayOrTensor, NdarrayTensor, TensorOrList
from .type_definitions import (
DtypeLike,
IndexSelection,
KeysCollection,
NdarrayOrTensor,
NdarrayTensor,
PathLike,
TensorOrList,
)
Loading

0 comments on commit 4c2c1dc

Please sign in to comment.