Skip to content

Commit

Permalink
YesNo Dataset Pathlib change (#1015)
Browse files Browse the repository at this point in the history

Co-authored-by: Vincent QB <vincentqb@users.noreply.github.com>
  • Loading branch information
bhargavkathivarapu and vincentqb authored Nov 13, 2020
1 parent 5630fe3 commit b9ee013
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
12 changes: 10 additions & 2 deletions test/torchaudio_unittest/datasets/yesno_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path

from torchaudio.datasets import yesno

Expand Down Expand Up @@ -36,8 +37,7 @@ def setUpClass(cls):
save_wav(path, data, 8000)
cls.data.append(normalize_wav(data))

def test_yesno(self):
dataset = yesno.YESNO(self.root_dir)
def _test_yesno(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
expected_label = self.labels[i]
Expand All @@ -47,3 +47,11 @@ def test_yesno(self):
assert label == expected_label
n_ite += 1
assert n_ite == len(self.data)

def test_yesno_str(self):
dataset = yesno.YESNO(self.root_dir)
self._test_yesno(dataset)

def test_yesno_path(self):
dataset = yesno.YESNO(Path(self.root_dir))
self._test_yesno(dataset)
10 changes: 7 additions & 3 deletions torchaudio/datasets/yesno.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import warnings
from typing import Any, List, Tuple
from typing import Any, List, Tuple, Union
from pathlib import Path

import torchaudio
from torch import Tensor
Expand Down Expand Up @@ -34,7 +35,7 @@ class YESNO(Dataset):
"""Create a Dataset for YesNo.
Args:
root (str): Path to the directory where the dataset is found or downloaded.
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from.
(default: ``"http://www.openslr.org/resources/1/waves_yesno.tar.gz"``)
folder_in_archive (str, optional):
Expand All @@ -48,7 +49,7 @@ class YESNO(Dataset):
_ext_audio = ".wav"

def __init__(self,
root: str,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
Expand All @@ -65,6 +66,9 @@ def __init__(self,
self.transform = transform
self.target_transform = target_transform

# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)

archive = os.path.basename(url)
archive = os.path.join(root, archive)
self._path = os.path.join(root, folder_in_archive)
Expand Down

0 comments on commit b9ee013

Please sign in to comment.