Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/apps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset
from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar
from .utils import check_hash, download_and_extract, download_url, extractall
from .utils import check_hash, download_and_extract, download_url, extractall, get_logger, logger
14 changes: 7 additions & 7 deletions monai/apps/mmars/mmars.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch

import monai.networks.nets as monai_nets
from monai.apps.utils import download_and_extract
from monai.apps.utils import download_and_extract, logger
from monai.utils.module import optional_import

from .model_desc import MODEL_DESC
Expand All @@ -42,7 +42,7 @@ def get_model_spec(idx: Union[int, str]):
for cand in MODEL_DESC:
if str(cand[Keys.ID]).strip().lower() == key:
return cand
print(f"Available specs are: {MODEL_DESC}.")
logger.info(f"Available specs are: {MODEL_DESC}.")
raise ValueError(f"Unknown MODEL_DESC request: {idx}")


Expand Down Expand Up @@ -213,7 +213,7 @@ def load_from_mmar(
item = get_model_spec(item)
model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version)
model_file = os.path.join(model_dir, item[Keys.MODEL_FILE])
print(f'\n*** "{item[Keys.ID]}" available at {model_dir}.')
logger.info(f'\n*** "{item[Keys.ID]}" available at {model_dir}.')

# loading with `torch.jit.load`
if f"{model_file}".endswith(".ts"):
Expand Down Expand Up @@ -264,18 +264,18 @@ def load_from_mmar(
else:
raise ValueError(f"Could not load model config {model_config}.")

print(f"*** Model: {model_cls}")
logger.info(f"*** Model: {model_cls}")
model_kwargs = model_config.get("args", None)
if model_kwargs:
model_inst = model_cls(**model_kwargs)
print(f"*** Model params: {model_kwargs}")
logger.info(f"*** Model params: {model_kwargs}")
else:
model_inst = model_cls()
if pretrained:
model_inst.load_state_dict(model_dict.get(model_key, model_dict))
print("\n---")
logger.info("\n---")
doc_url = item.get(Keys.DOC) or _get_ngc_doc_url(item[Keys.NAME], model_prefix="nvidia:med:")
print(f"For more information, please visit {doc_url}\n")
logger.info(f"For more information, please visit {doc_url}\n")
return model_inst


Expand Down
57 changes: 46 additions & 11 deletions monai/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
# limitations under the License.

import hashlib
import logging
import os
import shutil
import sys
import tarfile
import tempfile
import warnings
Expand All @@ -31,7 +33,40 @@
else:
tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")

__all__ = ["check_hash", "download_url", "extractall", "download_and_extract"]
__all__ = ["check_hash", "download_url", "extractall", "download_and_extract", "get_logger"]

DEFAULT_FMT = "%(asctime)s - %(levelname)s - %(message)s"


def get_logger(
module_name: str = "monai.apps",
fmt: str = DEFAULT_FMT,
datefmt: Optional[str] = None,
logger_handler: Optional[logging.Handler] = None,
):
"""
Get a `module_name` logger with the specified format and date format.
By default, the logger will print to `stdout` at the INFO level.
If `module_name` is `None`, return the root logger.
`fmt` and `datafmt` are passed to a `logging.Formatter` object
(https://docs.python.org/3/library/logging.html#formatter-objects).
`logger_handler` can be used to add an additional handler.
"""
logger = logging.getLogger(module_name)
logger.propagate = False
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
handler.setFormatter(formatter)
logger.addHandler(handler)
if logger_handler is not None:
logger.addHandler(logger_handler)
return logger


# apps module-level default logger
logger = get_logger("monai.apps")
__all__.append("logger")


def _basename(p):
Expand Down Expand Up @@ -71,7 +106,7 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: Optional[int] = None):
warnings.warn("tqdm is not installed, will not show the downloading progress bar.")
urlretrieve(url, filepath)
except (URLError, HTTPError, ContentTooShortError, OSError) as e:
print(f"Download failed from {url} to {filepath}.")
logger.error(f"Download failed from {url} to {filepath}.")
raise e


Expand All @@ -86,7 +121,7 @@ def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5")

"""
if val is None:
print(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.")
logger.info(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.")
return True
if hash_type.lower() == "md5":
actual_hash = hashlib.md5()
Expand All @@ -99,13 +134,13 @@ def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5")
for chunk in iter(lambda: f.read(1024 * 1024), b""):
actual_hash.update(chunk)
except Exception as e:
print(f"Exception in check_hash: {e}")
logger.error(f"Exception in check_hash: {e}")
return False
if val != actual_hash.hexdigest():
print(f"check_hash failed {actual_hash.hexdigest()}.")
logger.error(f"check_hash failed {actual_hash.hexdigest()}.")
return False

print(f"Verified '{_basename(filepath)}', {hash_type}: {val}.")
logger.info(f"Verified '{_basename(filepath)}', {hash_type}: {val}.")
return True


Expand Down Expand Up @@ -137,13 +172,13 @@ def download_url(
"""
if not filepath:
filepath = os.path.abspath(os.path.join(".", _basename(url)))
print(f"Default downloading to '{filepath}'")
logger.info(f"Default downloading to '{filepath}'")
if os.path.exists(filepath):
if not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}."
)
print(f"File exists: {filepath}, skipped downloading.")
logger.info(f"File exists: {filepath}, skipped downloading.")
return

with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -162,7 +197,7 @@ def download_url(
if file_dir:
os.makedirs(file_dir, exist_ok=True)
shutil.move(tmp_name, filepath) # copy the downloaded to a user-specified cache.
print(f"Downloaded: {filepath}")
logger.info(f"Downloaded: {filepath}")
if not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of downloaded file failed: URL={url}, "
Expand Down Expand Up @@ -205,13 +240,13 @@ def extractall(
else:
cache_dir = output_dir
if os.path.exists(cache_dir) and len(os.listdir(cache_dir)) > 0:
print(f"Non-empty folder exists in {cache_dir}, skipped extracting.")
logger.info(f"Non-empty folder exists in {cache_dir}, skipped extracting.")
return
if hash_val and not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of compressed file failed: " f"filepath={filepath}, expected {hash_type}={hash_val}."
)
print(f"Writing into directory: {output_dir}.")
logger.info(f"Writing into directory: {output_dir}.")
_file_type = file_type.lower().strip()
if filepath.endswith("zip") or _file_type == "zip":
zip_file = zipfile.ZipFile(filepath)
Expand Down
17 changes: 9 additions & 8 deletions tests/test_download_and_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ def test_actions(self):
return # skipping this test due the network connection errors

wrong_md5 = "0"
try:
download_url(url, filepath, wrong_md5)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, RuntimeError):
# FIXME: skip MD5 check as current downloading method may fail
self.assertTrue(str(e).startswith("md5 check"))
return # skipping this test due the network connection errors
with self.assertLogs(logger="monai.apps", level="ERROR"):
try:
download_url(url, filepath, wrong_md5)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, RuntimeError):
# FIXME: skip MD5 check as current downloading method may fail
self.assertTrue(str(e).startswith("md5 check"))
return # skipping this test due the network connection errors

try:
extractall(filepath, output_dir, wrong_md5)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_mmar_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def test_download(self, idx):
# test model specification
cand = get_model_spec(idx)
self.assertEqual(cand[RemoteMMARKeys.ID], idx)
download_mmar(idx)
with self.assertLogs(level="INFO", logger="monai.apps"):
download_mmar(idx)
download_mmar(idx, progress=False) # repeated to check caching
with tempfile.TemporaryDirectory() as tmp_dir:
download_mmar(idx, mmar_dir=tmp_dir, progress=False)
Expand Down