Skip to content

Commit

Permalink
Init Fleurs
Browse files Browse the repository at this point in the history
  • Loading branch information
kcz358 committed Jan 21, 2025
1 parent b1fbf55 commit fddbda4
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 0 deletions.
25 changes: 25 additions & 0 deletions lmms_eval/tasks/fleurs/_default_template_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
dataset_path: google/fleurs
dataset_kwargs:
token: True
test_split: test
output_type: generate_until
doc_to_visual: !function utils.fleurs_doc_to_audio
doc_to_text: !function utils.fleurs_doc_to_text
doc_to_target: "transcription"
generation_kwargs:
max_new_tokens: 256
temperature: 0
top_p: 1.0
num_beams: 1
do_sample: false
process_results: !function utils.fleurs_process_result
metric_list:
- metric: wer
aggregation : !function utils.fleurs_wer
higher_is_better : false
metadata:
- version: 0.0
lmms_eval_specific_kwargs:
default:
pre_prompt: ""
post_prompt: ""
3 changes: 3 additions & 0 deletions lmms_eval/tasks/fleurs/fleurs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
group: fleurs
task:
- fleurs_en
3 changes: 3 additions & 0 deletions lmms_eval/tasks/fleurs/fleurs_en.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dataset_name: en_us
include: _default_template_yaml
task: fleurs_en
291 changes: 291 additions & 0 deletions lmms_eval/tasks/fleurs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
import os
import re
import unicodedata
from collections import OrderedDict

import editdistance as ed
import zhconv

from lmms_eval.tasks.librispeech.cn_tn import TextNorm
from lmms_eval.tasks.librispeech.whisper_normalizer.basic import BasicTextNormalizer
from lmms_eval.tasks.librispeech.whisper_normalizer.english import EnglishTextNormalizer

_FLEURS_LANG_TO_ID = OrderedDict(
[
("Afrikaans", "af"),
("Amharic", "am"),
("Arabic", "ar"),
("Armenian", "hy"),
("Assamese", "as"),
("Asturian", "ast"),
("Azerbaijani", "az"),
("Belarusian", "be"),
("Bengali", "bn"),
("Bosnian", "bs"),
("Bulgarian", "bg"),
("Burmese", "my"),
("Catalan", "ca"),
("Cebuano", "ceb"),
("Mandarin Chinese", "cmn_hans"),
("Cantonese Chinese", "yue_hant"),
("Croatian", "hr"),
("Czech", "cs"),
("Danish", "da"),
("Dutch", "nl"),
("English", "en"),
("Estonian", "et"),
("Filipino", "fil"),
("Finnish", "fi"),
("French", "fr"),
("Fula", "ff"),
("Galician", "gl"),
("Ganda", "lg"),
("Georgian", "ka"),
("German", "de"),
("Greek", "el"),
("Gujarati", "gu"),
("Hausa", "ha"),
("Hebrew", "he"),
("Hindi", "hi"),
("Hungarian", "hu"),
("Icelandic", "is"),
("Igbo", "ig"),
("Indonesian", "id"),
("Irish", "ga"),
("Italian", "it"),
("Japanese", "ja"),
("Javanese", "jv"),
("Kabuverdianu", "kea"),
("Kamba", "kam"),
("Kannada", "kn"),
("Kazakh", "kk"),
("Khmer", "km"),
("Korean", "ko"),
("Kyrgyz", "ky"),
("Lao", "lo"),
("Latvian", "lv"),
("Lingala", "ln"),
("Lithuanian", "lt"),
("Luo", "luo"),
("Luxembourgish", "lb"),
("Macedonian", "mk"),
("Malay", "ms"),
("Malayalam", "ml"),
("Maltese", "mt"),
("Maori", "mi"),
("Marathi", "mr"),
("Mongolian", "mn"),
("Nepali", "ne"),
("Northern-Sotho", "nso"),
("Norwegian", "nb"),
("Nyanja", "ny"),
("Occitan", "oc"),
("Oriya", "or"),
("Oromo", "om"),
("Pashto", "ps"),
("Persian", "fa"),
("Polish", "pl"),
("Portuguese", "pt"),
("Punjabi", "pa"),
("Romanian", "ro"),
("Russian", "ru"),
("Serbian", "sr"),
("Shona", "sn"),
("Sindhi", "sd"),
("Slovak", "sk"),
("Slovenian", "sl"),
("Somali", "so"),
("Sorani-Kurdish", "ckb"),
("Spanish", "es"),
("Swahili", "sw"),
("Swedish", "sv"),
("Tajik", "tg"),
("Tamil", "ta"),
("Telugu", "te"),
("Thai", "th"),
("Turkish", "tr"),
("Ukrainian", "uk"),
("Umbundu", "umb"),
("Urdu", "ur"),
("Uzbek", "uz"),
("Vietnamese", "vi"),
("Welsh", "cy"),
("Wolof", "wo"),
("Xhosa", "xh"),
("Yoruba", "yo"),
("Zulu", "zu"),
]
)
_FLEURS_LANG_SHORT_TO_LONG = {v: k for k, v in _FLEURS_LANG_TO_ID.items()}

# ImportError: To support decoding audio files, please install 'librosa' and 'soundfile'.
english_normalizer = EnglishTextNormalizer()
chinese_normalizer = TextNorm(
to_banjiao=False,
to_upper=False,
to_lower=False,
remove_fillers=False,
remove_erhua=False,
check_chars=False,
remove_space=False,
cc_mode="",
)
basic_normalizer = BasicTextNormalizer()

