Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed Mar 30, 2022
1 parent e13c2c7 commit e9ca4b8
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 58 deletions.
44 changes: 6 additions & 38 deletions examples/hubert/dataset/hubert_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,7 @@
import torch
import torchaudio
from torch import Tensor
from torch.utils.data import BatchSampler, Dataset, DistributedSampler


class DistributedBatchSampler(BatchSampler):
"""`BatchSampler` wrapper that distributes across each batch multiple workers.
Note: The code is forked from PyTorch-NLP, you can find the license in
https://github.com/PetrochukM/PyTorch-NLP/blob/master/LICENSE
Args:
batch_sampler (torch.utils.data.sampler.BatchSampler)
num_replicas (int, optional): Number of processes participating in distributed training.
rank (int, optional): Rank of the current process within num_replicas.
Example:
>>> from torch.utils.data.sampler import BatchSampler
>>> from torch.utils.data.sampler import SequentialSampler
>>> sampler = SequentialSampler(list(range(12)))
>>> batch_sampler = BatchSampler(sampler, batch_size=4, drop_last=False)
>>>
>>> list(DistributedBatchSampler(batch_sampler, num_replicas=2, rank=0))
[[0, 2], [4, 6], [8, 10]]
>>> list(DistributedBatchSampler(batch_sampler, num_replicas=2, rank=1))
[[1, 3], [5, 7], [9, 11]]
"""

def __init__(self, batch_sampler, **kwargs):
self.batch_sampler = batch_sampler
self.kwargs = kwargs

def __iter__(self):
for batch in self.batch_sampler:
yield list(DistributedSampler(batch, **self.kwargs))

def __len__(self):
return len(self.batch_sampler)
from torch.utils.data import BatchSampler, Dataset


class BucketizeBatchSampler(BatchSampler):
Expand Down Expand Up @@ -169,20 +137,20 @@ class HuBERTDataSet(Dataset):
"""Create a Dataset for HuBERT model training and fine-tuning.
Args:
root_dir (str or Path): The root directory that contains ``tsv`` and ``label`` directories.
exp_dir (str or Path): The root directory of the ``.tsv`` file list.
dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
subset (str): The subset of the dataset. Options: [``train``, ``valid``].
"""

def __init__(
self,
root_dir: Union[str, Path],
exp_dir: Union[str, Path],
dataset: str,
subset: str,
) -> None:
self.root_dir = Path(root_dir)
tsv_dir = self.root_dir / "tsv"
label_dir = self.root_dir / "label"
self.exp_dir = Path(exp_dir)
tsv_dir = self.exp_dir / "tsv"
label_dir = self.exp_dir / "label"
f_list, ind_list, len_list = self._get_lists(tsv_dir, dataset, subset)
self.f_list, self.ind_list, self.len_list = f_list, ind_list, len_list
self.labels = self._load_labels(label_dir, dataset, subset)
Expand Down
30 changes: 21 additions & 9 deletions examples/hubert/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import torchaudio
from dataset import BucketizeSampler, DistributedBatchSampler, HuBERTDataSet, CollateFnHubert
from dataset import BucketizeBatchSampler, DistributedBatchSampler, HuBERTDataSet, CollateFnHubert
from loss import hubert_loss
from pytorch_lightning import LightningModule
from torch import Tensor
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
self.feature_type = feature_type
self.seconds_per_batch = seconds_per_batch

def _step(self, batch, batch_idx, step_type):
def _step(self, batch: Batch, batch_idx, step_type):
if batch is None:
return None
waveforms, labels, audio_lengths = batch
Expand All @@ -103,31 +103,43 @@ def configure_optimizers(self):
def training_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "train")

def validation_step(self, batch, batch_idx):
def validation_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "val")

def train_dataloader(self):
dataset = HuBERTDataSet(self.root_path, self.dataset, "train")
sampler = BucketizeSampler(dataset, num_buckets=1000, max_token_count=self.seconds_per_batch * 16000)
sampler = DistributedBatchSampler(sampler)
sampler = BucketizeBatchSampler(
dataset.len_list,
num_buckets=10000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=True,
)
sampler = DistributedBatchSampler(sampler, shuffle=True)
sampler.set_epoch(self.current_epoch)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
num_workers=10,
pin_memory=True,
)
return dataloader

def val_dataloader(self):
dataset = HuBERTDataSet(self.root_path, self.dataset, "valid")
sampler = BucketizeSampler(dataset, num_buckets=1000, max_token_count=self.seconds_per_batch * 16000)
sampler = DistributedBatchSampler(sampler)
sampler = BucketizeBatchSampler(
dataset.len_list,
num_buckets=1000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=False,
)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
num_workers=10,
pin_memory=True,
)
return dataloader
2 changes: 1 addition & 1 deletion examples/hubert/loss/hubert_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def hubert_loss(
feature_weight (float, optional): The weight for feature penalty loss (Default: ``10.0``).
reduction (str, optional): The reduction method for cross-entropy loss (Default: ``"sum"``).
"""
loss = feature_pen * feature_weight
loss = feature_pen * feature_weight * logit_m.shape[0]
if logit_m is not None:
target_m = torch.zeros(logit_m.shape[0], dtype=torch.long, device=logit_m.device)
loss_m = F.cross_entropy(logit_m, target_m, reduction=reduction)
Expand Down
13 changes: 3 additions & 10 deletions examples/hubert/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,9 @@ def run_train(args):
save_weights_only=True,
verbose=True,
)
train_checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/train_loss",
mode="min",
save_top_k=5,
save_weights_only=True,
verbose=True,
)

callbacks = [
checkpoint,
train_checkpoint,
]
trainer = Trainer(
default_root_dir=args.exp_dir,
Expand All @@ -54,6 +46,7 @@ def run_train(args):
replace_sampler_ddp=False,
gradient_clip_val=args.clip_norm,
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
)

model = HuBERTPreTrainModule(
Expand Down Expand Up @@ -142,7 +135,7 @@ def _parse_args():
)
parser.add_argument(
"--clip-norm",
default=1.0,
default=None,
type=float,
help="The gradient norm value to clip. (Default: 1.0)",
)
Expand Down

0 comments on commit e9ca4b8

Please sign in to comment.