Skip to content

Commit

Permalink
static batch work
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Jan 23, 2024
1 parent c641b53 commit d02fc87
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions wenet/dataset/processor_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random

import torch
from torch.nn.utils.rnn import pad_sequence
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import torch.nn.functional as F
Expand Down Expand Up @@ -312,3 +313,53 @@ def spec_trim(sample, max_t=20):
y = x.clone().detach()[:max_frames - length]
sample['feat'] = y
return sample


def padding(data):
""" Padding the data into training data
Args:
data: List[{key, feat, label}
Returns:
Tuple(keys, feats, labels, feats lengths, label lengths)
"""
sample = data
assert isinstance(sample, list)
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
dtype=torch.int32)
order = torch.argsort(feats_length, descending=True)
feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order],
dtype=torch.int32)
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['key'] for i in order]
sorted_labels = [
torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order
]
sorted_wavs = [sample[i]['wav'].squeeze(0) for i in order]
label_lengths = torch.tensor([x.size(0) for x in sorted_labels],
dtype=torch.int32)
wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs],
dtype=torch.int32)
padded_feats = pad_sequence(sorted_feats,
batch_first=True,
padding_value=0)
padding_labels = pad_sequence(sorted_labels,
batch_first=True,
padding_value=-1)
padded_wavs = pad_sequence(sorted_wavs, batch_first=True, padding_value=0)

batch = {
"keys": sorted_keys,
"feats": padded_feats,
"target": padding_labels,
"feats_lengths": feats_lengths,
"target_lengths": label_lengths,
"pcm": padded_wavs,
"pcm_length": wav_lengths,
}
if 'speaker' in sample[0]:
speaker = torch.tensor([sample[i]['speaker'] for i in order],
dtype=torch.int32)
batch['speaker'] = speaker
return batch

0 comments on commit d02fc87

Please sign in to comment.