Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MGB2 #396

Merged
merged 36 commits into from
Dec 2, 2022
Merged

MGB2 #396

merged 36 commits into from
Dec 2, 2022

Conversation

AmirHussein96
Copy link
Contributor

@AmirHussein96 AmirHussein96 commented Jun 4, 2022

This is a pull request for MGB2 recipe.
Kindly note that the model is still running and currently at epoch 3, see the training curves here https://tensorboard.dev/experiment/zy6FnumCQlmiO7BPsdCmEg/#scalars.
One issue is that with the current setup one epoch on 2GPUs V-100 32GB with --max-duration 100, takes 2 days which is very long compared to similar architecture with Espnet (1/2 day for 1 epoch ), any ideas what could cause this?
I tried to increase the --max-duration to 200 but it gave me OOM error.

On the other hand the WER on test = 23.53, looks reasonable given that this is still 3rd epoch. I expect to get something close to Espnet (Transformer 14.2, Conformer 13.7).

@csukuangfj
Copy link
Collaborator

Thanks!

How did you choose the following thresholds?

def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 0.5 <= c.duration <= 30.0

Could you update ./local/display_manifest_statistics.py?

Also, could you please try our pruned RNN-T recipe, which not only has a lower WER on LibriSpeech/GigaSpeech but also has a faster decoding speed with much less memory consumption?

I would recommend you using https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5 as a starting point.

@csukuangfj
Copy link
Collaborator

Also, I think you are converting kaldi manifests to lhotse format. Please have a look at #391 (reply in thread)

If you use a version of lhotse before lhotse-speech/lhotse#729 to extract the features, I would suggest you to re-extract it using the latest lhotse, which uses lilcom_chunky instead of lilcom_hdf5.

@AmirHussein96
Copy link
Contributor Author

lilcom_hdf5

Yes my current version uses lilcom_hdf5, I will rerun it using lilcom_chunky and let you know. Thank you.

@AmirHussein96
Copy link
Contributor Author

AmirHussein96 commented Jun 5, 2022

Thanks!

How did you choose the following thresholds?

def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 0.5 <= c.duration <= 30.0

Could you update ./local/display_manifest_statistics.py?

Also, could you please try our pruned RNN-T recipe, which not only has a lower WER on LibriSpeech/GigaSpeech but also has a faster decoding speed with much less memory consumption?

I would recommend you using https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5 as a starting point.

The min=0.5 and max=30 duration boundaries are similar to what I used with Espnet based on my experience. Longer segments > 30 cause memory issues and model underfitting (needs a lot of epochs to start fitting the training data).
I will check and update ./local/display_manifest_statistics.py.

Regarding RNN-T I was actually considering it as my next step, so yes I will run it as well, thank you for pointing me to the latest best RNN-T configuration.

@AmirHussein96
Copy link
Contributor Author

So to use lilcom_chunky I should change storage_type from LilcomHdf5Writer to LilcomChunkyWriter in compute_fbank_mgb2.py, right? Should I do the same with musan?
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)

@csukuangfj
Copy link
Collaborator

csukuangfj commented Jun 5, 2022

So to use lilcom_chunky I should change storage_type from LilcomHdf5Writer to LilcomChunkyWriter in compute_fbank_mgb2.py, right?

Yes. Please see

storage_type=LilcomChunkyWriter,

Should I do the same with musan?

Yes, you can do that. Please see

storage_type=LilcomChunkyWriter,

Notes that filenames end with jsonl.gz, not json.gz and you HAVE TO use to_file(), not to_json().

Also, please replace load_manifest with load_manifest_lazy, see

cuts_musan = load_manifest_lazy(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)

return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)

And use DynamicBucketingSampler to replace BucketingSampler, see

if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)


with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.json.gz").is_file():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (output_dir / f"cuts_{partition}.json.gz").is_file():
if (output_dir / f"cuts_{partition}.jsonl.gz").is_file():

cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")

@AmirHussein96
Copy link
Contributor Author

Also I think I closed the PR by mistake, @csukuangfj can you reopen it?

@pkufool pkufool reopened this Jun 5, 2022
@AmirHussein96
Copy link
Contributor Author

After making the suggested modifications for feature storing storage_type=LilcomChunkyWriter, and loading load_manifest_lazy and also I managed to increase the --max-duration from 150 to 300, the 6k iteration takes twice what it took previously. The new setup training curve can be found here https://tensorboard.dev/experiment/N8X0P5pHQyiWwvdT7RTp8w/. I note here that I am still using 2GPUs V-100 32 GB
Should I try --storage-type numpy_hdf5 ?

@csukuangfj
Copy link
Collaborator

I managed to increase the --max-duration from 150 to 300, the 6k iteration takes twice what it took previously

