-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Pruned transducer stateless2 for WenetSpeech #314
[WIP] Pruned transducer stateless2 for WenetSpeech #314
Conversation
When using char as modeling unit, there are some simple results. When I use the original pruned RNN-T for training, the pruned rnnt loss gets un-normal (too big in some iterations) after 23 epochs. And the pruned RNN-T2 doesn't have this case. BTW, I decode some results as follows:
It shows that the reworked model is better than the original. Next, I will use the total wenetspeech data for training based on the reworked model (char). |
Cool... what WER should we be aiming for, i.e. what is the state of the art? Also please show the decoding configuration, e.g. decoding algorithm and how many models you averaged. |
I would recommend you to use |
OK, will add it.
|
From wenetspeech github page (https://github.com/wenet-e2e/WenetSpeech): |
OK, please when you report results please be clear about which subsets you are referring to! |
I will update the results on real time. |
Please use predefined training subsets. Otherwise, the results are not comparable. |
# use cuts_L_50_pieces.jsonl.gz for original experiments | ||
) | ||
else: | ||
cuts_train = CutSet.from_file( | ||
self.args.manifest_dir | ||
/ "cuts_L_50_pieces.jsonl.gz" | ||
/ "cuts_L.jsonl.gz" | ||
# use cuts_L_50_pieces.jsonl.gz for original experiments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are "original experiments" in the comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It means that the cuts_L_50_pieces.jsonl.gz
includes 1400+ hours. It refers the results showing above.
log "Stage 8: Compute fbank for musan" | ||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then | ||
log "Stage 12: Combine features for M" | ||
if [ ! -f data/fbank/cuts_M.jsonl.gz ]; then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to combine them, I think.
For large datasets, the combining process is slow and the resulting file is large.
You can use
def train_XL_cuts(self) -> CutSet: | |
logging.info("About to get train-XL cuts") | |
filenames = list( | |
glob.glob(f"{self.manifest_dir}/XL_split_2000/cuts_XL.*.jsonl.gz") | |
) | |
pattern = re.compile(r"cuts_XL.([0-9]+).jsonl.gz") | |
idx_filenames = [ | |
(int(pattern.search(f).group(1)), f) for f in filenames | |
] | |
idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) | |
sorted_filenames = [f[1] for f in idx_filenames] | |
logging.info(f"Loading {len(sorted_filenames)} splits") | |
return lhotse.combine( | |
lhotse.load_manifest_lazy(p) for p in sorted_filenames | |
) |
as an example
@@ -5,6 +5,7 @@ set -eou pipefail | |||
nj=15 | |||
stage=0 | |||
stop_stage=100 | |||
use_whole_text=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is it used and what is its meaning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Em...actually, I want to set a choice for using the transcripts in L subset or using the text in the whole data to generate the text, token.txt, words.txt and so on. But it seems that it is not finished. Will add some things about it.
It seems the code takes much time to do the following process.
If I change to load the model |
That is expected. icefall/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py Lines 685 to 688 in 18a1e95
It has to skip as many batches as needed to reach the point where your last checkpoint was saved. |
OK. |
I figured out how to add the support for quick restoring of dynamic samplers state but forgot to actually add it. You can try to save It should be much faster because it only iterates cuts rather than doing actual dataloading. |
... note that it's supposed to "just work" and pick up from where it was saved when you start iterating it, so you won't need the |
I have test the PR #684. It runs successfully. When use this version, the time for loading the batch checkpoint is very few. Also, I can remove the check for batch_idx.
|
Thanks. Please double check that you are getting the expected number of steps in an epoch after resuming to confirm that the implementation is correct. There might be num_workers * prefetching_factor - 1 less batches (typically less than 10) but it is expected. |
I am testing the restore-dynamic-sampler #684 again. As shows in the above picture, the left part in it is my training process starting from epoch 0 and batch idx 0, I also set it to save a checkpoint-xxx.pt every 200 batchs. I use a small part data for testing it. The batch idx for epoch 0 is from 0 to 924. In the right part, I try to load the checkpoint-800.pt to continue train. Its starting batch idx is 799. And for its epoch 0, the batch idx is from 799 to 1723. |
Please clarify: did you expect epoch 0 to run for 124 more batches after step 800 (total 924), but it ran for 924 batches instead (total 1723) when resumed? |
Yes. So is it a reasonable expectation?
|
Yes, it's an issue, I need to debug that. |
Hey @luomingshuang, I was not able to reproduce your issue using mini librispeech. Please see the following notebook on Colab and let me know if anything is done differently than in your recipe. If you can reproduce it here, it will be easy for me to fix. Or maybe you're using this code in a way I haven't anticipated. https://drive.google.com/file/d/1hZSSIN8K6VPCw12pIHwUtL5w0bOzyvct/view?usp=sharing |
OK,thanks. I will see the Colab and check my code again. |
As you can see, the code has been running normally for 6 epochs. Somebody in the internet says that it may be due to the number of cpus in the server is not enough. I'm not sure. |
Hi, @pzelasko , I can reproduce your results with the provided Colab based on my current Lhotse. But when I use the codes as the PR showing (sampler is used in asr_datamodule.py and train.py ) with the same Lhotse for Wenetspeech, the problem as I described above still exists. When I load the checkpoint-xxx.pt, for the first training epoch, it still will run over the whole dataloader whatever the start-batch is. And after the first epoch, it will run normally.
|
I use a small part of wenetspeech for debug.
I save a checkpoint every 20 iterations. The number of iterations for one epoch is 66. I use the checkpoint-40 for debug.
|
if sampler_state_dict is not None: | ||
logging.info("Loading sampler state dict") | ||
print(sampler_state_dict.keys()) | ||
train_sampler.load_state_dict(sampler_state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try to move this state dict loading after creating dataloader? Then use it like:
train_dl.sampler.load_state_dict(sampler_state_dict)
I think DataLoader might be calling iter
more than once on the sampler, but after it's initialized, it would call iter
just once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do it but It seems that it is no use.
Interesting. Can you take the code from the notebook, and replace mini LibriSpeech with your test data, and see if you can replicate that issue? I suspect there is a difference between the usage in the notebook and the usage in the training script and we have to close in on that. |
Of course. I will do it right now.
|
Hi,@pzelasko , @csukuangfj suggests me to set num-workers=0 to re-run the code. I do it and there are some results. When num-workers=0, the batch id for the first epoch after loading the checkpoint-xxx.pt is normal. In my original experiments, the num-workers is 2. The training log is as follows: (the kept_batches is 44, and it runs for 22 batches in the first epoch. 44+22=66. This is correct.)
|
@luomingshuang |
OK, I will check it. |
Please update train.py in icefall to remove cur_batch_idx. |
This PR is for pruned transducer stateless2 on WenetSpeech. Still doing. Now, I am running with 1400+ hours data for training. Will update some things on real time.