diff --git a/test/torchaudio_unittest/datasets/yesno_test.py b/test/torchaudio_unittest/datasets/yesno_test.py index a8a8a04276..6b094aaece 100644 --- a/test/torchaudio_unittest/datasets/yesno_test.py +++ b/test/torchaudio_unittest/datasets/yesno_test.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from torchaudio.datasets import yesno @@ -36,8 +37,7 @@ def setUpClass(cls): save_wav(path, data, 8000) cls.data.append(normalize_wav(data)) - def test_yesno(self): - dataset = yesno.YESNO(self.root_dir) + def _test_yesno(self, dataset): n_ite = 0 for i, (waveform, sample_rate, label) in enumerate(dataset): expected_label = self.labels[i] @@ -47,3 +47,11 @@ def test_yesno(self): assert label == expected_label n_ite += 1 assert n_ite == len(self.data) + + def test_yesno_str(self): + dataset = yesno.YESNO(self.root_dir) + self._test_yesno(dataset) + + def test_yesno_path(self): + dataset = yesno.YESNO(Path(self.root_dir)) + self._test_yesno(dataset) diff --git a/torchaudio/datasets/yesno.py b/torchaudio/datasets/yesno.py index 5a20539e28..182d224ba4 100644 --- a/torchaudio/datasets/yesno.py +++ b/torchaudio/datasets/yesno.py @@ -1,6 +1,7 @@ import os import warnings -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Union +from pathlib import Path import torchaudio from torch import Tensor @@ -34,7 +35,7 @@ class YESNO(Dataset): """Create a Dataset for YesNo. Args: - root (str): Path to the directory where the dataset is found or downloaded. + root (str or Path): Path to the directory where the dataset is found or downloaded. url (str, optional): The URL to download the dataset from. (default: ``"http://www.openslr.org/resources/1/waves_yesno.tar.gz"``) folder_in_archive (str, optional): @@ -48,7 +49,7 @@ class YESNO(Dataset): _ext_audio = ".wav" def __init__(self, - root: str, + root: Union[str, Path], url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False, @@ -65,6 +66,9 @@ def __init__(self, self.transform = transform self.target_transform = target_transform + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + archive = os.path.basename(url) archive = os.path.join(root, archive) self._path = os.path.join(root, folder_in_archive)