Skip to content

Commit

Permalink
Minor fixes for the commonvoice recipe (#1534)
Browse files Browse the repository at this point in the history
* init commit

* fix for issue #1531

* minor fixes
  • Loading branch information
JinZr committed Mar 8, 2024
1 parent 5df24c1 commit ae61bd4
Show file tree
Hide file tree
Showing 6 changed files with 344 additions and 14 deletions.
1 change: 0 additions & 1 deletion egs/commonvoice/ASR/local/compile_hlg.py

This file was deleted.

168 changes: 168 additions & 0 deletions egs/commonvoice/ASR/local/compile_hlg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_n_gram.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path

import k2
import torch

from icefall.lexicon import Lexicon


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)

return parser.parse_args()


def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
lm:
The language stem base name.
Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))

if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")

first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]

L = k2.arc_sort(L)
G = k2.arc_sort(G)

logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")

logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")

logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")

LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))

logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)

logging.info("Removing disambiguation symbols on LG")

# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels

assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0

LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")

LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)

logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)

logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")

logging.info("Connecting LG")
HLG = k2.connect(HLG)

logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")

return HLG


def main():
args = get_args()
lang_dir = Path(args.lang_dir)

if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return

logging.info(f"Processing {lang_dir}")

HLG = compile_HLG(lang_dir, args.lm)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)

main()
1 change: 0 additions & 1 deletion egs/commonvoice/ASR/local/compile_lg.py

This file was deleted.

149 changes: 149 additions & 0 deletions egs/commonvoice/ASR/local/compile_lg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Kang Wei,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
This script takes as input lang_dir and generates LG from
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from lang_dir/lm/G_3_gram.fst.txt
The generated LG is saved in $lang_dir/LG.pt
"""
import argparse
import logging
from pathlib import Path

import k2
import torch

from icefall.lexicon import Lexicon


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)

return parser.parse_args()


def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
Return:
An FSA representing LG.
"""
lexicon = Lexicon(lang_dir)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))

if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")

first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]

L = k2.arc_sort(L)
G = k2.arc_sort(G)

logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")

logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")

logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")

LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
logging.info(type(LG.aux_labels))

logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)

logging.info("Removing disambiguation symbols on LG")

# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels

assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0

LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")

LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)

logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)

return LG


def main():
args = get_args()
lang_dir = Path(args.lang_dir)

if (lang_dir / "LG.pt").is_file():
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
return

logging.info(f"Processing {lang_dir}")

LG = compile_LG(lang_dir, args.lm)
logging.info(f"Saving LG.pt to {lang_dir}")
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)

main()
9 changes: 9 additions & 0 deletions egs/commonvoice/ASR/local/preprocess_commonvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ def normalize_text(utt: str, language: str) -> str:
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
elif language == "pl":
return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper()
elif language == "yue":
return (
utt.replace(" ", "")
.replace(",", "")
.replace("。", " ")
.replace("?", "")
.replace("!", "")
.replace("?", "")
)
else:
raise NotImplementedError(
f"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,11 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
Expand Down Expand Up @@ -315,8 +315,8 @@ def train_dataloaders(
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
Expand Down Expand Up @@ -383,9 +383,11 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
Expand Down
Loading

0 comments on commit ae61bd4

Please sign in to comment.