Since you are doubling the max duration, the time for 6k iterations should also be increased. But I am not sure whether it is normal that the time is doubled. @pzelasko Do you have any comments?

Should I try --storage-type numpy_hdf5 ?

I don't know whether switching to numpy_hdf5 will help you decrease the training time.
I think you were using chunked_lilcom_hdf5, i.e., ChunkedLilcomHdf5Writer

@danpovey
Copy link
Collaborator

danpovey commented Jun 6, 2022

It is pretty close to linear, the change of time per minibatch when you increase the --max-duration, so I think that is as expected.

@pzelasko
Copy link
Collaborator

pzelasko commented Jun 6, 2022

LilcomChunky and LilcomHdf5 should have very close performance, I don’t expect you’d win anything here.

Like Dan says, if you scaled up the batch size by 2x it can explain that it takes almost twice as long to run (unless you had a small model that underutilizes the GPU which is likely not the case here).

@AmirHussein96
Copy link
Contributor Author

AmirHussein96 commented Jun 6, 2022

Related to the slow training discussion, @danpovey suggested to either :

  • Use more workers.
  • Use webdatasets (reads data sequentially, more suited for data on HDD disk).
  • Debug the I/O read latency using py-spy (check if threads stuck in reading data).

So first I double checked the node hard drive lsblk -d -o name,rota and it looks like SSD:

NAME    ROTA
nvme0n1    0
nvme1n1    0

So as the first attempt I tried increasing the number of workers from 2 to 8, the --max-duration is150 similar to https://tensorboard.dev/experiment/zy6FnumCQlmiO7BPsdCmEg/#scalars because 300 and 200 gave OOM. The speed of the iterations indeed became 4 times faster, you can find the new setup with 8 workers: https://tensorboard.dev/experiment/WvSg4yn8SYyJlKyQGkls0A/#scalars.

@danpovey
Copy link
Collaborator

danpovey commented Jun 7, 2022

Increasing num-workers increases RAM utilization but does not increase GPU memory utilization so it should not affect the maximum --max-duration you can use.

My feeling is that the issue is that he is running from a HDD, not an SDD, so the latency of disk access is quite slow.
I think a solution would be to either use WebDataset for sequential access, or just use a much larger number of workers. If he use jsonl in the recipe, that should prevent the large num-workers from causing too-excessive memory use. (note: a recent PR in icefall from @csukuangfj , possibly just merged make some changes to use jsonl not json).

@csukuangfj
Copy link
Collaborator

@AmirHussein96
Copy link
Contributor Author

AmirHussein96 commented Jun 7, 2022

@AmirHussein96 I would suggest you using the following two files are a reference:

Yes I followed these changes with increasing the number of workers from 2 to 8 per GPU, and I am using 2 GPUs. The utilization is shown below, it is much better now 10h-12h per epoch compared to 2 days previously:
Capture

@csukuangfj
Copy link
Collaborator

@AmirHussein96
Please use load_manifest instead of load_manifest_lazy if the manifest can be read into the memory all at once.
It can speed up things a lot.

@AmirHussein96
Copy link
Contributor Author

AmirHussein96 commented Jun 27, 2022

@csukuangfj

Recent updates:

The conformer_ctc training for 45 epochs has finished, the tensorboard is here: https://tensorboard.dev/experiment/QYNzOi52RwOX8yvtpl3hMw/#scalars