dir_name = os.path.dirname(os.path.abspath(__file__))


def fleurs_doc_to_audio(doc):
return [doc["audio"]]


def fleurs_doc_to_text(doc, lmms_eval_specific_kwargs):
pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
post_prompt = lmms_eval_specific_kwargs["post_prompt"]
return f"{pre_prompt}Please recognize the speech and only output the recognized content:{post_prompt}"


def fleurs_process_result(doc, result):
pred = result[0] if len(result) > 0 else ""

gt = doc["transcription"]
source = doc["path"]
language = doc["language"]

data_dict = {"gt": gt, "pred": pred, "source": source, "language": language}

return {"wer": data_dict}


PUNCS = "!,.?;:"


def remove_sp(text, language):
gt = re.sub(r"<\|.*?\|>", " ", text)
gt = re.sub(rf"\s+", r" ", gt) # Replace consecutive spaces in the text with a single space.
gt = re.sub(f" ?([{PUNCS}])", r"\1", gt)
gt = gt.lstrip(" ")
if language == "cmn_hans":
gt = re.sub(rf"\s+", r"", gt)
return gt


class EvaluationTokenizer(object):
"""A generic evaluation-time tokenizer, which leverages built-in tokenizers
in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides
lowercasing, punctuation removal and character tokenization, which are
applied after sacreBLEU tokenization.
Args:
tokenizer_type (str): the type of sacreBLEU tokenizer to apply.
lowercase (bool): lowercase the text.
punctuation_removal (bool): remove punctuation (based on unicode
category) from text.
character_tokenization (bool): tokenize the text to characters.
"""

SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
# ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"])

def __init__(
self,
tokenizer_type: str = "13a",
lowercase: bool = False,
punctuation_removal: bool = False,
character_tokenization: bool = False,
):
from sacrebleu.tokenizers.tokenizer_13a import Tokenizer13a
from sacrebleu.tokenizers.tokenizer_char import TokenizerChar
from sacrebleu.tokenizers.tokenizer_intl import TokenizerV14International
from sacrebleu.tokenizers.tokenizer_ja_mecab import TokenizerJaMecab
from sacrebleu.tokenizers.tokenizer_none import NoneTokenizer
from sacrebleu.tokenizers.tokenizer_zh import TokenizerZh

TOKENIZERS = {
"none": NoneTokenizer,
"13a": Tokenizer13a,
"intl": TokenizerV14International,
"zh": TokenizerZh,
"ja-mecab": TokenizerJaMecab,
"char": TokenizerChar,
}

assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}"
self.lowercase = lowercase
self.punctuation_removal = punctuation_removal
self.character_tokenization = character_tokenization
self.tokenizer = TOKENIZERS[tokenizer_type]
# self.tokenizer = tokenizer_none

@classmethod
def remove_punctuation(cls, sent: str):
"""Remove punctuation based on Unicode category."""
return cls.SPACE.join(t for t in sent.split(cls.SPACE) if not all(unicodedata.category(c)[0] == "P" for c in t))

def tokenize(self, sent: str):
tokenized = self.tokenizer()(sent)

if self.punctuation_removal:
tokenized = self.remove_punctuation(tokenized)

if self.character_tokenization:
tokenized = self.SPACE.join(list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)))

if self.lowercase:
tokenized = tokenized.lower()

return tokenized


def compute_wer(refs, hyps, language):
distance = 0
ref_length = 0
tokenizer = EvaluationTokenizer(
tokenizer_type="none",
lowercase=True,
punctuation_removal=True,
character_tokenization=False,
)
for i in range(len(refs)):
ref = refs[i]
pred = hyps[i]
if language in ["yue_hant"]:
ref = zhconv.convert(ref, "zh-cn")
pred = zhconv.convert(pred, "zh-cn")
if language in ["en"]:
ref = english_normalizer(ref)
pred = english_normalizer(pred)
if language in ["cmn_hans"]:
ref = chinese_normalizer(ref)
pred = chinese_normalizer(pred)
else:
ref = basic_normalizer(ref)
pred = basic_normalizer(pred)
ref_items = tokenizer.tokenize(ref).split()
pred_items = tokenizer.tokenize(pred).split()
if language in ["zh", "yue"]:
ref_items = [x for x in "".join(ref_items)]
pred_items = [x for x in "".join(pred_items)]
if i == 0:
print(f"ref: {ref}")
print(f"pred: {pred}")
print(f"ref_items:\n{ref_items}\n{len(ref_items)}\n{ref_items[0]}")
print(f"pred_items:\n{pred_items}\n{len(ref_items)}\n{ref_items[0]}")
distance += ed.eval(ref_items, pred_items)
ref_length += len(ref_items)
return distance / ref_length


def fleurs_wer(results, args):
refs, hyps = [], []
for result in results:
lan = _FLEURS_LANG_TO_ID[result["language"]]
gt = result["gt"]
response = result["pred"]
gt = remove_sp(gt, lan)
response = remove_sp(response, lan)
refs.append(gt)
hyps.append(response)
wer = compute_wer(refs, hyps, lan)
return wer * 100

0 comments on commit fddbda4

Please sign in to comment.