-
Notifications
You must be signed in to change notification settings - Fork 192
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
322 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
dataset_path: google/fleurs | ||
dataset_kwargs: | ||
token: True | ||
test_split: test | ||
output_type: generate_until | ||
doc_to_visual: !function utils.fleurs_doc_to_audio | ||
doc_to_text: !function utils.fleurs_doc_to_text | ||
doc_to_target: "transcription" | ||
generation_kwargs: | ||
max_new_tokens: 256 | ||
temperature: 0 | ||
top_p: 1.0 | ||
num_beams: 1 | ||
do_sample: false | ||
process_results: !function utils.fleurs_process_result | ||
metric_list: | ||
- metric: wer | ||
aggregation : !function utils.fleurs_wer | ||
higher_is_better : false | ||
metadata: | ||
- version: 0.0 | ||
lmms_eval_specific_kwargs: | ||
default: | ||
pre_prompt: "" | ||
post_prompt: "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
group: fleurs | ||
task: | ||
- fleurs_en |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
dataset_name: en_us | ||
include: _default_template_yaml | ||
task: fleurs_en |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,291 @@ | ||
import os | ||
import re | ||
import unicodedata | ||
from collections import OrderedDict | ||
|
||
import editdistance as ed | ||
import zhconv | ||
|
||
from lmms_eval.tasks.librispeech.cn_tn import TextNorm | ||
from lmms_eval.tasks.librispeech.whisper_normalizer.basic import BasicTextNormalizer | ||
from lmms_eval.tasks.librispeech.whisper_normalizer.english import EnglishTextNormalizer | ||
|
||
_FLEURS_LANG_TO_ID = OrderedDict( | ||
[ | ||
("Afrikaans", "af"), | ||
("Amharic", "am"), | ||
("Arabic", "ar"), | ||
("Armenian", "hy"), | ||
("Assamese", "as"), | ||
("Asturian", "ast"), | ||
("Azerbaijani", "az"), | ||
("Belarusian", "be"), | ||
("Bengali", "bn"), | ||
("Bosnian", "bs"), | ||
("Bulgarian", "bg"), | ||
("Burmese", "my"), | ||
("Catalan", "ca"), | ||
("Cebuano", "ceb"), | ||
("Mandarin Chinese", "cmn_hans"), | ||
("Cantonese Chinese", "yue_hant"), | ||
("Croatian", "hr"), | ||
("Czech", "cs"), | ||
("Danish", "da"), | ||
("Dutch", "nl"), | ||
("English", "en"), | ||
("Estonian", "et"), | ||
("Filipino", "fil"), | ||
("Finnish", "fi"), | ||
("French", "fr"), | ||
("Fula", "ff"), | ||
("Galician", "gl"), | ||
("Ganda", "lg"), | ||
("Georgian", "ka"), | ||
("German", "de"), | ||
("Greek", "el"), | ||
("Gujarati", "gu"), | ||
("Hausa", "ha"), | ||
("Hebrew", "he"), | ||
("Hindi", "hi"), | ||
("Hungarian", "hu"), | ||
("Icelandic", "is"), | ||
("Igbo", "ig"), | ||
("Indonesian", "id"), | ||
("Irish", "ga"), | ||
("Italian", "it"), | ||
("Japanese", "ja"), | ||
("Javanese", "jv"), | ||
("Kabuverdianu", "kea"), | ||
("Kamba", "kam"), | ||
("Kannada", "kn"), | ||
("Kazakh", "kk"), | ||
("Khmer", "km"), | ||
("Korean", "ko"), | ||
("Kyrgyz", "ky"), | ||
("Lao", "lo"), | ||
("Latvian", "lv"), | ||
("Lingala", "ln"), | ||
("Lithuanian", "lt"), | ||
("Luo", "luo"), | ||
("Luxembourgish", "lb"), | ||
("Macedonian", "mk"), | ||
("Malay", "ms"), | ||
("Malayalam", "ml"), | ||
("Maltese", "mt"), | ||
("Maori", "mi"), | ||
("Marathi", "mr"), | ||
("Mongolian", "mn"), | ||
("Nepali", "ne"), | ||
("Northern-Sotho", "nso"), | ||
("Norwegian", "nb"), | ||
("Nyanja", "ny"), | ||
("Occitan", "oc"), | ||
("Oriya", "or"), | ||
("Oromo", "om"), | ||
("Pashto", "ps"), | ||
("Persian", "fa"), | ||
("Polish", "pl"), | ||
("Portuguese", "pt"), | ||
("Punjabi", "pa"), | ||
("Romanian", "ro"), | ||
("Russian", "ru"), | ||
("Serbian", "sr"), | ||
("Shona", "sn"), | ||
("Sindhi", "sd"), | ||
("Slovak", "sk"), | ||
("Slovenian", "sl"), | ||
("Somali", "so"), | ||
("Sorani-Kurdish", "ckb"), | ||
("Spanish", "es"), | ||
("Swahili", "sw"), | ||
("Swedish", "sv"), | ||
("Tajik", "tg"), | ||
("Tamil", "ta"), | ||
("Telugu", "te"), | ||
("Thai", "th"), | ||
("Turkish", "tr"), | ||
("Ukrainian", "uk"), | ||
("Umbundu", "umb"), | ||
("Urdu", "ur"), | ||
("Uzbek", "uz"), | ||
("Vietnamese", "vi"), | ||
("Welsh", "cy"), | ||
("Wolof", "wo"), | ||
("Xhosa", "xh"), | ||
("Yoruba", "yo"), | ||
("Zulu", "zu"), | ||
] | ||
) | ||
_FLEURS_LANG_SHORT_TO_LONG = {v: k for k, v in _FLEURS_LANG_TO_ID.items()} | ||
|
||
# ImportError: To support decoding audio files, please install 'librosa' and 'soundfile'. | ||
english_normalizer = EnglishTextNormalizer() | ||
chinese_normalizer = TextNorm( | ||
to_banjiao=False, | ||
to_upper=False, | ||
to_lower=False, | ||
remove_fillers=False, | ||
remove_erhua=False, | ||
check_chars=False, | ||
remove_space=False, | ||
cc_mode="", | ||
) | ||
basic_normalizer = BasicTextNormalizer() | ||
|
||
dir_name = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
|
||
def fleurs_doc_to_audio(doc): | ||
return [doc["audio"]] | ||
|
||
|
||
def fleurs_doc_to_text(doc, lmms_eval_specific_kwargs): | ||
pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] | ||
post_prompt = lmms_eval_specific_kwargs["post_prompt"] | ||
return f"{pre_prompt}Please recognize the speech and only output the recognized content:{post_prompt}" | ||
|
||
|
||
def fleurs_process_result(doc, result): | ||
pred = result[0] if len(result) > 0 else "" | ||
|
||
gt = doc["transcription"] | ||
source = doc["path"] | ||
language = doc["language"] | ||
|
||
data_dict = {"gt": gt, "pred": pred, "source": source, "language": language} | ||
|
||
return {"wer": data_dict} | ||
|
||
|
||
PUNCS = "!,.?;:" | ||
|
||
|
||
def remove_sp(text, language): | ||
gt = re.sub(r"<\|.*?\|>", " ", text) | ||
gt = re.sub(rf"\s+", r" ", gt) # Replace consecutive spaces in the text with a single space. | ||
gt = re.sub(f" ?([{PUNCS}])", r"\1", gt) | ||
gt = gt.lstrip(" ") | ||
if language == "cmn_hans": | ||
gt = re.sub(rf"\s+", r"", gt) | ||
return gt | ||
|
||
|
||
class EvaluationTokenizer(object): | ||
"""A generic evaluation-time tokenizer, which leverages built-in tokenizers | ||
in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides | ||
lowercasing, punctuation removal and character tokenization, which are | ||
applied after sacreBLEU tokenization. | ||
Args: | ||
tokenizer_type (str): the type of sacreBLEU tokenizer to apply. | ||
lowercase (bool): lowercase the text. | ||
punctuation_removal (bool): remove punctuation (based on unicode | ||
category) from text. | ||
character_tokenization (bool): tokenize the text to characters. | ||
""" | ||
|
||
SPACE = chr(32) | ||
SPACE_ESCAPE = chr(9601) | ||
# ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"]) | ||
|
||
def __init__( | ||
self, | ||
tokenizer_type: str = "13a", | ||
lowercase: bool = False, | ||
punctuation_removal: bool = False, | ||
character_tokenization: bool = False, | ||
): | ||
from sacrebleu.tokenizers.tokenizer_13a import Tokenizer13a | ||
from sacrebleu.tokenizers.tokenizer_char import TokenizerChar | ||
from sacrebleu.tokenizers.tokenizer_intl import TokenizerV14International | ||
from sacrebleu.tokenizers.tokenizer_ja_mecab import TokenizerJaMecab | ||
from sacrebleu.tokenizers.tokenizer_none import NoneTokenizer | ||
from sacrebleu.tokenizers.tokenizer_zh import TokenizerZh | ||
|
||
TOKENIZERS = { | ||
"none": NoneTokenizer, | ||
"13a": Tokenizer13a, | ||
"intl": TokenizerV14International, | ||
"zh": TokenizerZh, | ||
"ja-mecab": TokenizerJaMecab, | ||
"char": TokenizerChar, | ||
} | ||
|
||
assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}" | ||
self.lowercase = lowercase | ||
self.punctuation_removal = punctuation_removal | ||
self.character_tokenization = character_tokenization | ||
self.tokenizer = TOKENIZERS[tokenizer_type] | ||
# self.tokenizer = tokenizer_none | ||
|
||
@classmethod | ||
def remove_punctuation(cls, sent: str): | ||
"""Remove punctuation based on Unicode category.""" | ||
return cls.SPACE.join(t for t in sent.split(cls.SPACE) if not all(unicodedata.category(c)[0] == "P" for c in t)) | ||
|
||
def tokenize(self, sent: str): | ||
tokenized = self.tokenizer()(sent) | ||
|
||
if self.punctuation_removal: | ||
tokenized = self.remove_punctuation(tokenized) | ||
|
||
if self.character_tokenization: | ||
tokenized = self.SPACE.join(list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE))) | ||
|
||
if self.lowercase: | ||
tokenized = tokenized.lower() | ||
|
||
return tokenized | ||
|
||
|
||
def compute_wer(refs, hyps, language): | ||
distance = 0 | ||
ref_length = 0 | ||
tokenizer = EvaluationTokenizer( | ||
tokenizer_type="none", | ||
lowercase=True, | ||
punctuation_removal=True, | ||
character_tokenization=False, | ||
) | ||
for i in range(len(refs)): | ||
ref = refs[i] | ||
pred = hyps[i] | ||
if language in ["yue_hant"]: | ||
ref = zhconv.convert(ref, "zh-cn") | ||
pred = zhconv.convert(pred, "zh-cn") | ||
if language in ["en"]: | ||
ref = english_normalizer(ref) | ||
pred = english_normalizer(pred) | ||
if language in ["cmn_hans"]: | ||
ref = chinese_normalizer(ref) | ||
pred = chinese_normalizer(pred) | ||
else: | ||
ref = basic_normalizer(ref) | ||
pred = basic_normalizer(pred) | ||
ref_items = tokenizer.tokenize(ref).split() | ||
pred_items = tokenizer.tokenize(pred).split() | ||
if language in ["zh", "yue"]: | ||
ref_items = [x for x in "".join(ref_items)] | ||
pred_items = [x for x in "".join(pred_items)] | ||
if i == 0: | ||
print(f"ref: {ref}") | ||
print(f"pred: {pred}") | ||
print(f"ref_items:\n{ref_items}\n{len(ref_items)}\n{ref_items[0]}") | ||
print(f"pred_items:\n{pred_items}\n{len(ref_items)}\n{ref_items[0]}") | ||
distance += ed.eval(ref_items, pred_items) | ||
ref_length += len(ref_items) | ||
return distance / ref_length | ||
|
||
|
||
def fleurs_wer(results, args): | ||
refs, hyps = [], [] | ||
for result in results: | ||
lan = _FLEURS_LANG_TO_ID[result["language"]] | ||
gt = result["gt"] | ||
response = result["pred"] | ||
gt = remove_sp(gt, lan) | ||
response = remove_sp(response, lan) | ||
refs.append(gt) | ||
hyps.append(response) | ||
wer = compute_wer(refs, hyps, lan) | ||
return wer * 100 |