I tried the following decoding methods: (Note I had to reduce the max_active_states from 10000 to 5000 to fit on P100 16GB GPU.

  1. whole-lattice-rescoring: ./conformer_ctc/decode.py --epoch 45 --avg 5 --exp-dir conformer_ctc/exp_5000_att0.8 --lang-dir data/lang_bpe_5000 --method whole-lattice-rescoring --nbest-scale 0.5 --lm-dir data/lm --max-duration 30 --num-paths 1000 --num-workers 20
    results=> dev: 15.62 , test: 15.01

  2. Attention-decoder: ./conformer_ctc/decode.py --epoch 45 --avg 5 --max-duration 30 --num-paths 1000 --exp-dir conformer_ctc/exp_5000_att0.8 --lang-dir data/lang_bpe_5000 --method attention-decoder --shuffle False --enable-musan False --enable-spec-aug False --nbest-scale 0.5 --num-workers 20
    results=> dev: 15.89 , test: 15.08

Looks like there is still considerable gap compared to similar Espnet setup WER with decoding beam search=20, no LM:
(Transformer=> dev: 14.6 , test: 14.2 ; Conformer => test: 13.7)
https://github.com/espnet/espnet/blob/master/egs/mgb2/asr1/RESULTS.md

@AmirHussein96
Copy link
Contributor Author

AmirHussein96 commented Jun 28, 2022

Thanks!

How did you choose the following thresholds?

def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 0.5 <= c.duration <= 30.0

Could you update ./local/display_manifest_statistics.py?

Also, could you please try our pruned RNN-T recipe, which not only has a lower WER on LibriSpeech/GigaSpeech but also has a faster decoding speed with much less memory consumption?

I would recommend you using https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5 as a starting point.
@csukuangfj

I tried the RNNT on MGB2 with the following command ./pruned_transducer_stateless5/train.py \ --world-size 4 \ --num-epochs 40 \ --start-epoch 1 \ --exp-dir pruned_transducer_stateless5/exp \ --max-duration 30
and this is the tensorboard https://tensorboard.dev/experiment/xuOlsEwGRay3qspf7HezLw/#scalars&_smoothingWeight=0.693
The validation loss looks good but the training loss is weird, is this expected?
[errors-7078120.txt](https://github.com/k2-fsa/icefall/files/9000144/errors-7078

For some reason the RNNT asks for a lot of memory that does not fit into V100 16GB, any ideas why this is happening? errors-7078120.txt
120.txt)

@danpovey
Copy link
Collaborator

danpovey commented Jun 30, 2022

It looks like, for 1 of your jobs, an inf has got into the pruned_loss at some point. But this may only affect the diagnostics.
Which version of the scripts were you using here, I don't see these scripts in your PR.
Edit: in the logs I see that it is pruned_transducer_stateless5. You can try @yaozengwei 's recent PR where he simplifies the RandomCombine module, removing the linear layers. I have seen experiments diverge, in half precision, due to the Linear module in the RandomCombine module causing large outputs, which has been removed in that PR. That is possibly the issue, anyway.

@AmirHussein96
Copy link
Contributor Author

AmirHussein96 commented Jul 5, 2022

It looks like, for 1 of your jobs, an inf has got into the pruned_loss at some point. But this may only affect the diagnostics. Which version of the scripts were you using here, I don't see these scripts in your PR. Edit: in the logs I see that it is pruned_transducer_stateless5. You can try @yaozengwei 's recent PR where he simplifies the RandomCombine module, removing the linear layers. I have seen experiments diverge, in half precision, due to the Linear module in the RandomCombine module causing large outputs, which has been removed in that PR. That is possibly the issue, anyway.

Hi @danpovey , apologize for the late reply, I have pushed the updated pruned transducer stateless config that I am using with MGB2 please check it and let me know what do you think.

The details about k2 version I am using are below:

Build type: Release
Git SHA1: 3c606c27045750bbbb7a289d8b2b09825dea521a
Git date: Mon Jun 27 03:06:58 2022
Cuda used to build k2: 10.2
cuDNN used to build k2: 8.0.5
Python version used to build k2: 3.8
OS used to build k2: Red Hat Enterprise Linux Server release 7.8 (Maipo)
CMake version: 3.18.0
GCC version: 8.4.0
CMAKE_CUDA_FLAGS:   -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_60,code=sm_60 -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options -Wall  --compiler-options -Wno-strict-overflow  --compiler-options -Wno-unknown-pragmas
CMAKE_CXX_FLAGS:  -D_GLIBCXX_USE_CXX11_ABI=0 -Wno-unused-variable  -Wno-strict-overflow
PyTorch version used to build k2: 1.7.1
PyTorch is using Cuda: 10.2 ```


The lhotse version I am using is: `1.3.0.dev+git.a07121a.clean`
The icefall version I am using is: 
``` >>> icefall.get_env_info()
{'k2-version': '1.16', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3c606c27045750bbbb7a289d8b2b09825dea521a', 'k2-git-date': 'Mon Jun 27 03:06:58 2022', 'lhotse-version': '1.3.0.dev+git.a07121a.clean', 'torch-version': '1.7.1', 'torch-cuda-available': False, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'test', 'icefall-git-sha1': 'e24e6ac-dirty', 'icefall-git-date': 'Mon Jun 27 01:23:06 2022', 'icefall-path': '/alt-arabic/speech/amir/k2/tmp/icefall', 'k2-path': '/alt-arabic/speech/amir/k2/tmp/k2/k2/python/k2/__init__.py', 'lhotse-path': '/alt-arabic/speech/amir/k2/tmp/lhotse/lhotse/__init__.py', 'hostname': 'cribrighthead001', 'IP address': '10.141.255.254'} ```

Copy link
Collaborator

@desh2608 desh2608 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please update your black settings (see https://icefall.readthedocs.io/en/latest/contributing/code-style.html), and reformat the directory using the new default line-length of 88? This can be done by running black mgb2/.

@@ -287,7 +287,8 @@ def get_lr(self):
factor = (
(self.batch**2 + self.lr_batches**2) / self.lr_batches**2
) ** -0.25 * (
((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AmirHussein96 could you remove these formatting changes? Please see #692 where we updated the line-length.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I finished the formatting

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@csukuangfj told me that using this option provides significant WER improvements, and I noticed the same in my experiments. You can try it out if you have some time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the best reported results with stateless transducer are decoded with --use-averaged-model True

@AmirHussein96
Copy link
Contributor Author

The style_check failed on the symbolic links which @csukuangfj asked me to add.

@csukuangfj
Copy link
Collaborator

The style_check failed on the symbolic links which @csukuangfj asked me to add.

Here are the error logs:

./egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py:0:1: E902 FileNotFoundError: [Errno 2] No such file or directory: './egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py'
./egs/mgb2/ASR/conformer_ctc/download_lm.py:0:1: E902 FileNotFoundError: [Errno 2] No such file or directory: './egs/mgb2/ASR/conformer_ctc/download_lm.py'
./egs/mgb2/ASR/conformer_ctc/compile_hlg.py:0:1: E902 FileNotFoundError: [Errno 2] No such file or directory: './egs/mgb2/ASR/conformer_ctc/compile_hlg.py'
./egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py:0:1: E902 FileNotFoundError: [Errno 2] No such file or directory: './egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py'
./egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py:0:1: E902 FileNotFoundError: [Errno 2] No such file or directory: './egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py'

Please recheck your symlink.

@@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/convert_transcript_words_to_tokens.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For instance, this symlink is not correct.

Please fix other symlinks reported by the CI.

@@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/generate_unique_lexicon.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this one. It should be placed in local, not in conformer_ctc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I fixed them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@csukuangfj please let me know if there is anything else I need to do to merge the PR

@csukuangfj csukuangfj added ready and removed ready labels Nov 25, 2022
@@ -0,0 +1,901 @@
#!/usr/bin/env python3
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't make changes to this file, could you please replace it with a symlink to the one from the librispeech recipe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -0,0 +1,157 @@
#!/usr/bin/env python3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you replace it with a symlink to the file from librispeech?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


./pruned_transducer_stateless5/train.py \
--world-size 4 \
--num-epochs 30 \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

30 epochs are trained. Does the combination --epoch 18, --avg 5 produce the best WER among other combinations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

#### 2022-06-04

You can find a pretrained model, training logs, decoding logs, and decoding results at:
https://huggingface.co/AmirHussein/icefall-asr-mgb2-conformer_ctc-2022-27-06
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also upload pretrained.pt?
cpu_jit.pt is useful during inference time, while pretrained.pt is useful for resuming the training.

For the decoding results, could you also upload the following files:

  • errs-xxx
  • recogs-xxx
    Currently, only the decoding logs log-xxx are uploaded, which do not contain the recognition results.

Also, have you tried other decoding methods, e.g., ctc decoding and 1best decoding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried the whole lattice rescoring and the attention decoding. The attention gave me the best results.
I uploaded the errs-xxx and recogs-xxx.

#### 2022-06-04

You can find a pretrained model, training logs, decoding logs, and decoding results at:
https://huggingface.co/AmirHussein/icefall-asr-mgb2-conformer_ctc-2022-27-06
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also provide some test waves and the corresponding transcripts in the above hugging face repo so that we can use them to test your model in sherpa?


You can use
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14/tree/main/test_wavs
as a reference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# Results


### MGB2 all data BPE training results (Stateless Pruned Transducer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you upload the pretrained model, checkpoint, and decoding results to a hugging face repo?

You can use
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14/tree/main
as a reference and see which files should be uploaded to the huggingface repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can not share the stateless transducer at the current stage as it is being used in another project and it is kind of sensitive to the side that supports me with the computation resources. However I am planning to upload it in the near future.

@AmirHussein96
Copy link
Contributor Author

@csukuangfj I addressed all your comments. Please let me know if you have any other comments before merging the PR.

@csukuangfj
Copy link
Collaborator

@AmirHussein96

Thanks for your contribution.

@csukuangfj csukuangfj merged commit 6f71981 into k2-fsa:master Dec 2, 2022
| Decoding method | dev WER | test WER |
|---------------------------|------------|---------|
| attention-decoder | 15.62 | 15.01 |
| whole-lattice-rescoring | 15.89 | 15.08 |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, could you also add the results for 1best decoding and ctc_decoding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will add them.

@csukuangfj
Copy link
Collaborator

By the way, you can try the pre-trained models from this PR by visiting
https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition

You don't need to install anything for that. All you need is a browser.

Screen Shot 2022-12-04 at 4 14 50 PM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants