Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ScaledLSTM as streaming encoder #479

Merged
merged 44 commits into from
Aug 19, 2022
Merged

Conversation

yaozengwei
Copy link
Collaborator

@yaozengwei yaozengwei commented Jul 17, 2022

This PR implements ScaledLSTM as a light streaming encoder with the scaling mechanism in reworked model. We do not need to use time masking during training as in streaming Conformer, since the LSTM computation is essentially streaming.

The ScaledLSTM class is written in pruned_transducer_stateless2/scaling.py, which is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py, using the highly optimized operator _VF.lstm in pytorch.
@danpovey Could you help me check for correctness of the value of fan_in in _reset_parameters function if you are free?

In lstm_transducer_stateless/lstm.py, each encoder layer consists of a lstm module and a feed_forward module, which adopts similar residual connection as in Conformer.

nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in**-0.5 # 1/sqrt(fan_in)
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reset_parameters in torch rnn.py looks like:

    def reset_parameters(self) -> None:
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            init.uniform_(weight, -stdv, stdv)

so I suggest to set:

     scale = self.hidden_size ** -0.5

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Thanks.

@@ -164,6 +170,38 @@ def scaled_embedding_to_embedding(
return embedding


def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add type hint for the return value.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Thanks.

initial_speed: float = 1.0,
**kwargs
):
# Hardcode bidirectional=False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest using

if 'bidirectional' in kwargs:
  assert kwargs['bidirectional'] is False

bidirectional is False by default, so you don't need to pass it below.

To run this file, do:

cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless3/test_scaling_converter.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
python ./pruned_transducer_stateless3/test_scaling_converter.py
python ./lstm_transducer_stateless/test_scaling_converter.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Thanks.

word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used

supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)

# tail padding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a note about the reason why you need to do tail padding here.

Also, I suggest using

num_padded_frames = 35

and using num_padded_frames below to replace the magic number 35

word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used

@@ -0,0 +1,388 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change the author.

Copy link
Collaborator

@csukuangfj csukuangfj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Great work! Just left some minor comments.

--avg 10 \
--jit-trace 1

It will generates 3 files: `encoder_jit_trace.pt`,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
It will generates 3 files: `encoder_jit_trace.pt`,
It will generate 3 files: `encoder_jit_trace.pt`,

) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Stack list of lstm states corresponding to separate utterances into a single
lstm state so that it can be used as an input for lsit when those utterances
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lstm state so that it can be used as an input for lsit when those utterances
lstm state so that it can be used as an input for lstm when those utterances

sequence lengths.
- lengths: a tensor of shape (batch_size,) containing the number of
frames in `embeddings` before padding.
- updated states, whose shape is same as the input states.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- updated states, whose shape is same as the input states.
- updated states, whose shape is the same as the input states.

layer_dropout (float):
Dropout value for model-level warmup (default=0.075).
aux_layer_period (int):
Peroid of auxiliary layers used for randomly combined during training.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Peroid of auxiliary layers used for randomly combined during training.
Period of auxiliary layers used for random combiner during training.


if states is None:
x = self.encoder(x, warmup=warmup)[0]
# torch.jit.trace requires returned types be the same as annotated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# torch.jit.trace requires returned types be the same as annotated
# torch.jit.trace requires returned types to be the same as annotated

parser.add_argument(
"--max-contexts",
type=int,
default=4,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the same default value as decode.py is using.

"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
help="The number of streams that can be decoded parallel",
help="The number of streams that can be decoded in parallel",

params:
It is returned by :func:`get_params`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used



def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# TODO: We can add an option to switch between Conformer and Transformer

# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the changes from the master to filter nan/inf losses.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Thanks for the detailed review. I will modify them.

@yaozengwei yaozengwei added ready and removed ready labels Aug 19, 2022
@RuABraun
Copy link

What is the recommended recipe if one wants to do online ASR? Not clear to me whether this is recommended over the streaming conformer stuff or not? @danpovey

@RuABraun
Copy link

Trained a model on 1k hours and got twice the WER compared to a conformer model (greedy search) unfortunately.

@csukuangfj
Copy link
Collaborator

Trained a model on 1k hours and got twice the WER compared to a conformer model (greedy search) unfortunately.

Do you mean the streaming conformer?


#558 shows that when you use a larger dataset, e.g., 30k hours of data, ScaledLSTM
has competitive results with the streaming conformer in icefall.

@RuABraun
Copy link

RuABraun commented Aug 29, 2022

Interesting. I just realized I wasn't using musan augmentation which I presume is suboptimal. (edit: also was 50% higher WER not 100%)

I'm comparing to the normal conformer, the impression I got from the librispeech results file was that the lstm model will keep up with pruned_transducer_stateless2.

It does make sense to me that a 12 layer LSTM model needs a lot of data to work well.

csukuangfj added a commit to csukuangfj/icefall that referenced this pull request Nov 14, 2022
)

