Skip to content

Commit

Permalink
Fine-tune recipe for Zipformer (#1484)
Browse files Browse the repository at this point in the history
1. support finetune zipformer
2. update the usage; set a very large batch count
  • Loading branch information
marcoyang1998 authored Feb 6, 2024
1 parent a813186 commit 7770740
Show file tree
Hide file tree
Showing 4 changed files with 2,664 additions and 6 deletions.
18 changes: 13 additions & 5 deletions egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ def get_parser():
"Determines batch size dynamically.",
)

parser.add_argument(
"--subset",
type=str,
default="XL",
choices=["XL", "L", "M", "S", "XS"],
help="Which subset to work with",
)

parser.add_argument(
"--num-splits",
type=int,
Expand All @@ -76,7 +84,7 @@ def get_parser():

def compute_fbank_gigaspeech_splits(args):
num_splits = args.num_splits
output_dir = "data/fbank/XL_split"
output_dir = f"data/fbank/{args.subset}_split"
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"

Expand All @@ -96,15 +104,15 @@ def compute_fbank_gigaspeech_splits(args):
logging.info(f"device: {device}")

for i in range(start, stop):
idx = f"{i + 1}".zfill(num_digits)
idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}")

cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz"
cuts_path = output_dir / f"cuts_{args.subset}.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue

raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz"
raw_cuts_path = output_dir / f"cuts_{args.subset}_raw.{idx}.jsonl.gz"

logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
Expand All @@ -113,7 +121,7 @@ def compute_fbank_gigaspeech_splits(args):

cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/feats_XL_{idx}",
storage_path=f"{output_dir}/feats_{args.subset}_{idx}",
num_workers=args.num_workers,
batch_duration=args.batch_duration,
overwrite=True,
Expand Down
17 changes: 16 additions & 1 deletion egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
Expand Down Expand Up @@ -475,3 +475,18 @@ def test_other_cuts(self) -> CutSet:
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)

@lru_cache()
def gigaspeech_subset_small_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech subset-S cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")

@lru_cache()
def gigaspeech_dev_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")

@lru_cache()
def gigaspeech_test_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech test cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
Loading

0 comments on commit 7770740

Please sign in to comment.