From a61b90c2db462db51fea85fa03315665e430510b Mon Sep 17 00:00:00 2001 From: Sean Kim Date: Wed, 1 Jun 2022 13:55:39 -0700 Subject: [PATCH] Raising RuntimeErrors when datasets missing (#2430) Summary: Checks download flag and raises error when dataset is missing given download flag exists. Unit tested manually. edit: Changed path to check as well as comment that is returned. Pull Request resolved: https://github.com/pytorch/audio/pull/2430 Reviewed By: carolineechen Differential Revision: D36815729 Pulled By: skim0514 fbshipit-source-id: f062db7919271665b88ec9754d85cfa83b4f6fa3 --- torchaudio/datasets/cmuarctic.py | 7 ++++++- torchaudio/datasets/libritts.py | 6 ++++++ torchaudio/datasets/ljspeech.py | 6 ++++++ torchaudio/datasets/speechcommands.py | 6 ++++++ torchaudio/datasets/tedlium.py | 6 ++++++ 5 files changed, 30 insertions(+), 1 deletion(-) diff --git a/torchaudio/datasets/cmuarctic.py b/torchaudio/datasets/cmuarctic.py index 3bec6bdf80..6a1227b015 100644 --- a/torchaudio/datasets/cmuarctic.py +++ b/torchaudio/datasets/cmuarctic.py @@ -120,7 +120,12 @@ def __init__( checksum = _CHECKSUMS.get(url, None) download_url_to_file(url, archive, hash_prefix=checksum) extract_archive(archive) - + else: + if not os.path.exists(self._path): + raise RuntimeError( + f"The path {self._path} doesn't exist. " + "Please check the ``root`` path or set `download=True` to download it" + ) self._text = os.path.join(self._path, self._folder_text, self._file_text) with open(self._text, "r") as text: diff --git a/torchaudio/datasets/libritts.py b/torchaudio/datasets/libritts.py index 1ddab81a1e..f7e10cedc4 100644 --- a/torchaudio/datasets/libritts.py +++ b/torchaudio/datasets/libritts.py @@ -122,6 +122,12 @@ def __init__( checksum = _CHECKSUMS.get(url, None) download_url_to_file(url, archive, hash_prefix=checksum) extract_archive(archive) + else: + if not os.path.exists(self._path): + raise RuntimeError( + f"The path {self._path} doesn't exist. " + "Please check the ``root`` path or set `download=True` to download it" + ) self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio)) diff --git a/torchaudio/datasets/ljspeech.py b/torchaudio/datasets/ljspeech.py index db360c5262..e8421b639f 100644 --- a/torchaudio/datasets/ljspeech.py +++ b/torchaudio/datasets/ljspeech.py @@ -60,6 +60,12 @@ def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, downloa checksum = _RELEASE_CONFIGS["release1"]["checksum"] download_url_to_file(url, archive, hash_prefix=checksum) extract_archive(archive) + else: + if not os.path.exists(self._path): + raise RuntimeError( + f"The path {self._path} doesn't exist. " + "Please check the ``root`` path or set `download=True` to download it" + ) with open(self._metadata_path, "r", newline="") as metadata: flist = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE) diff --git a/torchaudio/datasets/speechcommands.py b/torchaudio/datasets/speechcommands.py index 4c85b1ee40..6b9872662f 100644 --- a/torchaudio/datasets/speechcommands.py +++ b/torchaudio/datasets/speechcommands.py @@ -109,6 +109,12 @@ def __init__( checksum = _CHECKSUMS.get(url, None) download_url_to_file(url, archive, hash_prefix=checksum) extract_archive(archive, self._path) + else: + if not os.path.exists(self._path): + raise RuntimeError( + f"The path {self._path} doesn't exist. " + "Please check the ``root`` path or set `download=True` to download it" + ) if subset == "validation": self._walker = _load_list(self._path, "validation_list.txt") diff --git a/torchaudio/datasets/tedlium.py b/torchaudio/datasets/tedlium.py index ee60b4f71b..d7478ca7be 100644 --- a/torchaudio/datasets/tedlium.py +++ b/torchaudio/datasets/tedlium.py @@ -108,6 +108,12 @@ def __init__( checksum = _RELEASE_CONFIGS[release]["checksum"] download_url_to_file(url, archive, hash_prefix=checksum) extract_archive(archive) + else: + if not os.path.exists(self._path): + raise RuntimeError( + f"The path {self._path} doesn't exist. " + "Please check the ``root`` path or set `download=True` to download it" + ) # Create list for all samples self._filelist = []