Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using aidatatang_200zh optionally in aishell training #495

Merged
merged 3 commits into from
Jul 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 70 additions & 35 deletions egs/aishell/ASR/pruned_transducer_stateless3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import torch
import torch.multiprocessing as mp
import torch.nn as nn

from aidatatang_200zh import AIDatatang200zh
from aishell import AIShell
from asr_datamodule import AsrDataModule
Expand Down Expand Up @@ -344,8 +345,11 @@ def get_parser():
"--datatang-prob",
type=float,
default=0.2,
help="The probability to select a batch from the "
"aidatatang_200zh dataset",
help="""The probability to select a batch from the
aidatatang_200zh dataset.
If it is set to 0, you don't need to download the data
for aidatatang_200zh.
""",
)

add_model_arguments(parser)
Expand Down Expand Up @@ -457,8 +461,12 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)

decoder_datatang = get_decoder_model(params)
joiner_datatang = get_joiner_model(params)
if params.datatang_prob > 0:
decoder_datatang = get_decoder_model(params)
joiner_datatang = get_joiner_model(params)
else:
decoder_datatang = None
joiner_datatang = None

model = Transducer(
encoder=encoder,
Expand Down Expand Up @@ -726,7 +734,7 @@ def train_one_epoch(
scheduler: LRSchedulerType,
graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
datatang_train_dl: torch.utils.data.DataLoader,
datatang_train_dl: Optional[torch.utils.data.DataLoader],
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
scaler: GradScaler,
Expand Down Expand Up @@ -778,13 +786,17 @@ def train_one_epoch(
dl_weights = [1 - params.datatang_prob, params.datatang_prob]

iter_aishell = iter(train_dl)
iter_datatang = iter(datatang_train_dl)
if datatang_train_dl is not None:
iter_datatang = iter(datatang_train_dl)

batch_idx = 0

while True:
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
dl = iter_aishell if idx == 0 else iter_datatang
if datatang_train_dl is not None:
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
dl = iter_aishell if idx == 0 else iter_datatang
else:
dl = iter_aishell

try:
batch = next(dl)
Expand All @@ -808,7 +820,11 @@ def train_one_epoch(
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
if datatang_train_dl is not None:
tot_loss = (
tot_loss * (1 - 1 / params.reset_interval)
) + loss_info

if aishell:
aishell_tot_loss = (
aishell_tot_loss * (1 - 1 / params.reset_interval)
Expand Down Expand Up @@ -871,12 +887,21 @@ def train_one_epoch(

if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
if datatang_train_dl is not None:
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
tot_loss_str = (
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
)
else:
tot_loss_str = ""
datatang_str = ""

logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, {prefix}_loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"{tot_loss_str}"
f"aishell_tot_loss[{aishell_tot_loss}], "
f"datatang_tot_loss[{datatang_tot_loss}], "
f"{datatang_str}"
f"batch size: {batch_size}, "
f"lr: {cur_lr:.2e}"
)
Expand All @@ -891,15 +916,18 @@ def train_one_epoch(
f"train/current_{prefix}_",
params.batch_idx_train,
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if datatang_train_dl is not None:
# If it is None, tot_loss is the same as aishell_tot_loss.
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
aishell_tot_loss.write_summary(
tb_writer, "train/aishell_tot_", params.batch_idx_train
)
datatang_tot_loss.write_summary(
tb_writer, "train/datatang_tot_", params.batch_idx_train
)
if datatang_train_dl is not None:
datatang_tot_loss.write_summary(
tb_writer, "train/datatang_tot_", params.batch_idx_train
)

if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
Expand Down Expand Up @@ -1032,11 +1060,6 @@ def run(rank, world_size, args):
train_cuts = aishell.train_cuts()
train_cuts = filter_short_and_long_utterances(train_cuts)

datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
train_datatang_cuts = datatang.train_cuts()
train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
train_datatang_cuts = train_datatang_cuts.repeat(times=None)

if args.enable_musan:
cuts_musan = load_manifest(
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
Expand All @@ -1052,11 +1075,21 @@ def run(rank, world_size, args):
cuts_musan=cuts_musan,
)

datatang_train_dl = asr_datamodule.train_dataloaders(
train_datatang_cuts,
on_the_fly_feats=False,
cuts_musan=cuts_musan,
)
if params.datatang_prob > 0:
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
train_datatang_cuts = datatang.train_cuts()
train_datatang_cuts = filter_short_and_long_utterances(
train_datatang_cuts
)
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
datatang_train_dl = asr_datamodule.train_dataloaders(
train_datatang_cuts,
on_the_fly_feats=False,
cuts_musan=cuts_musan,
)
else:
datatang_train_dl = None
logging.info("Not using aidatatang_200zh for training")

valid_cuts = aishell.valid_cuts()
valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
Expand All @@ -1065,13 +1098,14 @@ def run(rank, world_size, args):
train_dl,
# datatang_train_dl
]:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
)
if dl is not None:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
)

scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
Expand All @@ -1083,7 +1117,8 @@ def run(rank, world_size, args):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
datatang_train_dl.sampler.set_epoch(epoch)
if datatang_train_dl is not None:
datatang_train_dl.sampler.set_epoch(epoch)

if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
Expand Down