Skip to content

Commit

Permalink
cleanup, more
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 9, 2024
1 parent f3ad5a5 commit a5d353d
Showing 1 changed file with 5 additions and 48 deletions.
53 changes: 5 additions & 48 deletions users/zeyer/experiments/exp2024_04_23_baselines/claix2023.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,6 @@ def py():
for cr_ctc in [None, {"cr_loss_scale": 0.2}, {"cr_loss_scale": 0.5}, {"cr_loss_scale": 1.0}]:
# TODO also adapt specaug for CR...
use_cr_ctc = cr_ctc is not None
if use_cr_ctc:
cr_ctc: Dict[str, Any]
cr_ctc = cr_ctc.copy()
cr_ctc["use_fixed_ctc_grad"] = "v2"
name = f"crLoss{cr_ctc['cr_loss_scale']}-" if use_cr_ctc else ""
if opts.get("time_downsampling"):
name += f"time{opts['time_downsampling']}-"
Expand Down Expand Up @@ -266,7 +262,7 @@ def py():
# purely used for training
"aux_attention_decoder": rf.build_dict(TransformerDecoder, num_layers=6),
**(cr_ctc if use_cr_ctc else {}),
**({"aed_loss_bug_fix": True} if use_cr_ctc else {}),
**({"use_fixed_ctc_grad": "v2", "aed_loss_bug_fix": True} if use_cr_ctc else {}),
"max_seq_length_default_target": None,
# Note on max seq len stats: Before, when we used max_seq_length_default_target=75 with bpe10k,
# out of 281241 seqs in train, we removed only 71 seqs.
Expand Down Expand Up @@ -349,48 +345,6 @@ def py():
"num_enc_layers": 12,
"out_blank_separated": True,
},
config_updates={
**_get_cfg_lrlin_oclr_by_bs_nep_v3(150_000, 100, batch_size_factor=_batch_size_factor),
"optimizer.weight_decay": 1e-2,
"max_seq_length_default_target": None,
# Note on max seq len stats: Before, when we used max_seq_length_default_target=75 with bpe10k,
# out of 281241 seqs in train, we removed only 71 seqs.
# With max seq len 19.5 secs on the audio, we also remove exactly 71 seqs.
"max_seq_length_default_input": 19.5 * _raw_sample_rate,
"__train_audio_preprocess": speed_pert_librosa_config,
"speed_pert_discrete_values": [0.7, 0.8, 0.9, 1.0, 1.1],
"aux_attention_decoder": rf.build_dict(TransformerDecoder, num_layers=6), # purely used for training
},
post_config_updates={"log_grad_norm": True, "__multi_proc_dataset_opts": {"num_workers": 25}},
vocab="spm512",
train_vocab_opts={"other_opts": {"class": "SamplingBytePairEncoding", "breadth_prob": 0.01}},
dataset_train_opts={"train_epoch_split": 1, "train_epoch_wise_filter": None},
# avoid OOM
env_updates={"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"},
)

ctc_train_exp(
"time4-n12-spm512-blankSep-auxAED-b150k-ctcFixGrad",
config_96gb_bf16_accgrad1,
model_config={
"enc_input_layer": rf.build_dict(
ConformerConvSubsample,
out_dims=[32, 64, 64],
filter_sizes=[(3, 3), (3, 3), (3, 3)],
pool_sizes=[(1, 2)],
strides=[(1, 1), (2, 1), (2, 1)],
),
"enc_conformer_layer": rf.build_dict(
ConformerEncoderLayer,
ff=rf.build_dict(
ConformerPositionwiseFeedForward, activation=rf.build_dict(rf.relu_square), with_bias=False
),
num_heads=8,
),
"feature_batch_norm": True,
"num_enc_layers": 12,
"out_blank_separated": True,
},
config_updates={
**_get_cfg_lrlin_oclr_by_bs_nep_v3(150_000, 100, batch_size_factor=_batch_size_factor),
"optimizer.weight_decay": 1e-2,
Expand Down Expand Up @@ -506,6 +460,7 @@ def py():
"ctc_am_scale": am_scale,
"ctc_prior_scale": prior_scale,
"ctc_prior_type": prior_type,
"use_fixed_ctc_grad": "v2",
},
post_config_updates={"log_grad_norm": True, "__multi_proc_dataset_opts": {"num_workers": 25}},
vocab="spm512",
Expand Down Expand Up @@ -589,6 +544,7 @@ def py():
"__train_audio_preprocess": speed_pert_librosa_config,
"speed_pert_discrete_values": [0.7, 0.8, 0.9, 1.0, 1.1],
"aux_attention_decoder": rf.build_dict(TransformerDecoder, num_layers=6), # purely used for training
"use_fixed_ctc_grad": "v2",
},
post_config_updates={"log_grad_norm": True, "__multi_proc_dataset_opts": {"num_workers": 25}},
vocab="spm10k",
Expand Down Expand Up @@ -640,6 +596,7 @@ def py():
"ctc_am_scale": am_scale,
"ctc_prior_scale": prior_scale,
"ctc_prior_type": prior_type,
"use_fixed_ctc_grad": "v2",
},
post_config_updates={"log_grad_norm": True, "__multi_proc_dataset_opts": {"num_workers": 25}},
vocab="spm512",
Expand Down Expand Up @@ -878,7 +835,7 @@ def py():
"speed_pert_discrete_values": [0.7, 0.8, 0.9, 1.0, 1.1],
"aux_attention_decoder": rf.build_dict(TransformerDecoder, num_layers=6), # purely used for training
**(cr_ctc if use_cr_ctc else {}),
**({"aed_loss_bug_fix": True} if use_cr_ctc else {}),
**({"use_fixed_ctc_grad": "v2", "aed_loss_bug_fix": True} if use_cr_ctc else {}),
},
config_deletes=["aux_loss_layers"],
post_config_updates={"log_grad_norm": True, "__multi_proc_dataset_opts": {"num_workers": 25}},
Expand Down

0 comments on commit a5d353d

Please sign in to comment.