-
Notifications
You must be signed in to change notification settings - Fork 295
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Streaming Zipformer-Transducer recipe for KsponSpeech (#1651)
- Loading branch information
Showing
32 changed files
with
4,294 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Introduction | ||
KsponSpeech is a large-scale spontaneous speech corpus of Korean. | ||
This corpus contains 969 hours of open-domain dialog utterances, | ||
spoken by about 2,000 native Korean speakers in a clean environment. | ||
|
||
All data were constructed by recording the dialogue of two people | ||
freely conversing on a variety of topics and manually transcribing the utterances. | ||
|
||
The transcription provides a dual transcription consisting of orthography and pronunciation, | ||
and disfluency tags for spontaneity of speech, such as filler words, repeated words, and word fragments. | ||
|
||
The original audio data has a pcm extension. | ||
During preprocessing, it is converted into a file in the flac extension and saved anew. | ||
|
||
KsponSpeech is publicly available on an open data hub site of the Korea government. | ||
The dataset must be downloaded manually. | ||
|
||
For more details, please visit: | ||
|
||
- Dataset: https://aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=realm&dataSetSn=123 | ||
- Paper: https://www.mdpi.com/2076-3417/10/19/6936 | ||
|
||
[./RESULTS.md](./RESULTS.md) contains the latest results. | ||
|
||
# Transducers | ||
There are various folders containing the name `transducer` in this folder. The following table lists the differences among them. | ||
|
||
| | Encoder | Decoder | Comment | | ||
| ---------------------------------------- | -------------------- | ------------------ | ------------------------------------------------- | | ||
| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | streaming version of pruned_transducer_stateless7 | | ||
|
||
The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). We place an additional Conv1d layer right after the input embedding layer. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
## Results | ||
|
||
### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) | ||
|
||
#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) | ||
|
||
Number of model parameters: 79,022,891, i.e., 79.02 M | ||
|
||
##### Training on KsponSpeech (with MUSAN) | ||
|
||
Model: [johnBamma/icefall-asr-ksponspeech-pruned-transducer-stateless7-streaming-2024-06-12](https://huggingface.co/johnBamma/icefall-asr-ksponspeech-pruned-transducer-stateless7-streaming-2024-06-12) | ||
|
||
The CERs are: | ||
|
||
| decoding method | chunk size | eval_clean | eval_other | comment | decoding mode | | ||
|----------------------|------------|------------|------------|---------------------|----------------------| | ||
| greedy search | 320ms | 10.21 | 11.07 | --epoch 30 --avg 9 | simulated streaming | | ||
| greedy search | 320ms | 10.22 | 11.07 | --epoch 30 --avg 9 | chunk-wise | | ||
| fast beam search | 320ms | 10.21 | 11.04 | --epoch 30 --avg 9 | simulated streaming | | ||
| fast beam search | 320ms | 10.25 | 11.08 | --epoch 30 --avg 9 | chunk-wise | | ||
| modified beam search | 320ms | 10.13 | 10.88 | --epoch 30 --avg 9 | simulated streaming | | ||
| modified beam search | 320ms | 10.1 | 10.93 | --epoch 30 --avg 9 | chunk-size | | ||
| greedy search | 640ms | 9.94 | 10.82 | --epoch 30 --avg 9 | simulated streaming | | ||
| greedy search | 640ms | 10.04 | 10.85 | --epoch 30 --avg 9 | chunk-wise | | ||
| fast beam search | 640ms | 10.01 | 10.81 | --epoch 30 --avg 9 | simulated streaming | | ||
| fast beam search | 640ms | 10.04 | 10.7 | --epoch 30 --avg 9 | chunk-wise | | ||
| modified beam search | 640ms | 9.91 | 10.72 | --epoch 30 --avg 9 | simulated streaming | | ||
| modified beam search | 640ms | 9.92 | 10.72 | --epoch 30 --avg 9 | chunk-size | | ||
|
||
Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, | ||
while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. | ||
|
||
The training command is: | ||
|
||
```bash | ||
./pruned_transducer_stateless7_streaming/train.py \ | ||
--world-size 4 \ | ||
--num-epochs 30 \ | ||
--start-epoch 1 \ | ||
--use-fp16 1 \ | ||
--exp-dir pruned_transducer_stateless7_streaming/exp \ | ||
--max-duration 750 \ | ||
--enable-musan True | ||
``` | ||
|
||
The simulated streaming decoding command (e.g., chunk-size=320ms) is: | ||
```bash | ||
for m in greedy_search fast_beam_search modified_beam_search; do | ||
./pruned_transducer_stateless7_streaming/decode.py \ | ||
--epoch 30 \ | ||
--avg 9 \ | ||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \ | ||
--max-duration 600 \ | ||
--decode-chunk-len 32 \ | ||
--decoding-method $m | ||
done | ||
``` | ||
|
||
The streaming chunk-size decoding command (e.g., chunk-size=320ms) is: | ||
```bash | ||
for m in greedy_search modified_beam_search fast_beam_search; do | ||
./pruned_transducer_stateless7_streaming/streaming_decode.py \ | ||
--epoch 30 \ | ||
--avg 9 \ | ||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \ | ||
--decoding-method $m \ | ||
--decode-chunk-len 32 \ | ||
--num-decode-streams 2000 | ||
done | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) | ||
# | ||
# See ../../../../LICENSE for clarification regarding multiple authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import argparse | ||
import logging | ||
import os | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
import sentencepiece as spm | ||
import torch | ||
from filter_cuts import filter_cuts | ||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter | ||
from lhotse.recipes.utils import read_manifests_if_cached | ||
|
||
from icefall.utils import get_executor, str2bool | ||
|
||
# Torch's multithreaded behavior needs to be disabled or | ||
# it wastes a lot of CPU and slow things down. | ||
# Do this outside of main() in case it needs to take effect | ||
# even when we are not invoking the main (e.g. when spawning subprocesses). | ||
torch.set_num_threads(1) | ||
torch.set_num_interop_threads(1) | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--bpe-model", | ||
type=str, | ||
help="""Path to the bpe.model. If not None, we will remove short and | ||
long utterances before extracting features""", | ||
) | ||
|
||
parser.add_argument( | ||
"--dataset", | ||
type=str, | ||
help="""Dataset parts to compute fbank. If None, we will use all""", | ||
) | ||
|
||
parser.add_argument( | ||
"--perturb-speed", | ||
type=str2bool, | ||
default=True, | ||
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", | ||
) | ||
parser.add_argument( | ||
"--data-dir", | ||
type=str, | ||
default="data", | ||
help="""Path of data directory""", | ||
) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def compute_fbank_speechtools( | ||
bpe_model: Optional[str] = None, | ||
dataset: Optional[str] = None, | ||
perturb_speed: Optional[bool] = False, | ||
data_dir: Optional[str] = "data", | ||
): | ||
src_dir = Path(data_dir) / "manifests" | ||
output_dir = Path(data_dir) / "fbank" | ||
num_jobs = min(4, os.cpu_count()) | ||
num_mel_bins = 80 | ||
|
||
if bpe_model: | ||
logging.info(f"Loading {bpe_model}") | ||
sp = spm.SentencePieceProcessor() | ||
sp.load(bpe_model) | ||
|
||
if dataset is None: | ||
dataset_parts = ( | ||
"train", | ||
"dev", | ||
"eval_clean", | ||
"eval_other", | ||
) | ||
else: | ||
dataset_parts = dataset.split(" ", -1) | ||
|
||
prefix = "ksponspeech" | ||
suffix = "jsonl.gz" | ||
logging.info(f"Read manifests...") | ||
manifests = read_manifests_if_cached( | ||
dataset_parts=dataset_parts, | ||
output_dir=src_dir, | ||
prefix=prefix, | ||
suffix=suffix, | ||
) | ||
assert manifests is not None | ||
|
||
assert len(manifests) == len(dataset_parts), ( | ||
len(manifests), | ||
len(dataset_parts), | ||
list(manifests.keys()), | ||
dataset_parts, | ||
) | ||
|
||
if torch.cuda.is_available(): | ||
# Use cuda for fbank compute | ||
device = "cuda" | ||
else: | ||
device = "cpu" | ||
logging.info(f"Device: {device}") | ||
|
||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, device=device)) | ||
|
||
with get_executor() as ex: # Initialize the executor only once. | ||
logging.info(f"Executor: {ex}") | ||
for partition, m in manifests.items(): | ||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" | ||
if (output_dir / cuts_filename).is_file(): | ||
logging.info(f"{partition} already exists - skipping.") | ||
continue | ||
logging.info(f"Processing {partition}") | ||
cut_set = CutSet.from_manifests( | ||
recordings=m["recordings"], | ||
supervisions=m["supervisions"], | ||
) | ||
|
||
# Filter duration | ||
cut_set = cut_set.filter( | ||
lambda x: x.duration > 1 and x.sampling_rate == 16000 | ||
) | ||
|
||
if "train" in partition: | ||
if bpe_model: | ||
cut_set = filter_cuts(cut_set, sp) | ||
if perturb_speed: | ||
logging.info(f"Doing speed perturb") | ||
cut_set = ( | ||
cut_set | ||
+ cut_set.perturb_speed(0.9) | ||
+ cut_set.perturb_speed(1.1) | ||
) | ||
logging.info(f"Compute & Store features...") | ||
if device == "cuda": | ||
cut_set = cut_set.compute_and_store_features_batch( | ||
extractor=extractor, | ||
storage_path=f"{output_dir}/{prefix}_feats_{partition}", | ||
num_workers=4, | ||
storage_type=LilcomChunkyWriter, | ||
) | ||
else: | ||
cut_set = cut_set.compute_and_store_features( | ||
extractor=extractor, | ||
storage_path=f"{output_dir}/{prefix}_feats_{partition}", | ||
# when an executor is specified, make more partitions | ||
num_jobs=num_jobs if ex is None else 80, | ||
executor=ex, | ||
storage_type=LilcomChunkyWriter, | ||
) | ||
cut_set.to_file(output_dir / cuts_filename) | ||
|
||
|
||
if __name__ == "__main__": | ||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
|
||
logging.basicConfig(format=formatter, level=logging.INFO) | ||
args = get_args() | ||
logging.info(vars(args)) | ||
compute_fbank_speechtools( | ||
bpe_model=args.bpe_model, | ||
dataset=args.dataset, | ||
perturb_speed=args.perturb_speed, | ||
data_dir=args.data_dir, | ||
) |
Oops, something went wrong.