Skip to content

Commit

Permalink
use torchaudio quesst14
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed Apr 15, 2022
1 parent e2db27b commit adc91a5
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions s3prl/downstream/quesst14_dtw/expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from torch.utils.data import DataLoader
from tqdm import tqdm

from .dataset import QUESST14Dataset
from torchaudio.datasets import QUESST14 as QUESST14Dataset
import torchaudio.functional as F
import torchaudio.transforms as T
from torchaudio.sox_effects import apply_effects_tensor


class DownstreamExpert(nn.Module):
Expand Down Expand Up @@ -46,18 +49,33 @@ def __init__(

# Interface
def get_dataloader(self, mode):
root = Path(self.datarc["dataset_root"])
self.docs_dataset = QUESST14Dataset(root=root, subset="docs")
if mode == "dev":
self.test_dataset = QUESST14Dataset("dev", **self.datarc)
self.queries_dataset = QUESST14Dataset(root=root, subset="dev")
else: # eval
self.test_dataset = QUESST14Dataset("eval", **self.datarc)
self.queries_dataset = QUESST14Dataset(root=root, subset="eval")
self.test_dataset = torch.utils.data.ConcatDataset([self.queries_dataset, self.docs_dataset])

def collate_fn(samples):
"""Collate a mini-batch of data."""
wavs, audio_names = zip(*samples)
# return wavs, audio_names

preprocessed_wavs = []
effects = [["channels", "1"], ["rate", "16000"], ["gain", "-3.0"],]
for i in range(len(wavs)):
wav, _ = apply_effects_tensor(wavs[i], sample_rate=8000, effects=effects)
preprocessed_wavs.append(wav.squeeze(0))
return tuple(preprocessed_wavs), audio_names

return DataLoader(
self.test_dataset,
shuffle=False,
batch_size=self.datarc["batch_size"],
drop_last=False,
num_workers=self.datarc["num_workers"],
collate_fn=self.test_dataset.collate_fn,
collate_fn=collate_fn,
)

# Interface
Expand All @@ -81,10 +99,11 @@ def log_records(self, mode, records, **kwargs):
"""Perform DTW and save results."""

# Get precomputed queries & docs
queries = records["features"][: self.test_dataset.n_queries]
docs = records["features"][self.test_dataset.n_queries :]
query_names = records["audio_names"][: self.test_dataset.n_queries]
doc_names = records["audio_names"][self.test_dataset.n_queries :]
n_queries = len(self.queries_dataset)
queries = records["features"][: n_queries]
docs = records["features"][n_queries :]
query_names = records["audio_names"][: n_queries]
doc_names = records["audio_names"][n_queries :]

# Normalize upstream features
feature_mean, feature_std = 0.0, 1.0
Expand Down

0 comments on commit adc91a5

Please sign in to comment.