* Support running icefall outside of a git tracked directory. (k2-fsa#470)

* Support running icefall outside of a git tracked directory.

* Minor fixes.

* Rand combine update result (k2-fsa#467)

* update RESULTS.md

* fix test code in pruned_transducer_stateless5/conformer.py

* minor fix

* delete doc

* fix style

* Simplified memory bank for Emformer (k2-fsa#440)

* init files

* use average value as memory vector for each chunk

* change tail padding length from right_context_length to chunk_length

* correct the files, ln -> cp

* fix bug in conv_emformer_transducer_stateless2/emformer.py

* fix doc in conv_emformer_transducer_stateless/emformer.py

* refactor init states for stream

* modify .flake8

* fix bug about memory mask when memory_size==0

* add @torch.jit.export for init_states function

* update RESULTS.md

* minor change

* update README.md

* modify doc

* replace torch.div() with <<

* fix bug, >> -> <<

* use i&i-1 to judge if it is a power of 2

* minor fix

* fix error in RESULTS.md

* update multi_quantization installation (k2-fsa#469)

* update multi_quantization installation

* Update egs/librispeech/ASR/pruned_transducer_stateless6/train.py

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* [Ready] [Recipes] add aishell2 (k2-fsa#465)

* add aishell2

* fix aishell2

* add manifest stats

* update prepare char dict

* fix lint

* setting max duration

* lint

* change context size to 1

* update result

* update hf link

* fix decoding comment

* add more decoding methods

* update result

* change context-size 2 default

* [WIP] Rnn-T LM nbest rescoring (k2-fsa#471)

* add compile_lg.py for aishell2 recipe (k2-fsa#481)

* Add RNN-LM rescoring in fast beam search (k2-fsa#475)

* fix for case of None stats

* Update conformer.py for aishell4 (k2-fsa#484)

* update conformer.py for aishell4

* update conformer.py

* add strict=False when model.load_state_dict

* CTC attention model with reworked Conformer encoder and reworked Transformer decoder (k2-fsa#462)

* ctc attention model with reworked conformer encoder and reworked transformer decoder

* remove unnecessary func

* resolve flake8 conflicts

* fix typos and modify the expr of ScaledEmbedding

* use original beam size

* minor changes to the scripts

* add rnn lm decoding

* minor changes

* check whether q k v weight is None

* check whether q k v weight is None

* check whether q k v weight is None

* style correction

* update results

* update results

* upload the decoding results of rnn-lm to the RESULTS

* upload the decoding results of rnn-lm to the RESULTS

* Update egs/librispeech/ASR/RESULTS.md

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* Update egs/librispeech/ASR/RESULTS.md

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* Update egs/librispeech/ASR/RESULTS.md

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* Update doc to add a link to Nadira Povey's YouTube channel. (k2-fsa#492)

* Update doc to add a link to Nadira Povey's YouTube channel.

* fix a typo

* Add stats about duration and padding proportion (k2-fsa#485)

* add stats about duration and padding proportion

* add  for utt_duration

* add stats for other recipes

* add stats for other 2 recipes

* modify doc

* minor change

* Add modified_beam_search for streaming decode (k2-fsa#489)

* Add modified_beam_search for pruned_transducer_stateless/streaming_decode.py

* refactor

* modified beam search for stateless3,4

* Fix comments

* Add real streamng ci

* Fix using G before assignment in pruned_transducer_stateless/decode.py (k2-fsa#494)

* Support using aidatatang_200zh optionally in aishell training (k2-fsa#495)

* Use aidatatang_200zh optionally in aishell training.

* Fix get_transducer_model() for aishell. (k2-fsa#497)

PR k2-fsa#495 introduces an error. This commit fixes it.

* [WIP] Pruned-transducer-stateless5-for-WenetSpeech (offline and streaming) (k2-fsa#447)

* pruned-rnnt5-for-wenetspeech

* style check

* style check

* add streaming conformer

* add streaming decode

* changes codes for fast_beam_search and export cpu jit

* add modified-beam-search for streaming decoding

* add modified-beam-search for streaming decoding

* change for streaming_beam_search.py

* add README.md and RESULTS.md

* change for style_check.yml

* do some changes

* do some changes for export.py

* add some decode commands for usage

* add streaming results on README.md

* [debug] raise remind when git-lfs not available (k2-fsa#504)

* [debug] raise remind when git-lfs not available

* modify comment

* correction for prepare.sh (k2-fsa#506)

* Set overwrite=True when extracting features in batches. (k2-fsa#487)

* correction for get rank id. (k2-fsa#507)

* Fix no attribute 'data' error.

* minor fixes

* correction for get rank id.

* Add other decoding methods (nbest, nbest oracle, nbest LG) for wenetspeech pruned rnnt2 (k2-fsa#482)

* add other decoding methods for wenetspeech

* changes for RESULTS.md

* add ngram-lm-scale=0.35 results

* set ngram-lm-scale=0.35 as default

* Update README.md

* add nbest-scale for flie name

* Support dynamic chunk streaming training in pruned_transcuder_stateless5 (k2-fsa#454)

* support dynamic chunk streaming training

* Add simulate streaming decoding

* Support streaming decoding

* fix causal

* Minor fixes

* fix streaming decode; add results

* liear_fst_with_self_loops (k2-fsa#512)

* Support exporting to ONNX format (k2-fsa#501)

* WIP: Support exporting to ONNX format

* Minor fixes.

* Combine encoder/decoder/joiner into a single file.

* Revert merging three onnx models into a single one.

It's quite time consuming to extract a sub-graph from the combined
model. For instance, it takes more than one hour to extract
the encoder model.

* Update CI to test ONNX models.

* Decode with exported models.

* Fix typos.

* Add more doc.

* Remove ncnn as it is not fully tested yet.

* Fix as_strided for streaming conformer.

* Convert ScaledEmbedding to nn.Embedding for inference. (k2-fsa#517)

* Convert ScaledEmbedding to nn.Embedding for inference.

* Fix CI style issues.

* Fix preparing char based lang and add multiprocessing for wenetspeech text segmentation (k2-fsa#513)

* add multiprocessing for wenetspeech text segmentation

* Fix preparing char based lang for wenetspeech

* fix style

Co-authored-by: WeijiZhuang <zhuangweiji@xiaomi.com>

* change for pruned rnnt5 train.py (k2-fsa#519)

* fix about tensorboard (k2-fsa#516)

* fix metricstracker

* fix style

* Merging onnx models (k2-fsa#518)

* add export function of onnx-all-in-one to export.py

* add onnx_check script for all-in-one onnx model

* minor fix

* remove unused arguments

* add onnx-all-in-one test

* fix style

* fix style

* fix requirements

* fix input/output names

* fix installing onnx_graphsurgeon

* fix instaliing onnx_graphsurgeon

* revert to previous requirements.txt

* fix minor

* Fix loading sampler state dict. (k2-fsa#421)

* Fix loading sampler state dict.

* skip scan_pessimistic_batches_for_oom if params.start_batch > 0

* fix torchaudio version (k2-fsa#524)

* fix torchaudio version

* fix torchaudio version

* Fix computing averaged loss in the aishell recipe. (k2-fsa#523)

* Fix computing averaged loss in the aishell recipe.

* Set find_unused_parameters optionally.

* Sort results to make it more convenient to compare decoding results (k2-fsa#522)

* Sort result to make it more convenient to compare decoding results

* Add cut_id to recognition results

* add cut_id to results for all recipes

* Fix torch.jit.script

* Fix comments

* Minor fixes

* Fix torch.jit.tracing for Pytorch version before v1.9.0

* Add function display_and_save_batch in wenetspeech/pruned_transducer_stateless2/train.py (k2-fsa#528)

* Add function display_and_save_batch in egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py

* Modify function: display_and_save_batch

* Delete empty line in pruned_transducer_stateless2/train.py

* Modify code format

* Filter non-finite losses (k2-fsa#525)

* Filter non-finite losses

* Fixes after review

* propagate changes from k2-fsa#525 to other librispeech recipes (k2-fsa#531)

* propaga changes from k2-fsa#525 to other librispeech recipes

* refactor display_and_save_batch to utils

* fixed typo

* reformat code style

* Fix not enough values to unpack error . (k2-fsa#533)

* Use ScaledLSTM as streaming encoder (k2-fsa#479)

* add ScaledLSTM

* add RNNEncoderLayer and RNNEncoder classes in lstm.py

* add RNN and Conv2dSubsampling classes in lstm.py

* hardcode bidirectional=False

* link from pruned_transducer_stateless2

* link scaling.py pruned_transducer_stateless2

* copy from pruned_transducer_stateless2

* modify decode.py pretrained.py test_model.py train.py

* copy streaming decoding files from pruned_transducer_stateless2

* modify streaming decoding files

* simplified code in ScaledLSTM

* flat weights after scaling

* pruned2 -> pruned4

* link __init__.py

* fix style

* remove add_model_arguments

* modify .flake8

* fix style

* fix scale value in scaling.py

* add random combiner for training deeper model

* add using proj_size

* add scaling converter for ScaledLSTM

* support jit trace

* add using averaged model in export.py

* modify test_model.py, test if the model can be successfully exported by jit.trace

* modify pretrained.py

* support streaming decoding

* fix model.py

* Add cut_id to recognition results

* Add cut_id to recognition results

* do not pad in Conv subsampling module; add tail padding during decoding.

* update RESULTS.md

* minor fix

* fix doc

* update README.md

* minor change, filter infinite loss

* remove the condition of raise error

* modify type hint for the return value in model.py

* minor change

* modify RESULTS.md

Co-authored-by: pkufool <wkang.pku@gmail.com>

* Update asr_datamodule.py (k2-fsa#538)

minor file names correction

* minor fixes to LSTM streaming model (k2-fsa#537)

* Pruned transducer stateless2 for AISHELL-1 (k2-fsa#536)

* Fix not enough values to unpack error .

* [WIP] Pruned transducer stateless2 for AISHELL-1

* fix the style issue

* code format for black

* add pruned-transducer-stateless2 results for AISHELL-1

* simplify result

* consider case of empty tensor (k2-fsa#540)

* fixed import quantization is none (k2-fsa#541)

Signed-off-by: shanguanma <nanr9544@gmail.com>

Signed-off-by: shanguanma <nanr9544@gmail.com>
Co-authored-by: shanguanma <nanr9544@gmail.com>

* fix typo for export jit script (k2-fsa#544)

* some small changes for aidatatang_200zh (k2-fsa#542)

* Update prepare.sh

* Update compute_fbank_aidatatang_200zh.py

* fixed no cut_id error in decode_dataset (k2-fsa#549)

* fixed import quantization is none

Signed-off-by: shanguanma <nanr9544@gmail.com>

* fixed no cut_id error in decode_dataset

Signed-off-by: shanguanma <nanr9544@gmail.com>

* fixed more than one "#"

Signed-off-by: shanguanma <nanr9544@gmail.com>

* fixed code style

Signed-off-by: shanguanma <nanr9544@gmail.com>

Signed-off-by: shanguanma <nanr9544@gmail.com>
Co-authored-by: shanguanma <nanr9544@gmail.com>

* Add clamping operation in Eve optimizer for all scalar weights to avoid (k2-fsa#550)

non stable training in some scenarios. The clamping range is set to (-10,2).
 Note that this change may cause unexpected effect if you resume
training from a model that is trained without clamping.

* minor changes for correct path names && import module text2segments.py (k2-fsa#552)

* Update asr_datamodule.py

minor file names correction

* minor changes for correct path names && import module text2segments.py

* fix scaling converter test for decoder(predictor). (k2-fsa#553)

* Disable CUDA_LAUNCH_BLOCKING in wenetspeech recipes. (k2-fsa#554)

* Disable CUDA_LAUNCH_BLOCKING in wenetspeech recipes.

* minor fixes

* Check that read_manifests_if_cached returns a non-empty dict. (k2-fsa#555)

* Modified prepare_transcripts.py and preprare_lexicon.py of tedlium3 recipe (k2-fsa#567)

* Use modified ctc topo when vocab size is > 500 (k2-fsa#568)

* Add LSTM for the multi-dataset setup. (k2-fsa#558)

* Add LSTM for the multi-dataset setup.

* Add results

* fix style issues

* add missing file

* Adding Dockerfile for Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8 (k2-fsa#572)

* Changed Dockerfile

* Update Dockerfile

* Dockerfile

* Update README.md

* Add Dockerfiles

* Update README.md

Removed misleading CUDA version, as the Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8 Dockerfile can only support CUDA versions >11.0.

* support exporting to ncnn format via PNNX (k2-fsa#571)

* Small fixes to the transducer training doc (k2-fsa#575)

* Update kaldifeat in CI tests (k2-fsa#583)

* padding zeros (k2-fsa#591)

* Gradient filter for training lstm model (k2-fsa#564)

* init files

* add gradient filter module

* refact getting median value

* add cutoff for grad filter

* delete comments

* apply gradient filter in LSTM module, to filter both input and params

* fix typing and refactor

* filter with soft mask

* rename lstm_transducer_stateless2 to lstm_transducer_stateless3

* fix typos, and update RESULTS.md

* minor fix

* fix return typing

* fix typo

* Modified train.py of tedlium3 models (k2-fsa#597)

* Add dill to requirements.txt (k2-fsa#613)

* Add dill to requirements.txt

* Disable style check for python 3.7

* update docs (k2-fsa#611)

* update docs

Co-authored-by: unknown <mazhihao@jshcbd.cn>
Co-authored-by: KajiMaCN <moonlightshadowmzh@gmail.com>

* exporting projection layers of joiner separately for onnx (k2-fsa#584)

* exporting projection layers of joiner separately for onnx

* Remove all-in-one for onnx export (k2-fsa#614)

* Remove all-in-one for onnx export

* Exit on error for CI

* Modify ActivationBalancer for speed (k2-fsa#612)

* add a probability to apply ActivationBalancer

* minor fix

* minor fix

* Support exporting to ONNX for the wenetspeech recipe (k2-fsa#615)

* Support exporting to ONNX for the wenetspeech recipe

* Add doc about model export (k2-fsa#618)

* Add doc about model export

* fix typos

* Fix links in the doc (k2-fsa#619)

* fix type hints for decode.py (k2-fsa#623)

* Support exporting LSTM with projection to ONNX (k2-fsa#621)

* Support exporting LSTM with projection to ONNX

* Add missing files

* small fixes

* CSJ Data Preparation (k2-fsa#617)

* workspace setup

* csj prepare done

* Change compute_fbank_musan.py t soft link

* add description

* change lhotse prepare csj command

* split train-dev here

* Add header

* remove debug

* save manifest_statistics

* generate transcript in Lhotse

* update comments in config file

* fix number of parameters in RESULTS.md (k2-fsa#627)

* Add Shallow fusion in modified_beam_search (k2-fsa#630)

* Add utility for shallow fusion

* test batch size == 1 without shallow fusion

* Use shallow fusion for modified-beam-search

* Modified beam search with ngram rescoring

* Fix code according to review

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* Add kaldifst to requirements.txt (k2-fsa#631)

* Install kaldifst for GitHub actions (k2-fsa#632)

* Install kaldifst for GitHub actions

* Update train.py (k2-fsa#635)

Add the missing step to add the arguments to the parser.

* Fix type hints for decode.py (k2-fsa#638)

* Fix type hints for decode.py

* Fix flake8

* fix typos (k2-fsa#639)

* Remove onnx and onnxruntime from requirements.txt (k2-fsa#640)

* Remove onnx and onnxruntime from requirements.txt

* Checkout the LM for aishell explicitly (k2-fsa#642)

* Get timestamps during decoding (k2-fsa#598)

* print out timestamps during decoding

* add word-level alignments

* support to compute mean symbol delay with word-level alignments

* print variance of symbol delay

* update doc

* support to compute delay for pruned_transducer_stateless4

* fix bug

* add doc

* remove tail padding for non-streaming models (k2-fsa#625)

* support RNNLM shallow fusion for LSTM transducer

* support RNNLM shallow fusion in stateless5

* update results

* update decoding commands

* update author info

* update

* include previous added decoding method

* minor fixes

* remove redundant test lines

* Update egs/librispeech/ASR/lstm_transducer_stateless2/decode.py

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* Update tdnn_lstm_ctc.rst (k2-fsa#647)

* Update README.md (k2-fsa#649)

* Update tdnn_lstm_ctc.rst (k2-fsa#648)

* fix torchaudio version in dockerfile (k2-fsa#653)

* fix torchaudio version in dockerfile

* remove kaldiio

* update docs

* Add fast_beam_search_LG (k2-fsa#622)

* Add fast_beam_search_LG

* add fast_beam_search_LG to commonly used recipes

* fix ci

* fix ci

* Fix error

* Fix LG log file name (k2-fsa#657)

* resolve conflict with timestamp feature

* resolve conflicts

* minor fixes

* remove testing file

* Apply delay penalty on transducer (k2-fsa#654)

* add delay penalty

* fix CI

* fix CI

* Refactor getting timestamps in fsa-based decoding (k2-fsa#660)

* refactor getting timestamps for fsa-based decoding

* fix doc

* fix bug

* add ctc_decode.py

* fix doc

Signed-off-by: shanguanma <nanr9544@gmail.com>
Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
Co-authored-by: LIyong.Guo <839019390@qq.com>
Co-authored-by: Yuekai Zhang <zhangyuekai@foxmail.com>
Co-authored-by: ezerhouni <61225408+ezerhouni@users.noreply.github.com>
Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com>
Co-authored-by: Daniel Povey <dpovey@gmail.com>
Co-authored-by: Quandwang <quandwang@hotmail.com>
Co-authored-by: Wei Kang <wkang.pku@gmail.com>
Co-authored-by: boji123 <boji123@aliyun.com>
Co-authored-by: Lucky Wong <lekai.huang@gmail.com>
Co-authored-by: LIyong.Guo <guonwpu@qq.com>
Co-authored-by: Weiji Zhuang <zhuangweiji@foxmail.com>
Co-authored-by: WeijiZhuang <zhuangweiji@xiaomi.com>
Co-authored-by: Yunusemre <yunusemreozkose@gmail.com>
Co-authored-by: FNLPprojects <linxinzhulxz@gmail.com>
Co-authored-by: yangsuxia <34536059+yangsuxia@users.noreply.github.com>
Co-authored-by: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com>
Co-authored-by: rickychanhoyin <ricky.hoyin.chan@gmail.com>
Co-authored-by: Duo Ma <39255927+shanguanma@users.noreply.github.com>
Co-authored-by: shanguanma <nanr9544@gmail.com>
Co-authored-by: rxhmdia <41623136+rxhmdia@users.noreply.github.com>
Co-authored-by: kobenaxie <572745565@qq.com>
Co-authored-by: shcxlee <113081290+shcxlee@users.noreply.github.com>
Co-authored-by: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com>
Co-authored-by: KajiMaCN <827272056@qq.com>
Co-authored-by: unknown <mazhihao@jshcbd.cn>
Co-authored-by: KajiMaCN <moonlightshadowmzh@gmail.com>
Co-authored-by: Yunusemre <yunusemre.ozkose@sestek.com>
Co-authored-by: Nagendra Goel <nagendra.goel@gmail.com>
Co-authored-by: marcoyang <marcoyang1998@gmail.com>
Co-authored-by: zr_jin <60612200+JinZr@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants