Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update speechio whisper ft results #1605

Merged
merged 5 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions egs/multi_zh-hans/ASR/RESULTS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,48 @@
## Results

### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2
#### Whisper
[./whisper](./whisper)

Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.

| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------|
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting |
| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 |

Command for training is:
```bash
pip install -r whisper/requirements.txt

# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54

torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json
```

Command for decoding using fine-tuned models:
```bash
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper
ln -s icefall_asr_multi-hans-zh_whisper/v1.1/epoch-3-avg-10.pt whisper/exp_large_v2/epoch-999.pt

python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
```

Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>


### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model

This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.
Expand Down
4 changes: 2 additions & 2 deletions egs/multi_zh-hans/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare WenetSpeech"
if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then
cd data/fbank
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV_fixed.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L_fixed.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .

Expand Down
50 changes: 49 additions & 1 deletion egs/multi_zh-hans/ASR/whisper/decode.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert

Expand Down Expand Up @@ -214,7 +215,7 @@ def get_parser():
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"],
help="""The model name to use.
""",
)
Expand All @@ -226,6 +227,13 @@ def get_parser():
help="replace whisper encoder forward method to remove input length restriction",
)

parser.add_argument(
"--use-distill-whisper",
type=str2bool,
default=False,
help="Whether to use architecture of distill whisper.",
)

return parser


Expand Down Expand Up @@ -307,6 +315,43 @@ def decode_dataset(
Returns:
Return a dict, whose key may be "beam-search".
"""

def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re

text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("A", "A")
text = text.replace("a", "A")
text = text.replace("b", "B")
text = text.replace("c", "C")
text = text.replace("k", "K")
text = text.replace("t", "T")
text = text.replace(",", "")
text = text.replace("丶", "")
text = text.replace("。", "")
text = text.replace("、", "")
text = text.replace("?", "")
return text

results = []

num_cuts = 0
Expand All @@ -331,6 +376,7 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))

Expand Down Expand Up @@ -430,6 +476,8 @@ def main():

if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
if params.use_distill_whisper:
replace_whisper_decoder_forward()
model = whisper.load_model(params.model_name, "cpu")
if params.epoch > 0:
if params.avg > 1:
Expand Down
91 changes: 21 additions & 70 deletions egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, fbank_dir: str):
- thchs_30_cuts_train.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
- wenetspeech/cuts_L.jsonl.gz
- wenetspeech/cuts_L_fixed.jsonl.gz
"""
self.fbank_dir = Path(fbank_dir)

Expand Down Expand Up @@ -105,7 +105,7 @@ def train_cuts(self) -> CutSet:
# WeNetSpeech
logging.info("Loading WeNetSpeech in lazy mode")
wenetspeech_L_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz"
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
)

# KeSpeech
Expand All @@ -124,10 +124,10 @@ def train_cuts(self) -> CutSet:
aishell_4_L_cuts,
aishell_4_M_cuts,
aishell_4_S_cuts,
alimeeting_cuts,
stcmds_cuts,
primewords_cuts,
magicdata_cuts,
alimeeting_cuts,
wenetspeech_L_cuts,
kespeech_1_cuts,
kespeech_2_cuts,
Expand All @@ -138,10 +138,10 @@ def train_cuts(self) -> CutSet:
len(aishell_4_L_cuts),
len(aishell_4_M_cuts),
len(aishell_4_S_cuts),
len(alimeeting_cuts),
len(stcmds_cuts),
len(primewords_cuts),
len(magicdata_cuts),
len(alimeeting_cuts),
len(wenetspeech_L_cuts),
len(kespeech_1_cuts),
len(kespeech_2_cuts),
Expand All @@ -151,55 +151,13 @@ def train_cuts(self) -> CutSet:
def dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")

# AISHELL
logging.info("Loading Aishell DEV set in lazy mode")
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
)

# AISHELL-2
logging.info("Loading Aishell-2 DEV set in lazy mode")
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)

# Ali-Meeting
logging.info("Loading Ali-Meeting DEV set in lazy mode")
alimeeting_dev_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
)

# MagicData
logging.info("Loading MagicData DEV set in lazy mode")
magicdata_dev_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
)

# KeSpeech
logging.info("Loading KeSpeech DEV set in lazy mode")
kespeech_dev_phase1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
)
kespeech_dev_phase2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
)

# WeNetSpeech
logging.info("Loading WeNetSpeech DEV set in lazy mode")
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
)

return wenetspeech_dev_cuts
# return [
# aishell_dev_cuts,
# aishell2_dev_cuts,
# alimeeting_dev_cuts,
# magicdata_dev_cuts,
# kespeech_dev_phase1_cuts,
# kespeech_dev_phase2_cuts,
# wenetspeech_dev_cuts,
# ]

def test_cuts(self) -> Dict[str, CutSet]:
logging.info("About to get multidataset test cuts")
Expand Down Expand Up @@ -267,30 +225,23 @@ def test_cuts(self) -> Dict[str, CutSet]:
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
)
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
)

return {
"aishell-2_test": aishell2_test_cuts,
"aishell-4": aishell4_test_cuts,
"magicdata_test": magicdata_test_cuts,
"kespeech-asr_test": kespeech_test_cuts,
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
# "aishell_test": aishell_test_cuts,
# "aishell_dev": aishell_dev_cuts,
# "ali-meeting_test": alimeeting_test_cuts,
# "ali-meeting_eval": alimeeting_eval_cuts,
# "aishell-4_test": aishell4_test_cuts,
# "aishell-2_test": aishell2_test_cuts,
# "aishell-2_dev": aishell2_dev_cuts,
# "magicdata_test": magicdata_test_cuts,
# "magicdata_dev": magicdata_dev_cuts,
# "kespeech-asr_test": kespeech_test_cuts,
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
# "wenetspeech_dev": wenetspeech_dev_cuts,
}

# return {
# "alimeeting_test": alimeeting_test_cuts,
# "alimeeting_eval": alimeeting_eval_cuts,
# "aishell_test": aishell_test_cuts,
# "aishell_dev": aishell_dev_cuts,
# "aishell-2_test": aishell2_test_cuts,
# "aishell-2_dev": aishell2_dev_cuts,
# "aishell-4": aishell4_test_cuts,
# "magicdata_test": magicdata_test_cuts,
# "magicdata_dev": magicdata_dev_cuts,
# "kespeech-asr_test": kespeech_test_cuts,
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
# "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
# "wenetspeech_dev": wenetspeech_dev_cuts,
# }
50 changes: 48 additions & 2 deletions egs/multi_zh-hans/ASR/whisper/train.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from torch.nn.functional import pad as pad_tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward

from icefall import diagnostics
Expand Down Expand Up @@ -146,7 +147,7 @@ def get_parser():
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
yuekaizhang marked this conversation as resolved.
Show resolved Hide resolved
choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"],
help="""The model name to use.
""",
)
Expand Down Expand Up @@ -232,6 +233,13 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--use-distill-whisper",
type=str2bool,
default=False,
help="Whether to use architecture of distill whisper.",
)

parser = deepspeed.add_config_arguments(parser)

return parser
Expand Down Expand Up @@ -441,6 +449,42 @@ def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
return torch.stack([tensor for tensor in padded_tensors], dim=0)

def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re

text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("A", "A")
text = text.replace("a", "A")
text = text.replace("b", "B")
text = text.replace("c", "C")
text = text.replace("k", "K")
text = text.replace("t", "T")
text = text.replace(",", "")
text = text.replace("丶", "")
text = text.replace("。", "")
text = text.replace("、", "")
text = text.replace("?", "")
return text

max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
Expand All @@ -459,7 +503,7 @@ def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:

texts = batch["supervisions"]["text"]
# remove spaces in texts
texts = [text.replace(" ", "") for text in texts]
texts = [normalize_text_alimeeting(text) for text in texts]

text_tokens_list = [
list(tokenizer.sot_sequence_including_notimestamps)
Expand Down Expand Up @@ -759,6 +803,8 @@ def run(rank, world_size, args):
logging.info("About to create model")

replace_whisper_encoder_forward()
if params.use_distill_whisper:
replace_whisper_decoder_forward()
model = whisper.load_model(params.model_name, "cpu")
del model.alignment_heads

Expand Down
Loading
Loading