Skip to content

Commit

Permalink
Recipes for open vocabulary keyword spotting (#1428)
Browse files Browse the repository at this point in the history
* English recipe on gigaspeech; Chinese recipe on wenetspeech
  • Loading branch information
pkufool committed Feb 22, 2024
1 parent 13daf73 commit aac7df0
Show file tree
Hide file tree
Showing 57 changed files with 10,203 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@
torch.set_num_interop_threads(1)


def compute_fbank_gigaspeech_dev_test():
def compute_fbank_gigaspeech():
in_out_dir = Path("data/fbank")
# number of workers in dataloader
num_workers = 20

# number of seconds in a batch
batch_duration = 600
batch_duration = 1000

subsets = ("DEV", "TEST")
subsets = ("L", "M", "S", "XS", "DEV", "TEST")

device = torch.device("cpu")
if torch.cuda.is_available():
Expand All @@ -48,12 +48,12 @@ def compute_fbank_gigaspeech_dev_test():
logging.info(f"device: {device}")

for partition in subsets:
cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz"
cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue

raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz"
raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"

logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
Expand All @@ -62,7 +62,7 @@ def compute_fbank_gigaspeech_dev_test():

cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{in_out_dir}/feats_{partition}",
storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}",
num_workers=num_workers,
batch_duration=batch_duration,
overwrite=True,
Expand All @@ -80,7 +80,7 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)

compute_fbank_gigaspeech_dev_test()
compute_fbank_gigaspeech()


if __name__ == "__main__":
Expand Down
16 changes: 4 additions & 12 deletions egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,6 @@ def get_parser():
"Determines batch size dynamically.",
)

parser.add_argument(
"--subset",
type=str,
default="XL",
choices=["XL", "L", "M", "S", "XS"],
help="Which subset to work with",
)

parser.add_argument(
"--num-splits",
type=int,
Expand All @@ -84,7 +76,7 @@ def get_parser():

def compute_fbank_gigaspeech_splits(args):
num_splits = args.num_splits
output_dir = f"data/fbank/{args.subset}_split"
output_dir = f"data/fbank/XL_split"
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"

Expand All @@ -107,12 +99,12 @@ def compute_fbank_gigaspeech_splits(args):
idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}")

cuts_path = output_dir / f"cuts_{args.subset}.{idx}.jsonl.gz"
cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue

raw_cuts_path = output_dir / f"cuts_{args.subset}_raw.{idx}.jsonl.gz"
raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz"

logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
Expand All @@ -121,7 +113,7 @@ def compute_fbank_gigaspeech_splits(args):

cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/feats_{args.subset}_{idx}",
storage_path=f"{output_dir}/gigaspeech_feats_{idx}",
num_workers=args.num_workers,
batch_duration=args.batch_duration,
overwrite=True,
Expand Down
38 changes: 30 additions & 8 deletions egs/gigaspeech/ASR/local/preprocess_gigaspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import re
from pathlib import Path

from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool

# Similar text filtering and normalization procedure as in:
# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Whether to use speed perturbation.",
)

return parser.parse_args()


def normalize_text(
utt: str,
punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
Expand All @@ -42,7 +56,7 @@ def has_no_oov(
return oov_pattern.search(sup.text) is None


def preprocess_giga_speech():
def preprocess_giga_speech(args):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir.mkdir(exist_ok=True)
Expand All @@ -51,6 +65,10 @@ def preprocess_giga_speech():
"DEV",
"TEST",
"XL",
"L",
"M",
"S",
"XS",
)

logging.info("Loading manifest (may take 4 minutes)")
Expand All @@ -71,7 +89,7 @@ def preprocess_giga_speech():

for partition, m in manifests.items():
logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping")
continue
Expand All @@ -94,11 +112,14 @@ def preprocess_giga_speech():
# Run data augmentation that needs to be done in the
# time domain.
if partition not in ["DEV", "TEST"]:
logging.info(
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
)
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
if args.perturb_speed:
logging.info(
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
)
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)

Expand All @@ -107,7 +128,8 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)

preprocess_giga_speech()
args = get_args()
preprocess_giga_speech(args)


if __name__ == "__main__":
Expand Down
61 changes: 40 additions & 21 deletions egs/gigaspeech/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,14 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
exit 1;
fi
# Download XL, DEV and TEST sets by default.
lhotse download gigaspeech --subset auto --host tsinghua \
lhotse download gigaspeech --subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
--host tsinghua \
$dl_dir/password $dl_dir/GigaSpeech
fi

Expand All @@ -118,7 +125,14 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# We assume that you have downloaded the GigaSpeech corpus
# to $dl_dir/GigaSpeech
mkdir -p data/manifests
lhotse prepare gigaspeech --subset auto -j $nj \
lhotse prepare gigaspeech --subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
-j $nj \
$dl_dir/GigaSpeech data/manifests
fi

Expand All @@ -139,8 +153,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi

if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)"
python3 ./local/compute_fbank_gigaspeech_dev_test.py
log "Stage 4: Compute features for L, M, S, XS, DEV and TEST subsets of GigaSpeech."
python3 ./local/compute_fbank_gigaspeech.py
fi

if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
Expand Down Expand Up @@ -176,18 +190,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
fi

if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Prepare phone based lang"
log "Stage 9: Prepare transcript_words.txt and words.txt"
lang_dir=data/lang_phone
mkdir -p $lang_dir

(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt

if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi

if [ ! -f $lang_dir/transcript_words.txt ]; then
gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \
| jq '.text' \
Expand Down Expand Up @@ -238,7 +243,21 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
fi

if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Prepare BPE based lang"
log "Stage 10: Prepare phone based lang"
lang_dir=data/lang_phone
mkdir -p $lang_dir

(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt

if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
fi

if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare BPE based lang"

for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
Expand All @@ -260,8 +279,8 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
done
fi

if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare bigram P"
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Prepare bigram P"

for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
Expand Down Expand Up @@ -291,8 +310,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
done
fi

if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Prepare G"
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "Stage 13: Prepare G"
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm

Expand All @@ -317,8 +336,8 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
fi
fi

if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "Stage 13: Compile HLG"
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "Stage 14: Compile HLG"
./local/compile_hlg.py --lang-dir data/lang_phone

for vocab_size in ${vocab_sizes[@]}; do
Expand Down
5 changes: 4 additions & 1 deletion egs/gigaspeech/ASR/zipformer/asr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
group.add_argument(
"--num-buckets",
type=int,
default=30,
default=100,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
Expand Down Expand Up @@ -368,6 +368,8 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle=False,
)
logging.info("About to create dev dataloader")
Expand Down Expand Up @@ -417,6 +419,7 @@ def train_cuts(self) -> CutSet:
logging.info(
f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode"
)

cuts_train = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)
Expand Down
21 changes: 20 additions & 1 deletion egs/gigaspeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,17 @@ def get_parser():
help="Accumulate stats on activations, print them and exit.",
)

parser.add_argument(
"--scan-for-oom-batches",
type=str2bool,
default=False,
help="""
Whether to scan for oom batches before training, this is helpful for
finding the suitable max_duration, you only need to run it once.
Caution: a little time consuming.
""",
)

parser.add_argument(
"--inf-check",
type=str2bool,
Expand Down Expand Up @@ -1171,9 +1182,16 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)

def remove_short_utt(c: Cut):
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
return T > 0

gigaspeech = GigaSpeechAsrDataModule(args)

train_cuts = gigaspeech.train_cuts()
train_cuts = train_cuts.filter(remove_short_utt)

if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
Expand All @@ -1187,9 +1205,10 @@ def run(rank, world_size, args):
)

valid_cuts = gigaspeech.dev_cuts()
valid_cuts = valid_cuts.filter(remove_short_utt)
valid_dl = gigaspeech.valid_dataloaders(valid_cuts)

if not params.print_diagnostics:
if not params.print_diagnostics and params.scan_for_oom_batches:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
Expand Down
Loading

0 comments on commit aac7df0

Please sign in to comment.