-
Notifications
You must be signed in to change notification settings - Fork 684
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: implementation adapted from [s3prl](https://github.com/s3prl/s3prl/blob/master/s3prl/downstream/quesst14_dtw/dataset.py) modifying the s3prl downstream expert to [this](carolineechen/s3prl@adc91a5) using this dataset implementation produces the same results as using the original s3prl pipeline Pull Request resolved: #2290 Reviewed By: nateanl Differential Revision: D35692551 Pulled By: carolineechen fbshipit-source-id: 035ad161d4cbbd2072411cfdf89984b73a89868c
- Loading branch information
1 parent
86100e3
commit aebcf6a
Showing
4 changed files
with
278 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import os | ||
from collections import defaultdict | ||
from pathlib import Path | ||
|
||
from parameterized import parameterized | ||
from torchaudio.datasets import quesst14 | ||
from torchaudio_unittest.common_utils import ( | ||
TempDirMixin, | ||
TorchaudioTestCase, | ||
get_whitenoise, | ||
save_wav, | ||
) | ||
|
||
|
||
def _get_filename(folder, index): | ||
if folder == "Audio": | ||
return f"quesst14_{index:05d}.wav" | ||
elif folder == "dev_queries": | ||
return f"quesst14_dev_{index:04d}.wav" | ||
elif folder == "eval_queries": | ||
return f"quesst14_eval_{index:04d}.wav" | ||
return | ||
|
||
|
||
def _get_key(folder): | ||
folder_key_mapping = { | ||
"Audio": "utterances", | ||
"dev_queries": "dev", | ||
"eval_queries": "eval", | ||
} | ||
return folder_key_mapping[folder] | ||
|
||
|
||
def _save_sample(dataset_dir, folder, language, index, sample_rate, seed): | ||
# create and save audio samples to corresponding files | ||
path = os.path.join(dataset_dir, folder) | ||
os.makedirs(path, exist_ok=True) | ||
filename = _get_filename(folder, index) | ||
file_path = os.path.join(path, filename) | ||
|
||
data = get_whitenoise( | ||
sample_rate=sample_rate, | ||
duration=0.01, | ||
n_channels=1, | ||
seed=seed, | ||
) | ||
save_wav(file_path, data, sample_rate) | ||
|
||
sample = (data, Path(file_path).with_suffix("").name) | ||
|
||
# add audio files and language data to language key files | ||
scoring_path = os.path.join(dataset_dir, "scoring") | ||
os.makedirs(scoring_path, exist_ok=True) | ||
wav_file = f"quesst14Database/{folder}/{filename}" | ||
line = f"{wav_file} {language}" | ||
|
||
key = _get_key(folder) | ||
language_key_file = f"language_key_{key}.lst" | ||
language_key_file = os.path.join(scoring_path, language_key_file) | ||
with open(language_key_file, "a") as f: | ||
f.write(line + "\n") | ||
|
||
return sample | ||
|
||
|
||
def _get_mocked_samples(dataset_dir, folder, sample_rate, seed): | ||
samples_per_language = 2 | ||
|
||
samples_map = defaultdict(list) | ||
samples_all = [] | ||
|
||
curr_idx = 0 | ||
for language in quesst14._LANGUAGES: | ||
for _ in range(samples_per_language): | ||
sample = _save_sample(dataset_dir, folder, language, curr_idx, sample_rate, seed) | ||
samples_map[language].append(sample) | ||
samples_all.append(sample) | ||
|
||
curr_idx += 1 | ||
return samples_map, samples_all | ||
|
||
|
||
def get_mock_dataset(dataset_dir): | ||
""" | ||
dataset_dir: directory to the mocked dataset | ||
""" | ||
os.makedirs(dataset_dir, exist_ok=True) | ||
sample_rate = 8000 | ||
|
||
audio_seed = 0 | ||
dev_seed = 1 | ||
eval_seed = 2 | ||
|
||
mocked_utterances, mocked_utterances_all = _get_mocked_samples(dataset_dir, "Audio", sample_rate, audio_seed) | ||
mocked_dev_samples, mocked_dev_samples_all = _get_mocked_samples(dataset_dir, "dev_queries", sample_rate, dev_seed) | ||
mocked_eval_samples, mocked_eval_samples_all = _get_mocked_samples( | ||
dataset_dir, "eval_queries", sample_rate, eval_seed | ||
) | ||
|
||
return ( | ||
mocked_utterances, | ||
mocked_dev_samples, | ||
mocked_eval_samples, | ||
mocked_utterances_all, | ||
mocked_dev_samples_all, | ||
mocked_eval_samples_all, | ||
) | ||
|
||
|
||
class TestQuesst14(TempDirMixin, TorchaudioTestCase): | ||
root_dir = None | ||
backend = "default" | ||
|
||
utterances = {} | ||
dev_samples = {} | ||
eval_samples = {} | ||
utterances_all = [] | ||
dev_samples_all = [] | ||
eval_samples_all = [] | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
cls.root_dir = cls.get_base_temp_dir() | ||
dataset_dir = os.path.join(cls.root_dir, "quesst14Database") | ||
( | ||
cls.utterances, | ||
cls.dev_samples, | ||
cls.eval_samples, | ||
cls.utterances_all, | ||
cls.dev_samples_all, | ||
cls.eval_samples_all, | ||
) = get_mock_dataset(dataset_dir) | ||
|
||
def _testQuesst14(self, dataset, data_samples): | ||
num_samples = 0 | ||
for i, (data, name) in enumerate(dataset): | ||
self.assertEqual(data, data_samples[i][0]) | ||
assert name == data_samples[i][1] | ||
num_samples += 1 | ||
|
||
assert num_samples == len(data_samples) | ||
|
||
def testQuesst14SubsetDocs(self): | ||
dataset = quesst14.QUESST14(self.root_dir, language=None, subset="docs") | ||
self._testQuesst14(dataset, self.utterances_all) | ||
|
||
def testQuesst14SubsetDev(self): | ||
dataset = quesst14.QUESST14(self.root_dir, language=None, subset="dev") | ||
self._testQuesst14(dataset, self.dev_samples_all) | ||
|
||
def testQuesst14SubsetEval(self): | ||
dataset = quesst14.QUESST14(self.root_dir, language=None, subset="eval") | ||
self._testQuesst14(dataset, self.eval_samples_all) | ||
|
||
@parameterized.expand(quesst14._LANGUAGES) | ||
def testQuesst14DocsSingleLanguage(self, language): | ||
dataset = quesst14.QUESST14(self.root_dir, language=language, subset="docs") | ||
self._testQuesst14(dataset, self.utterances[language]) | ||
|
||
@parameterized.expand(quesst14._LANGUAGES) | ||
def testQuesst14DevSingleLanguage(self, language): | ||
dataset = quesst14.QUESST14(self.root_dir, language=language, subset="dev") | ||
self._testQuesst14(dataset, self.dev_samples[language]) | ||
|
||
@parameterized.expand(quesst14._LANGUAGES) | ||
def testQuesst14EvalSingleLanguage(self, language): | ||
dataset = quesst14.QUESST14(self.root_dir, language=language, subset="eval") | ||
self._testQuesst14(dataset, self.eval_samples[language]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import os | ||
import re | ||
from pathlib import Path | ||
from typing import Tuple, Union, Optional | ||
|
||
import torch | ||
import torchaudio | ||
from torch.hub import download_url_to_file | ||
from torch.utils.data import Dataset | ||
from torchaudio.datasets.utils import extract_archive | ||
|
||
|
||
URL = "https://speech.fit.vutbr.cz/files/quesst14Database.tgz" | ||
_CHECKSUM = "4f869e06bc066bbe9c5dde31dbd3909a0870d70291110ebbb38878dcbc2fc5e4" | ||
_LANGUAGES = [ | ||
"albanian", | ||
"basque", | ||
"czech", | ||
"nnenglish", | ||
"romanian", | ||
"slovak", | ||
] | ||
|
||
|
||
class QUESST14(Dataset): | ||
"""Create QUESST14 Dataset | ||
Args: | ||
root (str or Path): Root directory where the dataset's top level directory is found | ||
language (str or None, optional): Language to get dataset for. | ||
Options: [None, ``albanian``, ``basque``, ``czech``, `nnenglish``, ``romanian``, ``slovak``]. | ||
(default: ``"nnenglish"``) | ||
subset (str): subset of the dataset to use. Options: ["docs", "dev", "eval"]. | ||
download (bool, optional): Whether to download the dataset if it is not found at root path. | ||
(default: ``False``) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
root: Union[str, Path], | ||
language: Optional[str] = "nnenglish", | ||
subset: Optional[str] = None, | ||
download: bool = False, | ||
) -> None: | ||
assert subset is None or subset in ["docs", "dev", "eval"], "`subset` must be one of ['docs', 'dev', 'eval']" | ||
|
||
assert language is None or language in _LANGUAGES, f"`language` must be None or one of {str(_LANGUAGES)}" | ||
|
||
# Get string representation of 'root' | ||
root = os.fspath(root) | ||
|
||
basename = os.path.basename(URL) | ||
archive = os.path.join(root, basename) | ||
|
||
basename = basename.rsplit(".", 2)[0] | ||
self._path = os.path.join(root, basename) | ||
|
||
if not os.path.isdir(self._path): | ||
if not os.path.isfile(archive): | ||
if not download: | ||
raise RuntimeError("Dataset not found. Please use `download=True` to download") | ||
download_url_to_file(URL, archive, hash_prefix=_CHECKSUM) | ||
extract_archive(archive, root) | ||
|
||
if subset == "docs": | ||
self.data = filter_audio_paths(self._path, language, "language_key_utterances.lst") | ||
elif subset == "dev": | ||
self.data = filter_audio_paths(self._path, language, "language_key_dev.lst") | ||
elif subset == "eval": | ||
self.data = filter_audio_paths(self._path, language, "language_key_eval.lst") | ||
|
||
def _load_sample(self, n: int) -> Tuple[torch.Tensor, str]: | ||
audio_path = self.data[n] | ||
wav, _ = torchaudio.load(audio_path) | ||
return wav, audio_path.with_suffix("").name | ||
|
||
def __getitem__(self, n: int) -> Tuple[torch.Tensor, str]: | ||
return self._load_sample(n) | ||
|
||
def __len__(self) -> int: | ||
return len(self.data) | ||
|
||
|
||
def filter_audio_paths( | ||
path: str, | ||
language: str, | ||
lst_name: str, | ||
): | ||
"""Extract audio paths for the given language.""" | ||
audio_paths = [] | ||
|
||
path = Path(path) | ||
with open(path / "scoring" / lst_name) as f: | ||
for line in f: | ||
audio_path, lang = line.strip().split() | ||
if language is not None and lang != language: | ||
continue | ||
audio_path = re.sub(r"^.*?\/", "", audio_path) | ||
audio_paths.append(path / audio_path) | ||
|
||
return audio_paths |