From 60e74a8393d12f440ac8ae75c50321d55b6699d6 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 21 Feb 2021 13:57:10 +0700 Subject: [PATCH 01/10] :rocket: add support for save model to cloud --- README.md | 41 ++---- .../train_keras_subword_conformer.py | 31 ++++- .../train_tpu_keras_subword_conformer.py | 6 +- .../conformer/train_tpu_subword_conformer.py | 123 ------------------ .../train_keras_subword_contextnet.py | 31 ++++- .../train_tpu_keras_subword_contextnet.py | 7 +- examples/deepspeech2/train_keras_ds2.py | 30 ++++- examples/jasper/train_keras_jasper.py | 26 +++- ...rain_keras_subword_streaming_transducer.py | 31 ++++- setup.py | 2 +- tensorflow_asr/datasets/asr_dataset.py | 9 +- tensorflow_asr/datasets/base_dataset.py | 3 + tensorflow_asr/datasets/keras/asr_dataset.py | 17 ++- .../featurizers/speech_featurizers.py | 3 + .../featurizers/text_featurizers.py | 3 + tensorflow_asr/models/__init__.py | 38 ++++++ tensorflow_asr/runners/base_runners.py | 11 +- tensorflow_asr/utils/utils.py | 30 ++++- 18 files changed, 240 insertions(+), 202 deletions(-) delete mode 100644 examples/conformer/train_tpu_subword_conformer.py diff --git a/README.md b/README.md index d73e6a0f21..1ff59a9809 100755 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ TensorFlowASR implements some automatic speech recognition architectures such as - [TFLite Convertion](#tflite-convertion) - [Features Extraction](#features-extraction) - [Augmentations](#augmentations) -- [Training & Testing](#training--testing) +- [Training & Testing Tutorial](#training--testing-tutorial) - [Corpus Sources and Pretrained Models](#corpus-sources-and-pretrained-models) - [English](#english) - [Vietnamese](#vietnamese) @@ -164,34 +164,17 @@ See [features_extraction](./tensorflow_asr/featurizers/README.md) See [augmentations](./tensorflow_asr/augmentations/README.md) -## Training & Testing - -**Example YAML Config Structure** - -```yaml -speech_config: ... -model_config: ... -decoder_config: ... -learning_config: - train_dataset_config: - augmentation_config: ... - data_paths: ... - tfrecords_dir: ... - eval_dataset_config: - augmentation_config: ... - data_paths: ... - tfrecords_dir: ... - test_dataset_config: - augmentation_config: ... - data_paths: ... - tfrecords_dir: ... - optimizer_config: ... - running_config: - batch_size: 8 - num_epochs: 20 - outdir: ... - log_interval_steps: 500 -``` +## Training & Testing Tutorial + +1. Define config YAML file, see the `config.yml` files in the [example folder](./examples) for reference (you can copy and modify values such as parameters, paths, etc.. to match your local machine configuration) +2. Download your corpus (a.k.a datasets) and create a script to generate `transcripts.tsv` files from your corpus (this is general format used in this project because each dataset has different format). For more detail, see [datasets](./tensorflow_asr/datasets/README.md). **Note:** Make sure your data contain only characters in your language, for example, english has `a` to `z` and `'`. **Do not use `cache` if your dataset size is not fit in the RAM**. +3. [Optional] Generate TFRecords to use `tf.data.TFRecordDataset` for better performance by using the script [create_tfrecords.py](./scripts/create_tfrecords.py) +4. Create vocabulary file (characters or subwords/wordpieces) by defining `language.characters`, using the scripts [generate_vocab_subwords.py](./scripts/generate_vocab_subwords.py) or [generate_vocab_sentencepiece.py](./scripts/generate_vocab_sentencepiece.py). There're predefined ones in [vocabularies](./vocabularies) +5. [Optional] Generate metadata file for your dataset by using script [generate_metadata.py](./scripts/generate_metadata.py). This metadata file contains maximum lengths calculated with your `config.yml` and total number of elements in each dataset, for static shape training and precalculated steps per epoch. +6. For training, see `train_*.py` files in the [example folder](./examples) to see the options +7. For testing, see `test_.*.py` files in the [example folder](./examples) to see the options. **Note:** Testing is currently not supported for TPUs. It will print nothing other than the progress bar in the console, but it will store the predicted transcripts to the file `output_name.tsv` in the `outdir` defined in the config yaml file. After testing is done, the metrics (WER and CER) are calculated from `output_name.tsv`. **If you define the same `output_name`, it will resume the testing from the previous tested batch, which means if the testing is done then it will only calculate the metrics, if you want to run a new test, define a new `output_name` that the file `output.tsv` is not exists or only contains the header** + +**Recommendation**: For better performance, please use **keras builtin training functions** as in `train_keras_*.py` files and/or tfrecords. Keras builtin training uses **infinite dataset**, which avoids the potential last partial batch. See [examples](./examples/) for some predefined ASR models and results diff --git a/examples/conformer/train_keras_subword_conformer.py b/examples/conformer/train_keras_subword_conformer.py index bbba1cadd4..986396c2a6 100644 --- a/examples/conformer/train_keras_subword_conformer.py +++ b/examples/conformer/train_keras_subword_conformer.py @@ -38,6 +38,10 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") +parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance") + +parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") + parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") @@ -79,20 +83,30 @@ if args.tfrecords: train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) + # Update metadata calculated from both train and eval datasets + train_dataset.load_metadata(args.metadata_prefix) + eval_dataset.load_metadata(args.metadata_prefix) + # Use dynamic length + speech_featurizer.reset_length() + text_featurizer.reset_length() else: train_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) with strategy.scope(): @@ -114,7 +128,12 @@ epsilon=config.learning_config.optimizer_config["epsilon"] ) - conformer.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank) + conformer.compile( + optimizer=optimizer, + experimental_steps_per_execution=args.spx, + global_batch_size=global_batch_size, + blank=text_featurizer.blank + ) train_data_loader = train_dataset.create(global_batch_size) eval_data_loader = eval_dataset.create(global_batch_size) @@ -128,5 +147,5 @@ conformer.fit( train_data_loader, epochs=config.learning_config.running_config.num_epochs, validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps ) diff --git a/examples/conformer/train_tpu_keras_subword_conformer.py b/examples/conformer/train_tpu_keras_subword_conformer.py index 8590f269ed..10a0fc12d6 100644 --- a/examples/conformer/train_tpu_keras_subword_conformer.py +++ b/examples/conformer/train_tpu_keras_subword_conformer.py @@ -78,11 +78,13 @@ train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) if args.compute_lengths: diff --git a/examples/conformer/train_tpu_subword_conformer.py b/examples/conformer/train_tpu_subword_conformer.py deleted file mode 100644 index d9a55be207..0000000000 --- a/examples/conformer/train_tpu_subword_conformer.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2021 M. Yusuf Sarıgöz (@monatis) and Huy Le Nguyen (@usimarit) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import math -import argparse -from tensorflow_asr.utils import setup_environment, setup_tpu - -setup_environment() -import tensorflow as tf - -DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") - -tf.keras.backend.clear_session() - -parser = argparse.ArgumentParser(prog="Conformer Training") - -parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") - -parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep") - -parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") - -parser.add_argument("--bs", type=int, default=None, help="Common training and evaluation batch size per TPU core") - -parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab") - -parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") - -parser.add_argument("--compute_lengths", default=False, action="store_true", help="Whether to compute lengths") - -parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") - -parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords") - -parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords") - -args = parser.parse_args() - -tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset -from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer -from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer -from tensorflow_asr.runners.transducer_runners import TransducerTrainer -from tensorflow_asr.models.conformer import Conformer -from tensorflow_asr.optimizers.schedules import TransformerSchedule - -config = Config(args.config) -speech_featurizer = TFSpeechFeaturizer(config.speech_config) - -if args.sentence_piece: - print("Loading SentencePiece model ...") - text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords) -elif args.subwords and os.path.exists(args.subwords): - print("Loading subwords ...") - text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) -else: - print("Generating subwords ...") - text_featurizer = SubwordFeaturizer.build_from_corpus( - config.decoder_config, - corpus_files=args.subwords_corpus - ) - text_featurizer.save_to_file(args.subwords) - -train_dataset = ASRTFRecordDataset( - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) -) - -eval_dataset = ASRTFRecordDataset( - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) -) - -if args.compute_lengths: - train_dataset.update_lengths(args.metadata_prefix) - eval_dataset.update_lengths(args.metadata_prefix) - -# Update metadata calculated from both train and eval datasets -train_dataset.load_metadata(args.metadata_prefix) -eval_dataset.load_metadata(args.metadata_prefix) - -strategy = setup_tpu(args.tpu_address) - -conformer_trainer = TransducerTrainer( - config=config.learning_config.running_config, - text_featurizer=text_featurizer, strategy=strategy -) - -with conformer_trainer.strategy.scope(): - # build model - conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) - conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, - batch_size=args.bs if args.bs is not None else config.learning_config.running_config.batch_size) - conformer.summary(line_length=120) - - optimizer = tf.keras.optimizers.Adam( - TransformerSchedule( - d_model=conformer.dmodel, - warmup_steps=config.learning_config.optimizer_config["warmup_steps"], - max_lr=(0.05 / math.sqrt(conformer.dmodel)) - ), - beta_1=config.learning_config.optimizer_config["beta1"], - beta_2=config.learning_config.optimizer_config["beta2"], - epsilon=config.learning_config.optimizer_config["epsilon"] - ) - -conformer_trainer.compile(model=conformer, optimizer=optimizer, max_to_keep=args.max_ckpts) - -conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.bs, eval_bs=args.bs) diff --git a/examples/contextnet/train_keras_subword_contextnet.py b/examples/contextnet/train_keras_subword_contextnet.py index 8ff26f8fb0..b641dc6b32 100644 --- a/examples/contextnet/train_keras_subword_contextnet.py +++ b/examples/contextnet/train_keras_subword_contextnet.py @@ -36,6 +36,10 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") +parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance") + +parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") + parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") @@ -74,20 +78,30 @@ if args.tfrecords: train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) + # Update metadata calculated from both train and eval datasets + train_dataset.load_metadata(args.metadata_prefix) + eval_dataset.load_metadata(args.metadata_prefix) + # Use dynamic length + speech_featurizer.reset_length() + text_featurizer.reset_length() else: train_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) with strategy.scope(): @@ -109,7 +123,12 @@ epsilon=config.learning_config.optimizer_config["epsilon"] ) - contextnet.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank) + contextnet.compile( + optimizer=optimizer, + experimental_steps_per_execution=args.spx, + global_batch_size=global_batch_size, + blank=text_featurizer.blank + ) train_data_loader = train_dataset.create(global_batch_size) eval_data_loader = eval_dataset.create(global_batch_size) @@ -123,5 +142,5 @@ contextnet.fit( train_data_loader, epochs=config.learning_config.running_config.num_epochs, validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps ) diff --git a/examples/contextnet/train_tpu_keras_subword_contextnet.py b/examples/contextnet/train_tpu_keras_subword_contextnet.py index 92d46a54d2..5293982638 100644 --- a/examples/contextnet/train_tpu_keras_subword_contextnet.py +++ b/examples/contextnet/train_tpu_keras_subword_contextnet.py @@ -78,11 +78,13 @@ train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) if args.compute_lengths: @@ -131,4 +133,5 @@ contextnet.fit( train_data_loader, epochs=config.learning_config.running_config.num_epochs, validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps ) diff --git a/examples/deepspeech2/train_keras_ds2.py b/examples/deepspeech2/train_keras_ds2.py index b9501040de..470e2297d1 100644 --- a/examples/deepspeech2/train_keras_ds2.py +++ b/examples/deepspeech2/train_keras_ds2.py @@ -33,6 +33,8 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replicas") +parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") + parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords dataset") parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") @@ -58,20 +60,30 @@ if args.tfrecords: train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) + # Update metadata calculated from both train and eval datasets + train_dataset.load_metadata(args.metadata_prefix) + eval_dataset.load_metadata(args.metadata_prefix) + # Use dynamic length + speech_featurizer.reset_length() + text_featurizer.reset_length() else: train_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) # Build DS2 model @@ -83,8 +95,12 @@ ds2_model._build(speech_featurizer.shape) ds2_model.summary(line_length=120) - ds2_model.compile(optimizer=config.learning_config.optimizer_config, - global_batch_size=global_batch_size, blank=text_featurizer.blank) + ds2_model.compile( + optimizer=config.learning_config.optimizer_config, + experimental_steps_per_execution=args.spx, + global_batch_size=global_batch_size, + blank=text_featurizer.blank + ) train_data_loader = train_dataset.create(global_batch_size) eval_data_loader = eval_dataset.create(global_batch_size) @@ -98,5 +114,5 @@ ds2_model.fit( train_data_loader, epochs=config.learning_config.running_config.num_epochs, validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps ) diff --git a/examples/jasper/train_keras_jasper.py b/examples/jasper/train_keras_jasper.py index 9cfa48669c..bc0acab373 100644 --- a/examples/jasper/train_keras_jasper.py +++ b/examples/jasper/train_keras_jasper.py @@ -33,6 +33,10 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replicas") +parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance") + +parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") + parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords dataset") parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") @@ -58,12 +62,20 @@ if args.tfrecords: train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) + # Update metadata calculated from both train and eval datasets + train_dataset.load_metadata(args.metadata_prefix) + eval_dataset.load_metadata(args.metadata_prefix) + # Use dynamic length + speech_featurizer.reset_length() + text_featurizer.reset_length() else: train_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, @@ -82,8 +94,12 @@ jasper._build(speech_featurizer.shape) jasper.summary(line_length=120) - jasper.compile(optimizer=config.learning_config.optimizer_config, - global_batch_size=global_batch_size, blank=text_featurizer.blank) + jasper.compile( + optimizer=config.learning_config.optimizer_config, + experimental_steps_per_execution=args.spx, + global_batch_size=global_batch_size, + blank=text_featurizer.blank + ) train_data_loader = train_dataset.create(global_batch_size) eval_data_loader = eval_dataset.create(global_batch_size) @@ -97,5 +113,5 @@ jasper.fit( train_data_loader, epochs=config.learning_config.running_config.num_epochs, validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps ) diff --git a/examples/streaming_transducer/train_keras_subword_streaming_transducer.py b/examples/streaming_transducer/train_keras_subword_streaming_transducer.py index 573edb9f13..e0d025f609 100644 --- a/examples/streaming_transducer/train_keras_subword_streaming_transducer.py +++ b/examples/streaming_transducer/train_keras_subword_streaming_transducer.py @@ -35,6 +35,10 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") +parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance") + +parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") + parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") @@ -72,20 +76,30 @@ if args.tfrecords: train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) + # Update metadata calculated from both train and eval datasets + train_dataset.load_metadata(args.metadata_prefix) + eval_dataset.load_metadata(args.metadata_prefix) + # Use dynamic length + speech_featurizer.reset_length() + text_featurizer.reset_length() else: train_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) with strategy.scope(): @@ -101,7 +115,12 @@ optimizer = tf.keras.optimizers.get(config.learning_config.optimizer_config) - streaming_transducer.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank) + streaming_transducer.compile( + optimizer=optimizer, + experimental_steps_per_execution=args.spx, + global_batch_size=global_batch_size, + blank=text_featurizer.blank + ) train_data_loader = train_dataset.create(global_batch_size) eval_data_loader = eval_dataset.create(global_batch_size) @@ -115,5 +134,5 @@ streaming_transducer.fit( train_data_loader, epochs=config.learning_config.running_config.num_epochs, validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps ) diff --git a/setup.py b/setup.py index 2d66d12b42..7d6161e567 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ setuptools.setup( name="TensorFlowASR", - version="0.7.6", + version="0.7.7", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index e7b0139e9f..46d5f3e777 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -36,6 +36,7 @@ def __init__(self, augmentations: Augmentation = Augmentation(None), cache: bool = False, shuffle: bool = False, + indefinite: bool = False, drop_remainder: bool = True, use_tf: bool = False, buffer_size: int = BUFFER_SIZE, @@ -43,7 +44,7 @@ def __init__(self, super(ASRDataset, self).__init__( data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, stage=stage, buffer_size=buffer_size, - drop_remainder=drop_remainder, use_tf=use_tf + drop_remainder=drop_remainder, use_tf=use_tf, indefinite=indefinite ) self.speech_featurizer = speech_featurizer self.text_featurizer = text_featurizer @@ -184,6 +185,9 @@ def process(self, dataset: tf.data.Dataset, batch_size: int): if self.shuffle: dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) + if self.indefinite: + dataset = dataset.repeat() + # PADDED BATCH the dataset dataset = dataset.padded_batch( batch_size=batch_size, @@ -230,13 +234,14 @@ def __init__(self, cache: bool = False, shuffle: bool = False, use_tf: bool = False, + indefinite: bool = False, drop_remainder: bool = True, buffer_size: int = BUFFER_SIZE, **kwargs): super(ASRTFRecordDataset, self).__init__( stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, buffer_size=buffer_size, - drop_remainder=drop_remainder, use_tf=use_tf + drop_remainder=drop_remainder, use_tf=use_tf, indefinite=indefinite ) if not self.stage: raise ValueError("stage must be defined, either 'train', 'eval' or 'test'") self.tfrecords_dir = tfrecords_dir diff --git a/tensorflow_asr/datasets/base_dataset.py b/tensorflow_asr/datasets/base_dataset.py index 760778983b..fd23f4cf98 100644 --- a/tensorflow_asr/datasets/base_dataset.py +++ b/tensorflow_asr/datasets/base_dataset.py @@ -31,6 +31,7 @@ def __init__(self, cache: bool = False, shuffle: bool = False, buffer_size: int = BUFFER_SIZE, + indefinite: bool = False, drop_remainder: bool = True, use_tf: bool = False, stage: str = "train", @@ -44,6 +45,8 @@ def __init__(self, self.stage = stage # for defining tfrecords files self.use_tf = use_tf self.drop_remainder = drop_remainder # whether to drop remainder for multi gpu training + self.indefinite = indefinite # Whether to make dataset repeat indefinitely -> avoid the potential last partial batch + if self.indefinite: self.drop_remainder = False # No dropping remainder in indefinite dataset self.total_steps = None # for better training visualization @abc.abstractmethod diff --git a/tensorflow_asr/datasets/keras/asr_dataset.py b/tensorflow_asr/datasets/keras/asr_dataset.py index 6f29bf95ee..be6647eac6 100644 --- a/tensorflow_asr/datasets/keras/asr_dataset.py +++ b/tensorflow_asr/datasets/keras/asr_dataset.py @@ -59,6 +59,9 @@ def process(self, dataset, batch_size): if self.shuffle: dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) + if self.indefinite: + dataset = dataset.repeat() + # PADDED BATCH the dataset dataset = dataset.padded_batch( batch_size=batch_size, @@ -111,18 +114,21 @@ def __init__(self, cache: bool = False, shuffle: bool = False, use_tf: bool = False, + indefinite: bool = False, drop_remainder: bool = True, buffer_size: int = BUFFER_SIZE, **kwargs): ASRTFRecordDataset.__init__( self, stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, data_paths=data_paths, tfrecords_dir=tfrecords_dir, augmentations=augmentations, cache=cache, shuffle=shuffle, - tfrecords_shards=tfrecords_shards, drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf + tfrecords_shards=tfrecords_shards, drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf, + indefinite=indefinite ) ASRDatasetKeras.__init__( self, stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, - drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf + drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf, + indefinite=indefinite ) @tf.function @@ -151,18 +157,21 @@ def __init__(self, cache: bool = False, shuffle: bool = False, use_tf: bool = False, + indefinite: bool = False, drop_remainder: bool = True, buffer_size: int = BUFFER_SIZE, **kwargs): ASRSliceDataset.__init__( self, stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, - drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf + drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf, + indefinite=indefinite ) ASRDatasetKeras.__init__( self, stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, - drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf + drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf, + indefinite=indefinite ) @tf.function diff --git a/tensorflow_asr/featurizers/speech_featurizers.py b/tensorflow_asr/featurizers/speech_featurizers.py index 554b00063d..7421a61eaa 100755 --- a/tensorflow_asr/featurizers/speech_featurizers.py +++ b/tensorflow_asr/featurizers/speech_featurizers.py @@ -235,6 +235,9 @@ def shape(self) -> list: def update_length(self, length: int): self.max_length = max(self.max_length, length) + def reset_length(self): + self.max_length = 0 + @abc.abstractclassmethod def stft(self, signal): raise NotImplementedError() diff --git a/tensorflow_asr/featurizers/text_featurizers.py b/tensorflow_asr/featurizers/text_featurizers.py index f9e633087c..4d00ba7965 100755 --- a/tensorflow_asr/featurizers/text_featurizers.py +++ b/tensorflow_asr/featurizers/text_featurizers.py @@ -51,6 +51,9 @@ def prepand_shape(self) -> list: def update_length(self, length: int): self.max_length = max(self.max_length, length) + def reset_length(self): + self.max_length = 0 + def preprocess_text(self, text): text = unicodedata.normalize("NFC", text.lower()) return text.strip("\n") # remove trailing newline diff --git a/tensorflow_asr/models/__init__.py b/tensorflow_asr/models/__init__.py index 84955496d0..7f37b4ffb1 100644 --- a/tensorflow_asr/models/__init__.py +++ b/tensorflow_asr/models/__init__.py @@ -12,14 +12,52 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import abc +import tempfile import tensorflow as tf +from ..utils.utils import is_cloud_path, is_hdf5_filepath + class Model(tf.keras.Model): def __init__(self, name, **kwargs): super(Model, self).__init__(name=name, **kwargs) + def save(self, filepath, overwrite=True, include_optimizer=True, save_format=None, + signatures=None, options=None, save_traces=True): + if is_cloud_path(filepath) and is_hdf5_filepath(filepath): + _, ext = os.path.splitext(filepath) + with tempfile.NamedTemporaryFile(suffix=ext) as tmp: + super(Model, self).save( + tmp.name, overwrite=overwrite, include_optimizer=include_optimizer, + save_format=save_format, signatures=signatures, options=options, save_traces=save_traces + ) + tf.io.gfile.copy(tmp.name, filepath, overwrite=True) + else: + super(Model, self).save( + filepath, overwrite=overwrite, include_optimizer=include_optimizer, + save_format=save_format, signatures=signatures, options=options, save_traces=save_traces + ) + + def save_weights(self, filepath, overwrite=True, save_format=None, options=None): + if is_cloud_path(filepath) and is_hdf5_filepath(filepath): + _, ext = os.path.splitext(filepath) + with tempfile.NamedTemporaryFile(suffix=ext) as tmp: + super(Model, self).save_weights(tmp.name, overwrite=overwrite, save_format=save_format, options=options) + tf.io.gfile.copy(tmp.name, filepath, overwrite=True) + else: + super(Model, self).save_weights(filepath, overwrite=overwrite, save_format=save_format, options=options) + + def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None): + if is_cloud_path(filepath) and is_hdf5_filepath(filepath): + _, ext = os.path.splitext(filepath) + with tempfile.NamedTemporaryFile(suffix=ext) as tmp: + tf.io.gfile.copy(filepath, tmp.name, overwrite=True) + super(Model, self).load_weights(tmp.name, by_name=by_name, skip_mismatch=skip_mismatch, options=options) + else: + super(Model, self).load_weights(filepath, by_name=by_name, skip_mismatch=skip_mismatch, options=options) + @abc.abstractmethod def _build(self, *args, **kwargs): raise NotImplementedError() diff --git a/tensorflow_asr/runners/base_runners.py b/tensorflow_asr/runners/base_runners.py index 14562a0b1c..82029da23d 100644 --- a/tensorflow_asr/runners/base_runners.py +++ b/tensorflow_asr/runners/base_runners.py @@ -33,15 +33,12 @@ def __init__(self, config: RunningConfig): self.config = config # Writers self.writers = { - "train": tf.summary.create_file_writer( - os.path.join(self.config.outdir, "tensorboard", "train")), - "eval": tf.summary.create_file_writer( - os.path.join(self.config.outdir, "tensorboard", "eval")) + "train": tf.summary.create_file_writer(os.path.join(self.config.outdir, "tensorboard", "train")), + "eval": tf.summary.create_file_writer(os.path.join(self.config.outdir, "tensorboard", "eval")) } def add_writer(self, stage: str): - self.writers[stage] = tf.summary.create_file_writer( - os.path.join(self.config.outdir, "tensorboard", stage)) + self.writers[stage] = tf.summary.create_file_writer(os.path.join(self.config.outdir, "tensorboard", stage)) def _write_to_tensorboard(self, list_metrics: dict, @@ -149,7 +146,7 @@ def create_checkpoint_manager(self, max_to_keep=10, **kwargs): with self.strategy.scope(): self.ckpt = tf.train.Checkpoint(steps=self.steps, **kwargs) checkpoint_dir = os.path.join(self.config.outdir, "checkpoints") - if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) + if not tf.io.gfile.exists(checkpoint_dir): tf.io.gfile.makedirs(checkpoint_dir) self.ckpt_manager = tf.train.CheckpointManager(self.ckpt, checkpoint_dir, max_to_keep=max_to_keep) def save_checkpoint(self): diff --git a/tensorflow_asr/utils/utils.py b/tensorflow_asr/utils/utils.py index 1ade899227..df509089f6 100755 --- a/tensorflow_asr/utils/utils.py +++ b/tensorflow_asr/utils/utils.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import re import os import sys import math @@ -47,11 +49,35 @@ def check_key_in_dict(dictionary, keys): raise ValueError("{} must be defined".format(key)) +def is_hdf5_filepath(filepath): + return (filepath.endswith('.h5') or filepath.endswith('.keras') or filepath.endswith('.hdf5')) + + +def is_cloud_path(path): + """ Check if the path is on cloud (which requires tf.io.gfile) + + Args: + path (str): Path to directory or file + + Returns: + bool: True if path is on cloud, False otherwise + """ + return bool(re.match(r"^[a-z]+://", path)) + + def preprocess_paths(paths: Union[List, str]): + """Expand the path to the root "/" + + Args: + paths (Union[List, str]): A path or list of paths + + Returns: + Union[List, str]: A processed path or list of paths, return None if it's not path + """ if isinstance(paths, list): - return [path if path.startswith('gs://') else os.path.abspath(os.path.expanduser(path)) for path in paths] + return [path if is_cloud_path(path) else os.path.abspath(os.path.expanduser(path)) for path in paths] elif isinstance(paths, str): - return paths if paths.startswith('gs://') else os.path.abspath(os.path.expanduser(paths)) + return paths if is_cloud_path(paths) else os.path.abspath(os.path.expanduser(paths)) else: return None From 3f908b1a1280837e7b7582245bfdb02c7fe6d40f Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 21 Feb 2021 18:13:55 +0700 Subject: [PATCH 02/10] :writing_hand: update tpu scripts --- examples/conformer/train_tpu_keras_subword_conformer.py | 2 +- examples/contextnet/train_tpu_keras_subword_contextnet.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/conformer/train_tpu_keras_subword_conformer.py b/examples/conformer/train_tpu_keras_subword_conformer.py index 10a0fc12d6..06e9a976e6 100644 --- a/examples/conformer/train_tpu_keras_subword_conformer.py +++ b/examples/conformer/train_tpu_keras_subword_conformer.py @@ -32,7 +32,7 @@ parser.add_argument("--bs", type=int, default=None, help="Batch size per replica") -parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing TPU performance") +parser.add_argument("--spx", type=int, default=50, help="Steps per execution for maximizing TPU performance") parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab") diff --git a/examples/contextnet/train_tpu_keras_subword_contextnet.py b/examples/contextnet/train_tpu_keras_subword_contextnet.py index 5293982638..99fb69a838 100644 --- a/examples/contextnet/train_tpu_keras_subword_contextnet.py +++ b/examples/contextnet/train_tpu_keras_subword_contextnet.py @@ -32,7 +32,7 @@ parser.add_argument("--bs", type=int, default=None, help="Batch size per replica") -parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing TPU performance") +parser.add_argument("--spx", type=int, default=50, help="Steps per execution for maximizing TPU performance") parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab") @@ -117,6 +117,7 @@ contextnet.compile( optimizer=optimizer, + experimental_steps_per_execution=args.spx, global_batch_size=global_batch_size, blank=text_featurizer.blank ) From 91cd08a49cdc79a6bfb227813fd3911da7fb84f5 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 21 Feb 2021 18:25:30 +0700 Subject: [PATCH 03/10] :writing_hand: update training keras scripts --- examples/conformer/train_keras_subword_conformer.py | 6 ++---- examples/conformer/train_tpu_keras_subword_conformer.py | 3 +-- examples/contextnet/train_keras_subword_contextnet.py | 6 ++---- examples/contextnet/train_tpu_keras_subword_contextnet.py | 3 +-- examples/deepspeech2/train_keras_ds2.py | 6 ++---- examples/jasper/train_keras_jasper.py | 6 +++--- .../train_keras_subword_streaming_transducer.py | 6 ++---- 7 files changed, 13 insertions(+), 23 deletions(-) diff --git a/examples/conformer/train_keras_subword_conformer.py b/examples/conformer/train_keras_subword_conformer.py index 986396c2a6..5f0652782f 100644 --- a/examples/conformer/train_keras_subword_conformer.py +++ b/examples/conformer/train_keras_subword_conformer.py @@ -88,8 +88,7 @@ ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config), - indefinite=True + **vars(config.learning_config.eval_dataset_config) ) # Update metadata calculated from both train and eval datasets train_dataset.load_metadata(args.metadata_prefix) @@ -105,8 +104,7 @@ ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config), - indefinite=True + **vars(config.learning_config.train_dataset_config) ) with strategy.scope(): diff --git a/examples/conformer/train_tpu_keras_subword_conformer.py b/examples/conformer/train_tpu_keras_subword_conformer.py index 06e9a976e6..38b47a00b7 100644 --- a/examples/conformer/train_tpu_keras_subword_conformer.py +++ b/examples/conformer/train_tpu_keras_subword_conformer.py @@ -83,8 +83,7 @@ ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config), - indefinite=True + **vars(config.learning_config.eval_dataset_config) ) if args.compute_lengths: diff --git a/examples/contextnet/train_keras_subword_contextnet.py b/examples/contextnet/train_keras_subword_contextnet.py index b641dc6b32..404b989a58 100644 --- a/examples/contextnet/train_keras_subword_contextnet.py +++ b/examples/contextnet/train_keras_subword_contextnet.py @@ -83,8 +83,7 @@ ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config), - indefinite=True + **vars(config.learning_config.eval_dataset_config) ) # Update metadata calculated from both train and eval datasets train_dataset.load_metadata(args.metadata_prefix) @@ -100,8 +99,7 @@ ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config), - indefinite=True + **vars(config.learning_config.eval_dataset_config) ) with strategy.scope(): diff --git a/examples/contextnet/train_tpu_keras_subword_contextnet.py b/examples/contextnet/train_tpu_keras_subword_contextnet.py index 99fb69a838..335ff031b1 100644 --- a/examples/contextnet/train_tpu_keras_subword_contextnet.py +++ b/examples/contextnet/train_tpu_keras_subword_contextnet.py @@ -83,8 +83,7 @@ ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config), - indefinite=True + **vars(config.learning_config.eval_dataset_config) ) if args.compute_lengths: diff --git a/examples/deepspeech2/train_keras_ds2.py b/examples/deepspeech2/train_keras_ds2.py index 470e2297d1..7c1b4afa57 100644 --- a/examples/deepspeech2/train_keras_ds2.py +++ b/examples/deepspeech2/train_keras_ds2.py @@ -65,8 +65,7 @@ ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config), - indefinite=True + **vars(config.learning_config.eval_dataset_config) ) # Update metadata calculated from both train and eval datasets train_dataset.load_metadata(args.metadata_prefix) @@ -82,8 +81,7 @@ ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config), - indefinite=True + **vars(config.learning_config.eval_dataset_config) ) # Build DS2 model diff --git a/examples/jasper/train_keras_jasper.py b/examples/jasper/train_keras_jasper.py index bc0acab373..6b0558fc7b 100644 --- a/examples/jasper/train_keras_jasper.py +++ b/examples/jasper/train_keras_jasper.py @@ -67,8 +67,7 @@ ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config), - indefinite=True + **vars(config.learning_config.eval_dataset_config) ) # Update metadata calculated from both train and eval datasets train_dataset.load_metadata(args.metadata_prefix) @@ -79,7 +78,8 @@ else: train_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, diff --git a/examples/streaming_transducer/train_keras_subword_streaming_transducer.py b/examples/streaming_transducer/train_keras_subword_streaming_transducer.py index e0d025f609..0b7de8fdaf 100644 --- a/examples/streaming_transducer/train_keras_subword_streaming_transducer.py +++ b/examples/streaming_transducer/train_keras_subword_streaming_transducer.py @@ -81,8 +81,7 @@ ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config), - indefinite=True + **vars(config.learning_config.eval_dataset_config) ) # Update metadata calculated from both train and eval datasets train_dataset.load_metadata(args.metadata_prefix) @@ -98,8 +97,7 @@ ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config), - indefinite=True + **vars(config.learning_config.eval_dataset_config) ) with strategy.scope(): From 678db67903e758aa2d90f273c4700fe48fe505a3 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 21 Feb 2021 18:46:36 +0700 Subject: [PATCH 04/10] :writing_hand: update dataset creation --- tensorflow_asr/datasets/asr_dataset.py | 5 ++--- tensorflow_asr/datasets/keras/asr_dataset.py | 6 ++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index 46d5f3e777..d02479fba7 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -165,7 +165,6 @@ def tf_preprocess(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): return path, features, input_length, label, label_length, prediction, prediction_length - @tf.function def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): """ Returns: @@ -183,7 +182,8 @@ def process(self, dataset: tf.data.Dataset, batch_size: int): dataset = dataset.cache() if self.shuffle: - dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) + reshuffle = not self.indefinite + dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=reshuffle) if self.indefinite: dataset = dataset.repeat() @@ -296,7 +296,6 @@ def get_shard_path(shard_id): return True - @tf.function def parse(self, record: tf.Tensor): feature_description = { "path": tf.io.FixedLenFeature([], tf.string), diff --git a/tensorflow_asr/datasets/keras/asr_dataset.py b/tensorflow_asr/datasets/keras/asr_dataset.py index be6647eac6..d0f7321cd1 100644 --- a/tensorflow_asr/datasets/keras/asr_dataset.py +++ b/tensorflow_asr/datasets/keras/asr_dataset.py @@ -25,7 +25,6 @@ class ASRDatasetKeras(ASRDataset): """ Keras Dataset for ASR using Generator """ - @tf.function def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): """ Returns: @@ -57,7 +56,8 @@ def process(self, dataset, batch_size): dataset = dataset.cache() if self.shuffle: - dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) + reshuffle = not self.indefinite + dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=reshuffle) if self.indefinite: dataset = dataset.repeat() @@ -131,7 +131,6 @@ def __init__(self, indefinite=indefinite ) - @tf.function def parse(self, record: tf.Tensor): feature_description = { "path": tf.io.FixedLenFeature([], tf.string), @@ -174,7 +173,6 @@ def __init__(self, indefinite=indefinite ) - @tf.function def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): return ASRDatasetKeras.parse(self, path, audio, indices) From 2dda1e138473b945c8f868d514087d0cda4857fc Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 21 Feb 2021 19:20:55 +0700 Subject: [PATCH 05/10] :writing_hand: update scripts --- examples/contextnet/train_tpu_keras_subword_contextnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/contextnet/train_tpu_keras_subword_contextnet.py b/examples/contextnet/train_tpu_keras_subword_contextnet.py index 335ff031b1..039c0e280c 100644 --- a/examples/contextnet/train_tpu_keras_subword_contextnet.py +++ b/examples/contextnet/train_tpu_keras_subword_contextnet.py @@ -78,8 +78,7 @@ train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config), - indefinite=True + **vars(config.learning_config.train_dataset_config) ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, From 93dcc3559088fabd5bf0796b5eca764489a9dbe3 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 21 Feb 2021 19:31:05 +0700 Subject: [PATCH 06/10] :writing_hand: update scripts --- .../train_tpu_keras_subword_contextnet.py | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/examples/contextnet/train_tpu_keras_subword_contextnet.py b/examples/contextnet/train_tpu_keras_subword_contextnet.py index 039c0e280c..b75e541d34 100644 --- a/examples/contextnet/train_tpu_keras_subword_contextnet.py +++ b/examples/contextnet/train_tpu_keras_subword_contextnet.py @@ -78,11 +78,13 @@ train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) if args.compute_lengths: @@ -93,10 +95,14 @@ train_dataset.load_metadata(args.metadata_prefix) eval_dataset.load_metadata(args.metadata_prefix) +batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size +global_batch_size = batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) + with strategy.scope(): - batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size - global_batch_size = batch_size - global_batch_size *= strategy.num_replicas_in_sync # build model contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) contextnet._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size) @@ -120,17 +126,14 @@ blank=text_featurizer.blank ) - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] - contextnet.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps - ) +contextnet.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +) From 3ae67291d0b99029b84dcf3258fffc55b06034bb Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 21 Feb 2021 19:51:51 +0700 Subject: [PATCH 07/10] :writing_hand: update dataset --- tensorflow_asr/datasets/asr_dataset.py | 5 ++--- tensorflow_asr/datasets/keras/asr_dataset.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index d02479fba7..0bfb28377c 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -182,8 +182,7 @@ def process(self, dataset: tf.data.Dataset, batch_size: int): dataset = dataset.cache() if self.shuffle: - reshuffle = not self.indefinite - dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=reshuffle) + dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) if self.indefinite: dataset = dataset.repeat() @@ -200,7 +199,7 @@ def process(self, dataset: tf.data.Dataset, batch_size: int): tf.TensorShape(self.text_featurizer.prepand_shape), tf.TensorShape([]), ), - padding_values=("", 0., 0, self.text_featurizer.blank, 0, self.text_featurizer.blank, 0), + padding_values=(None, 0., 0, self.text_featurizer.blank, 0, self.text_featurizer.blank, 0), drop_remainder=self.drop_remainder ) diff --git a/tensorflow_asr/datasets/keras/asr_dataset.py b/tensorflow_asr/datasets/keras/asr_dataset.py index d0f7321cd1..2a2b8702da 100644 --- a/tensorflow_asr/datasets/keras/asr_dataset.py +++ b/tensorflow_asr/datasets/keras/asr_dataset.py @@ -56,8 +56,7 @@ def process(self, dataset, batch_size): dataset = dataset.cache() if self.shuffle: - reshuffle = not self.indefinite - dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=reshuffle) + dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) if self.indefinite: dataset = dataset.repeat() @@ -80,7 +79,7 @@ def process(self, dataset, batch_size): ), padding_values=( { - "path": "", + "path": None, "input": 0., "input_length": 0, "prediction": self.text_featurizer.blank, From de2cb10e8109d1c4d20ec8d8ed4eff41e6046237 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 21 Feb 2021 20:00:10 +0700 Subject: [PATCH 08/10] :writing_hand: remove path from dataset --- tensorflow_asr/datasets/keras/asr_dataset.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tensorflow_asr/datasets/keras/asr_dataset.py b/tensorflow_asr/datasets/keras/asr_dataset.py index 2a2b8702da..448ad4011e 100644 --- a/tensorflow_asr/datasets/keras/asr_dataset.py +++ b/tensorflow_asr/datasets/keras/asr_dataset.py @@ -33,11 +33,10 @@ def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): if self.use_tf: data = self.tf_preprocess(path, audio, indices) else: data = self.preprocess(path, audio, indices) - path, features, input_length, label, label_length, prediction, prediction_length = data + _, features, input_length, label, label_length, prediction, prediction_length = data return ( { - "path": path, "input": features, "input_length": input_length, "prediction": prediction, @@ -66,7 +65,6 @@ def process(self, dataset, batch_size): batch_size=batch_size, padded_shapes=( { - "path": tf.TensorShape([]), "input": tf.TensorShape(self.speech_featurizer.shape), "input_length": tf.TensorShape([]), "prediction": tf.TensorShape(self.text_featurizer.prepand_shape), @@ -79,7 +77,6 @@ def process(self, dataset, batch_size): ), padding_values=( { - "path": None, "input": 0., "input_length": 0, "prediction": self.text_featurizer.blank, From bedbd4c40c2c5e46b8c6b89b19091eabd89e5297 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 21 Feb 2021 20:38:51 +0700 Subject: [PATCH 09/10] :writing_hand: update dataset --- tensorflow_asr/datasets/base_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow_asr/datasets/base_dataset.py b/tensorflow_asr/datasets/base_dataset.py index fd23f4cf98..d62cc2de31 100644 --- a/tensorflow_asr/datasets/base_dataset.py +++ b/tensorflow_asr/datasets/base_dataset.py @@ -46,7 +46,6 @@ def __init__(self, self.use_tf = use_tf self.drop_remainder = drop_remainder # whether to drop remainder for multi gpu training self.indefinite = indefinite # Whether to make dataset repeat indefinitely -> avoid the potential last partial batch - if self.indefinite: self.drop_remainder = False # No dropping remainder in indefinite dataset self.total_steps = None # for better training visualization @abc.abstractmethod From bddcf81397a7d2a14bc44495e5166cd638acacec Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 21 Feb 2021 21:24:41 +0700 Subject: [PATCH 10/10] :writing_hand: update dataset and scripts --- .../train_keras_subword_conformer.py | 36 +++++++++-------- .../train_tpu_keras_subword_conformer.py | 36 +++++++++-------- .../train_keras_subword_contextnet.py | 39 ++++++++++--------- examples/deepspeech2/train_keras_ds2.py | 37 +++++++++--------- examples/jasper/train_keras_jasper.py | 37 +++++++++--------- ...rain_keras_subword_streaming_transducer.py | 36 +++++++++-------- 6 files changed, 116 insertions(+), 105 deletions(-) diff --git a/examples/conformer/train_keras_subword_conformer.py b/examples/conformer/train_keras_subword_conformer.py index 5f0652782f..7f2219cff2 100644 --- a/examples/conformer/train_keras_subword_conformer.py +++ b/examples/conformer/train_keras_subword_conformer.py @@ -104,12 +104,17 @@ ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) +global_batch_size = config.learning_config.running_config.batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) + with strategy.scope(): - global_batch_size = config.learning_config.running_config.batch_size - global_batch_size *= strategy.num_replicas_in_sync # build model conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape) @@ -133,17 +138,14 @@ blank=text_featurizer.blank ) - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] - - conformer.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps - ) +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] + +conformer.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +) diff --git a/examples/conformer/train_tpu_keras_subword_conformer.py b/examples/conformer/train_tpu_keras_subword_conformer.py index 38b47a00b7..a6126c6933 100644 --- a/examples/conformer/train_tpu_keras_subword_conformer.py +++ b/examples/conformer/train_tpu_keras_subword_conformer.py @@ -83,7 +83,8 @@ ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) if args.compute_lengths: @@ -94,10 +95,14 @@ train_dataset.load_metadata(args.metadata_prefix) eval_dataset.load_metadata(args.metadata_prefix) +batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size +global_batch_size = batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) + with strategy.scope(): - batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size - global_batch_size = batch_size - global_batch_size *= strategy.num_replicas_in_sync # build model conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size) @@ -121,17 +126,14 @@ blank=text_featurizer.blank ) - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] - - conformer.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps - ) +conformer.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +) diff --git a/examples/contextnet/train_keras_subword_contextnet.py b/examples/contextnet/train_keras_subword_contextnet.py index 404b989a58..4046cfb858 100644 --- a/examples/contextnet/train_keras_subword_contextnet.py +++ b/examples/contextnet/train_keras_subword_contextnet.py @@ -83,7 +83,8 @@ ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) # Update metadata calculated from both train and eval datasets train_dataset.load_metadata(args.metadata_prefix) @@ -99,12 +100,17 @@ ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) +global_batch_size = config.learning_config.running_config.batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) + with strategy.scope(): - global_batch_size = config.learning_config.running_config.batch_size - global_batch_size *= strategy.num_replicas_in_sync # build model contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) contextnet._build(speech_featurizer.shape) @@ -128,17 +134,14 @@ blank=text_featurizer.blank ) - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] - - contextnet.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps - ) +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] + +contextnet.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +) diff --git a/examples/deepspeech2/train_keras_ds2.py b/examples/deepspeech2/train_keras_ds2.py index 7c1b4afa57..49e0b83d95 100644 --- a/examples/deepspeech2/train_keras_ds2.py +++ b/examples/deepspeech2/train_keras_ds2.py @@ -81,14 +81,18 @@ ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) +global_batch_size = config.learning_config.running_config.batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) + # Build DS2 model with strategy.scope(): - global_batch_size = config.learning_config.running_config.batch_size - global_batch_size *= strategy.num_replicas_in_sync - ds2_model = DeepSpeech2(**config.model_config, vocabulary_size=text_featurizer.num_classes) ds2_model._build(speech_featurizer.shape) ds2_model.summary(line_length=120) @@ -100,17 +104,14 @@ blank=text_featurizer.blank ) - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] - - ds2_model.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps - ) +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] + +ds2_model.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +) diff --git a/examples/jasper/train_keras_jasper.py b/examples/jasper/train_keras_jasper.py index 6b0558fc7b..444ca1314a 100644 --- a/examples/jasper/train_keras_jasper.py +++ b/examples/jasper/train_keras_jasper.py @@ -83,13 +83,17 @@ ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) -with strategy.scope(): - global_batch_size = config.learning_config.running_config.batch_size - global_batch_size *= strategy.num_replicas_in_sync +global_batch_size = config.learning_config.running_config.batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) +with strategy.scope(): jasper = Jasper(**config.model_config, vocabulary_size=text_featurizer.num_classes) jasper._build(speech_featurizer.shape) jasper.summary(line_length=120) @@ -101,17 +105,14 @@ blank=text_featurizer.blank ) - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] - - jasper.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps - ) +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] + +jasper.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +) diff --git a/examples/streaming_transducer/train_keras_subword_streaming_transducer.py b/examples/streaming_transducer/train_keras_subword_streaming_transducer.py index 0b7de8fdaf..c9254a4fd0 100644 --- a/examples/streaming_transducer/train_keras_subword_streaming_transducer.py +++ b/examples/streaming_transducer/train_keras_subword_streaming_transducer.py @@ -97,12 +97,17 @@ ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) +global_batch_size = config.learning_config.running_config.batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) + with strategy.scope(): - global_batch_size = config.learning_config.running_config.batch_size - global_batch_size *= strategy.num_replicas_in_sync # build model streaming_transducer = StreamingTransducer( **config.model_config, @@ -120,17 +125,14 @@ blank=text_featurizer.blank ) - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] - - streaming_transducer.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps - ) +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] + +streaming_transducer.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +)