diff --git a/README.md b/README.md index d73e6a0f21..1ff59a9809 100755 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ TensorFlowASR implements some automatic speech recognition architectures such as - [TFLite Convertion](#tflite-convertion) - [Features Extraction](#features-extraction) - [Augmentations](#augmentations) -- [Training & Testing](#training--testing) +- [Training & Testing Tutorial](#training--testing-tutorial) - [Corpus Sources and Pretrained Models](#corpus-sources-and-pretrained-models) - [English](#english) - [Vietnamese](#vietnamese) @@ -164,34 +164,17 @@ See [features_extraction](./tensorflow_asr/featurizers/README.md) See [augmentations](./tensorflow_asr/augmentations/README.md) -## Training & Testing - -**Example YAML Config Structure** - -```yaml -speech_config: ... -model_config: ... -decoder_config: ... -learning_config: - train_dataset_config: - augmentation_config: ... - data_paths: ... - tfrecords_dir: ... - eval_dataset_config: - augmentation_config: ... - data_paths: ... - tfrecords_dir: ... - test_dataset_config: - augmentation_config: ... - data_paths: ... - tfrecords_dir: ... - optimizer_config: ... - running_config: - batch_size: 8 - num_epochs: 20 - outdir: ... - log_interval_steps: 500 -``` +## Training & Testing Tutorial + +1. Define config YAML file, see the `config.yml` files in the [example folder](./examples) for reference (you can copy and modify values such as parameters, paths, etc.. to match your local machine configuration) +2. Download your corpus (a.k.a datasets) and create a script to generate `transcripts.tsv` files from your corpus (this is general format used in this project because each dataset has different format). For more detail, see [datasets](./tensorflow_asr/datasets/README.md). **Note:** Make sure your data contain only characters in your language, for example, english has `a` to `z` and `'`. **Do not use `cache` if your dataset size is not fit in the RAM**. +3. [Optional] Generate TFRecords to use `tf.data.TFRecordDataset` for better performance by using the script [create_tfrecords.py](./scripts/create_tfrecords.py) +4. Create vocabulary file (characters or subwords/wordpieces) by defining `language.characters`, using the scripts [generate_vocab_subwords.py](./scripts/generate_vocab_subwords.py) or [generate_vocab_sentencepiece.py](./scripts/generate_vocab_sentencepiece.py). There're predefined ones in [vocabularies](./vocabularies) +5. [Optional] Generate metadata file for your dataset by using script [generate_metadata.py](./scripts/generate_metadata.py). This metadata file contains maximum lengths calculated with your `config.yml` and total number of elements in each dataset, for static shape training and precalculated steps per epoch. +6. For training, see `train_*.py` files in the [example folder](./examples) to see the options +7. For testing, see `test_.*.py` files in the [example folder](./examples) to see the options. **Note:** Testing is currently not supported for TPUs. It will print nothing other than the progress bar in the console, but it will store the predicted transcripts to the file `output_name.tsv` in the `outdir` defined in the config yaml file. After testing is done, the metrics (WER and CER) are calculated from `output_name.tsv`. **If you define the same `output_name`, it will resume the testing from the previous tested batch, which means if the testing is done then it will only calculate the metrics, if you want to run a new test, define a new `output_name` that the file `output.tsv` is not exists or only contains the header** + +**Recommendation**: For better performance, please use **keras builtin training functions** as in `train_keras_*.py` files and/or tfrecords. Keras builtin training uses **infinite dataset**, which avoids the potential last partial batch. See [examples](./examples/) for some predefined ASR models and results diff --git a/examples/conformer/train_keras_subword_conformer.py b/examples/conformer/train_keras_subword_conformer.py index bbba1cadd4..7f2219cff2 100644 --- a/examples/conformer/train_keras_subword_conformer.py +++ b/examples/conformer/train_keras_subword_conformer.py @@ -38,6 +38,10 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") +parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance") + +parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") + parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") @@ -79,25 +83,38 @@ if args.tfrecords: train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.eval_dataset_config) ) + # Update metadata calculated from both train and eval datasets + train_dataset.load_metadata(args.metadata_prefix) + eval_dataset.load_metadata(args.metadata_prefix) + # Use dynamic length + speech_featurizer.reset_length() + text_featurizer.reset_length() else: train_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) +global_batch_size = config.learning_config.running_config.batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) + with strategy.scope(): - global_batch_size = config.learning_config.running_config.batch_size - global_batch_size *= strategy.num_replicas_in_sync # build model conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape) @@ -114,19 +131,21 @@ epsilon=config.learning_config.optimizer_config["epsilon"] ) - conformer.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank) - - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] - - conformer.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps + conformer.compile( + optimizer=optimizer, + experimental_steps_per_execution=args.spx, + global_batch_size=global_batch_size, + blank=text_featurizer.blank ) + +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] + +conformer.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +) diff --git a/examples/conformer/train_tpu_keras_subword_conformer.py b/examples/conformer/train_tpu_keras_subword_conformer.py index 8590f269ed..a6126c6933 100644 --- a/examples/conformer/train_tpu_keras_subword_conformer.py +++ b/examples/conformer/train_tpu_keras_subword_conformer.py @@ -32,7 +32,7 @@ parser.add_argument("--bs", type=int, default=None, help="Batch size per replica") -parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing TPU performance") +parser.add_argument("--spx", type=int, default=50, help="Steps per execution for maximizing TPU performance") parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab") @@ -78,11 +78,13 @@ train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) if args.compute_lengths: @@ -93,10 +95,14 @@ train_dataset.load_metadata(args.metadata_prefix) eval_dataset.load_metadata(args.metadata_prefix) +batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size +global_batch_size = batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) + with strategy.scope(): - batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size - global_batch_size = batch_size - global_batch_size *= strategy.num_replicas_in_sync # build model conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size) @@ -120,17 +126,14 @@ blank=text_featurizer.blank ) - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] - conformer.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps - ) +conformer.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +) diff --git a/examples/conformer/train_tpu_subword_conformer.py b/examples/conformer/train_tpu_subword_conformer.py deleted file mode 100644 index d9a55be207..0000000000 --- a/examples/conformer/train_tpu_subword_conformer.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2021 M. Yusuf Sarıgöz (@monatis) and Huy Le Nguyen (@usimarit) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import math -import argparse -from tensorflow_asr.utils import setup_environment, setup_tpu - -setup_environment() -import tensorflow as tf - -DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") - -tf.keras.backend.clear_session() - -parser = argparse.ArgumentParser(prog="Conformer Training") - -parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") - -parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep") - -parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") - -parser.add_argument("--bs", type=int, default=None, help="Common training and evaluation batch size per TPU core") - -parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab") - -parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") - -parser.add_argument("--compute_lengths", default=False, action="store_true", help="Whether to compute lengths") - -parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") - -parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords") - -parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords") - -args = parser.parse_args() - -tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset -from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer -from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer -from tensorflow_asr.runners.transducer_runners import TransducerTrainer -from tensorflow_asr.models.conformer import Conformer -from tensorflow_asr.optimizers.schedules import TransformerSchedule - -config = Config(args.config) -speech_featurizer = TFSpeechFeaturizer(config.speech_config) - -if args.sentence_piece: - print("Loading SentencePiece model ...") - text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords) -elif args.subwords and os.path.exists(args.subwords): - print("Loading subwords ...") - text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) -else: - print("Generating subwords ...") - text_featurizer = SubwordFeaturizer.build_from_corpus( - config.decoder_config, - corpus_files=args.subwords_corpus - ) - text_featurizer.save_to_file(args.subwords) - -train_dataset = ASRTFRecordDataset( - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) -) - -eval_dataset = ASRTFRecordDataset( - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) -) - -if args.compute_lengths: - train_dataset.update_lengths(args.metadata_prefix) - eval_dataset.update_lengths(args.metadata_prefix) - -# Update metadata calculated from both train and eval datasets -train_dataset.load_metadata(args.metadata_prefix) -eval_dataset.load_metadata(args.metadata_prefix) - -strategy = setup_tpu(args.tpu_address) - -conformer_trainer = TransducerTrainer( - config=config.learning_config.running_config, - text_featurizer=text_featurizer, strategy=strategy -) - -with conformer_trainer.strategy.scope(): - # build model - conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) - conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, - batch_size=args.bs if args.bs is not None else config.learning_config.running_config.batch_size) - conformer.summary(line_length=120) - - optimizer = tf.keras.optimizers.Adam( - TransformerSchedule( - d_model=conformer.dmodel, - warmup_steps=config.learning_config.optimizer_config["warmup_steps"], - max_lr=(0.05 / math.sqrt(conformer.dmodel)) - ), - beta_1=config.learning_config.optimizer_config["beta1"], - beta_2=config.learning_config.optimizer_config["beta2"], - epsilon=config.learning_config.optimizer_config["epsilon"] - ) - -conformer_trainer.compile(model=conformer, optimizer=optimizer, max_to_keep=args.max_ckpts) - -conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.bs, eval_bs=args.bs) diff --git a/examples/contextnet/train_keras_subword_contextnet.py b/examples/contextnet/train_keras_subword_contextnet.py index 8ff26f8fb0..4046cfb858 100644 --- a/examples/contextnet/train_keras_subword_contextnet.py +++ b/examples/contextnet/train_keras_subword_contextnet.py @@ -36,6 +36,10 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") +parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance") + +parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") + parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") @@ -74,25 +78,39 @@ if args.tfrecords: train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) + # Update metadata calculated from both train and eval datasets + train_dataset.load_metadata(args.metadata_prefix) + eval_dataset.load_metadata(args.metadata_prefix) + # Use dynamic length + speech_featurizer.reset_length() + text_featurizer.reset_length() else: train_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) +global_batch_size = config.learning_config.running_config.batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) + with strategy.scope(): - global_batch_size = config.learning_config.running_config.batch_size - global_batch_size *= strategy.num_replicas_in_sync # build model contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) contextnet._build(speech_featurizer.shape) @@ -109,19 +127,21 @@ epsilon=config.learning_config.optimizer_config["epsilon"] ) - contextnet.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank) - - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] - - contextnet.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps + contextnet.compile( + optimizer=optimizer, + experimental_steps_per_execution=args.spx, + global_batch_size=global_batch_size, + blank=text_featurizer.blank ) + +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] + +contextnet.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +) diff --git a/examples/contextnet/train_tpu_keras_subword_contextnet.py b/examples/contextnet/train_tpu_keras_subword_contextnet.py index 92d46a54d2..b75e541d34 100644 --- a/examples/contextnet/train_tpu_keras_subword_contextnet.py +++ b/examples/contextnet/train_tpu_keras_subword_contextnet.py @@ -32,7 +32,7 @@ parser.add_argument("--bs", type=int, default=None, help="Batch size per replica") -parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing TPU performance") +parser.add_argument("--spx", type=int, default=50, help="Steps per execution for maximizing TPU performance") parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab") @@ -78,11 +78,13 @@ train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) if args.compute_lengths: @@ -93,10 +95,14 @@ train_dataset.load_metadata(args.metadata_prefix) eval_dataset.load_metadata(args.metadata_prefix) +batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size +global_batch_size = batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) + with strategy.scope(): - batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size - global_batch_size = batch_size - global_batch_size *= strategy.num_replicas_in_sync # build model contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) contextnet._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size) @@ -115,20 +121,19 @@ contextnet.compile( optimizer=optimizer, + experimental_steps_per_execution=args.spx, global_batch_size=global_batch_size, blank=text_featurizer.blank ) - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] - contextnet.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - ) +contextnet.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +) diff --git a/examples/deepspeech2/train_keras_ds2.py b/examples/deepspeech2/train_keras_ds2.py index b9501040de..49e0b83d95 100644 --- a/examples/deepspeech2/train_keras_ds2.py +++ b/examples/deepspeech2/train_keras_ds2.py @@ -33,6 +33,8 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replicas") +parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") + parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords dataset") parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") @@ -58,45 +60,58 @@ if args.tfrecords: train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.eval_dataset_config) ) + # Update metadata calculated from both train and eval datasets + train_dataset.load_metadata(args.metadata_prefix) + eval_dataset.load_metadata(args.metadata_prefix) + # Use dynamic length + speech_featurizer.reset_length() + text_featurizer.reset_length() else: train_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) +global_batch_size = config.learning_config.running_config.batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) + # Build DS2 model with strategy.scope(): - global_batch_size = config.learning_config.running_config.batch_size - global_batch_size *= strategy.num_replicas_in_sync - ds2_model = DeepSpeech2(**config.model_config, vocabulary_size=text_featurizer.num_classes) ds2_model._build(speech_featurizer.shape) ds2_model.summary(line_length=120) - ds2_model.compile(optimizer=config.learning_config.optimizer_config, - global_batch_size=global_batch_size, blank=text_featurizer.blank) - - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] - - ds2_model.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps + ds2_model.compile( + optimizer=config.learning_config.optimizer_config, + experimental_steps_per_execution=args.spx, + global_batch_size=global_batch_size, + blank=text_featurizer.blank ) + +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] + +ds2_model.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +) diff --git a/examples/jasper/train_keras_jasper.py b/examples/jasper/train_keras_jasper.py index 9cfa48669c..444ca1314a 100644 --- a/examples/jasper/train_keras_jasper.py +++ b/examples/jasper/train_keras_jasper.py @@ -33,6 +33,10 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replicas") +parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance") + +parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") + parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords dataset") parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") @@ -58,44 +62,57 @@ if args.tfrecords: train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.eval_dataset_config) ) + # Update metadata calculated from both train and eval datasets + train_dataset.load_metadata(args.metadata_prefix) + eval_dataset.load_metadata(args.metadata_prefix) + # Use dynamic length + speech_featurizer.reset_length() + text_featurizer.reset_length() else: train_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) -with strategy.scope(): - global_batch_size = config.learning_config.running_config.batch_size - global_batch_size *= strategy.num_replicas_in_sync +global_batch_size = config.learning_config.running_config.batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) +with strategy.scope(): jasper = Jasper(**config.model_config, vocabulary_size=text_featurizer.num_classes) jasper._build(speech_featurizer.shape) jasper.summary(line_length=120) - jasper.compile(optimizer=config.learning_config.optimizer_config, - global_batch_size=global_batch_size, blank=text_featurizer.blank) - - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] - - jasper.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps + jasper.compile( + optimizer=config.learning_config.optimizer_config, + experimental_steps_per_execution=args.spx, + global_batch_size=global_batch_size, + blank=text_featurizer.blank ) + +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] + +jasper.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +) diff --git a/examples/streaming_transducer/train_keras_subword_streaming_transducer.py b/examples/streaming_transducer/train_keras_subword_streaming_transducer.py index 573edb9f13..c9254a4fd0 100644 --- a/examples/streaming_transducer/train_keras_subword_streaming_transducer.py +++ b/examples/streaming_transducer/train_keras_subword_streaming_transducer.py @@ -35,6 +35,10 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") +parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance") + +parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") + parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") @@ -72,25 +76,38 @@ if args.tfrecords: train_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRTFRecordDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.eval_dataset_config) ) + # Update metadata calculated from both train and eval datasets + train_dataset.load_metadata(args.metadata_prefix) + eval_dataset.load_metadata(args.metadata_prefix) + # Use dynamic length + speech_featurizer.reset_length() + text_featurizer.reset_length() else: train_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.train_dataset_config) + **vars(config.learning_config.train_dataset_config), + indefinite=True ) eval_dataset = ASRSliceDatasetKeras( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - **vars(config.learning_config.eval_dataset_config) + **vars(config.learning_config.eval_dataset_config), + indefinite=True ) +global_batch_size = config.learning_config.running_config.batch_size +global_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_batch_size) + with strategy.scope(): - global_batch_size = config.learning_config.running_config.batch_size - global_batch_size *= strategy.num_replicas_in_sync # build model streaming_transducer = StreamingTransducer( **config.model_config, @@ -101,19 +118,21 @@ optimizer = tf.keras.optimizers.get(config.learning_config.optimizer_config) - streaming_transducer.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank) - - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) - ] - - streaming_transducer.fit( - train_data_loader, epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps + streaming_transducer.compile( + optimizer=optimizer, + experimental_steps_per_execution=args.spx, + global_batch_size=global_batch_size, + blank=text_featurizer.blank ) + +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] + +streaming_transducer.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps +) diff --git a/setup.py b/setup.py index 2d66d12b42..7d6161e567 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ setuptools.setup( name="TensorFlowASR", - version="0.7.6", + version="0.7.7", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index e7b0139e9f..0bfb28377c 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -36,6 +36,7 @@ def __init__(self, augmentations: Augmentation = Augmentation(None), cache: bool = False, shuffle: bool = False, + indefinite: bool = False, drop_remainder: bool = True, use_tf: bool = False, buffer_size: int = BUFFER_SIZE, @@ -43,7 +44,7 @@ def __init__(self, super(ASRDataset, self).__init__( data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, stage=stage, buffer_size=buffer_size, - drop_remainder=drop_remainder, use_tf=use_tf + drop_remainder=drop_remainder, use_tf=use_tf, indefinite=indefinite ) self.speech_featurizer = speech_featurizer self.text_featurizer = text_featurizer @@ -164,7 +165,6 @@ def tf_preprocess(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): return path, features, input_length, label, label_length, prediction, prediction_length - @tf.function def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): """ Returns: @@ -184,6 +184,9 @@ def process(self, dataset: tf.data.Dataset, batch_size: int): if self.shuffle: dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) + if self.indefinite: + dataset = dataset.repeat() + # PADDED BATCH the dataset dataset = dataset.padded_batch( batch_size=batch_size, @@ -196,7 +199,7 @@ def process(self, dataset: tf.data.Dataset, batch_size: int): tf.TensorShape(self.text_featurizer.prepand_shape), tf.TensorShape([]), ), - padding_values=("", 0., 0, self.text_featurizer.blank, 0, self.text_featurizer.blank, 0), + padding_values=(None, 0., 0, self.text_featurizer.blank, 0, self.text_featurizer.blank, 0), drop_remainder=self.drop_remainder ) @@ -230,13 +233,14 @@ def __init__(self, cache: bool = False, shuffle: bool = False, use_tf: bool = False, + indefinite: bool = False, drop_remainder: bool = True, buffer_size: int = BUFFER_SIZE, **kwargs): super(ASRTFRecordDataset, self).__init__( stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, buffer_size=buffer_size, - drop_remainder=drop_remainder, use_tf=use_tf + drop_remainder=drop_remainder, use_tf=use_tf, indefinite=indefinite ) if not self.stage: raise ValueError("stage must be defined, either 'train', 'eval' or 'test'") self.tfrecords_dir = tfrecords_dir @@ -291,7 +295,6 @@ def get_shard_path(shard_id): return True - @tf.function def parse(self, record: tf.Tensor): feature_description = { "path": tf.io.FixedLenFeature([], tf.string), diff --git a/tensorflow_asr/datasets/base_dataset.py b/tensorflow_asr/datasets/base_dataset.py index 760778983b..d62cc2de31 100644 --- a/tensorflow_asr/datasets/base_dataset.py +++ b/tensorflow_asr/datasets/base_dataset.py @@ -31,6 +31,7 @@ def __init__(self, cache: bool = False, shuffle: bool = False, buffer_size: int = BUFFER_SIZE, + indefinite: bool = False, drop_remainder: bool = True, use_tf: bool = False, stage: str = "train", @@ -44,6 +45,7 @@ def __init__(self, self.stage = stage # for defining tfrecords files self.use_tf = use_tf self.drop_remainder = drop_remainder # whether to drop remainder for multi gpu training + self.indefinite = indefinite # Whether to make dataset repeat indefinitely -> avoid the potential last partial batch self.total_steps = None # for better training visualization @abc.abstractmethod diff --git a/tensorflow_asr/datasets/keras/asr_dataset.py b/tensorflow_asr/datasets/keras/asr_dataset.py index 6f29bf95ee..448ad4011e 100644 --- a/tensorflow_asr/datasets/keras/asr_dataset.py +++ b/tensorflow_asr/datasets/keras/asr_dataset.py @@ -25,7 +25,6 @@ class ASRDatasetKeras(ASRDataset): """ Keras Dataset for ASR using Generator """ - @tf.function def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): """ Returns: @@ -34,11 +33,10 @@ def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): if self.use_tf: data = self.tf_preprocess(path, audio, indices) else: data = self.preprocess(path, audio, indices) - path, features, input_length, label, label_length, prediction, prediction_length = data + _, features, input_length, label, label_length, prediction, prediction_length = data return ( { - "path": path, "input": features, "input_length": input_length, "prediction": prediction, @@ -59,12 +57,14 @@ def process(self, dataset, batch_size): if self.shuffle: dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) + if self.indefinite: + dataset = dataset.repeat() + # PADDED BATCH the dataset dataset = dataset.padded_batch( batch_size=batch_size, padded_shapes=( { - "path": tf.TensorShape([]), "input": tf.TensorShape(self.speech_featurizer.shape), "input_length": tf.TensorShape([]), "prediction": tf.TensorShape(self.text_featurizer.prepand_shape), @@ -77,7 +77,6 @@ def process(self, dataset, batch_size): ), padding_values=( { - "path": "", "input": 0., "input_length": 0, "prediction": self.text_featurizer.blank, @@ -111,21 +110,23 @@ def __init__(self, cache: bool = False, shuffle: bool = False, use_tf: bool = False, + indefinite: bool = False, drop_remainder: bool = True, buffer_size: int = BUFFER_SIZE, **kwargs): ASRTFRecordDataset.__init__( self, stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, data_paths=data_paths, tfrecords_dir=tfrecords_dir, augmentations=augmentations, cache=cache, shuffle=shuffle, - tfrecords_shards=tfrecords_shards, drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf + tfrecords_shards=tfrecords_shards, drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf, + indefinite=indefinite ) ASRDatasetKeras.__init__( self, stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, - drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf + drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf, + indefinite=indefinite ) - @tf.function def parse(self, record: tf.Tensor): feature_description = { "path": tf.io.FixedLenFeature([], tf.string), @@ -151,21 +152,23 @@ def __init__(self, cache: bool = False, shuffle: bool = False, use_tf: bool = False, + indefinite: bool = False, drop_remainder: bool = True, buffer_size: int = BUFFER_SIZE, **kwargs): ASRSliceDataset.__init__( self, stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, - drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf + drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf, + indefinite=indefinite ) ASRDatasetKeras.__init__( self, stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, - drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf + drop_remainder=drop_remainder, buffer_size=buffer_size, use_tf=use_tf, + indefinite=indefinite ) - @tf.function def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): return ASRDatasetKeras.parse(self, path, audio, indices) diff --git a/tensorflow_asr/featurizers/speech_featurizers.py b/tensorflow_asr/featurizers/speech_featurizers.py index 554b00063d..7421a61eaa 100755 --- a/tensorflow_asr/featurizers/speech_featurizers.py +++ b/tensorflow_asr/featurizers/speech_featurizers.py @@ -235,6 +235,9 @@ def shape(self) -> list: def update_length(self, length: int): self.max_length = max(self.max_length, length) + def reset_length(self): + self.max_length = 0 + @abc.abstractclassmethod def stft(self, signal): raise NotImplementedError() diff --git a/tensorflow_asr/featurizers/text_featurizers.py b/tensorflow_asr/featurizers/text_featurizers.py index f9e633087c..4d00ba7965 100755 --- a/tensorflow_asr/featurizers/text_featurizers.py +++ b/tensorflow_asr/featurizers/text_featurizers.py @@ -51,6 +51,9 @@ def prepand_shape(self) -> list: def update_length(self, length: int): self.max_length = max(self.max_length, length) + def reset_length(self): + self.max_length = 0 + def preprocess_text(self, text): text = unicodedata.normalize("NFC", text.lower()) return text.strip("\n") # remove trailing newline diff --git a/tensorflow_asr/models/__init__.py b/tensorflow_asr/models/__init__.py index 84955496d0..7f37b4ffb1 100644 --- a/tensorflow_asr/models/__init__.py +++ b/tensorflow_asr/models/__init__.py @@ -12,14 +12,52 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import abc +import tempfile import tensorflow as tf +from ..utils.utils import is_cloud_path, is_hdf5_filepath + class Model(tf.keras.Model): def __init__(self, name, **kwargs): super(Model, self).__init__(name=name, **kwargs) + def save(self, filepath, overwrite=True, include_optimizer=True, save_format=None, + signatures=None, options=None, save_traces=True): + if is_cloud_path(filepath) and is_hdf5_filepath(filepath): + _, ext = os.path.splitext(filepath) + with tempfile.NamedTemporaryFile(suffix=ext) as tmp: + super(Model, self).save( + tmp.name, overwrite=overwrite, include_optimizer=include_optimizer, + save_format=save_format, signatures=signatures, options=options, save_traces=save_traces + ) + tf.io.gfile.copy(tmp.name, filepath, overwrite=True) + else: + super(Model, self).save( + filepath, overwrite=overwrite, include_optimizer=include_optimizer, + save_format=save_format, signatures=signatures, options=options, save_traces=save_traces + ) + + def save_weights(self, filepath, overwrite=True, save_format=None, options=None): + if is_cloud_path(filepath) and is_hdf5_filepath(filepath): + _, ext = os.path.splitext(filepath) + with tempfile.NamedTemporaryFile(suffix=ext) as tmp: + super(Model, self).save_weights(tmp.name, overwrite=overwrite, save_format=save_format, options=options) + tf.io.gfile.copy(tmp.name, filepath, overwrite=True) + else: + super(Model, self).save_weights(filepath, overwrite=overwrite, save_format=save_format, options=options) + + def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None): + if is_cloud_path(filepath) and is_hdf5_filepath(filepath): + _, ext = os.path.splitext(filepath) + with tempfile.NamedTemporaryFile(suffix=ext) as tmp: + tf.io.gfile.copy(filepath, tmp.name, overwrite=True) + super(Model, self).load_weights(tmp.name, by_name=by_name, skip_mismatch=skip_mismatch, options=options) + else: + super(Model, self).load_weights(filepath, by_name=by_name, skip_mismatch=skip_mismatch, options=options) + @abc.abstractmethod def _build(self, *args, **kwargs): raise NotImplementedError() diff --git a/tensorflow_asr/runners/base_runners.py b/tensorflow_asr/runners/base_runners.py index 14562a0b1c..82029da23d 100644 --- a/tensorflow_asr/runners/base_runners.py +++ b/tensorflow_asr/runners/base_runners.py @@ -33,15 +33,12 @@ def __init__(self, config: RunningConfig): self.config = config # Writers self.writers = { - "train": tf.summary.create_file_writer( - os.path.join(self.config.outdir, "tensorboard", "train")), - "eval": tf.summary.create_file_writer( - os.path.join(self.config.outdir, "tensorboard", "eval")) + "train": tf.summary.create_file_writer(os.path.join(self.config.outdir, "tensorboard", "train")), + "eval": tf.summary.create_file_writer(os.path.join(self.config.outdir, "tensorboard", "eval")) } def add_writer(self, stage: str): - self.writers[stage] = tf.summary.create_file_writer( - os.path.join(self.config.outdir, "tensorboard", stage)) + self.writers[stage] = tf.summary.create_file_writer(os.path.join(self.config.outdir, "tensorboard", stage)) def _write_to_tensorboard(self, list_metrics: dict, @@ -149,7 +146,7 @@ def create_checkpoint_manager(self, max_to_keep=10, **kwargs): with self.strategy.scope(): self.ckpt = tf.train.Checkpoint(steps=self.steps, **kwargs) checkpoint_dir = os.path.join(self.config.outdir, "checkpoints") - if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) + if not tf.io.gfile.exists(checkpoint_dir): tf.io.gfile.makedirs(checkpoint_dir) self.ckpt_manager = tf.train.CheckpointManager(self.ckpt, checkpoint_dir, max_to_keep=max_to_keep) def save_checkpoint(self): diff --git a/tensorflow_asr/utils/utils.py b/tensorflow_asr/utils/utils.py index 1ade899227..df509089f6 100755 --- a/tensorflow_asr/utils/utils.py +++ b/tensorflow_asr/utils/utils.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import re import os import sys import math @@ -47,11 +49,35 @@ def check_key_in_dict(dictionary, keys): raise ValueError("{} must be defined".format(key)) +def is_hdf5_filepath(filepath): + return (filepath.endswith('.h5') or filepath.endswith('.keras') or filepath.endswith('.hdf5')) + + +def is_cloud_path(path): + """ Check if the path is on cloud (which requires tf.io.gfile) + + Args: + path (str): Path to directory or file + + Returns: + bool: True if path is on cloud, False otherwise + """ + return bool(re.match(r"^[a-z]+://", path)) + + def preprocess_paths(paths: Union[List, str]): + """Expand the path to the root "/" + + Args: + paths (Union[List, str]): A path or list of paths + + Returns: + Union[List, str]: A processed path or list of paths, return None if it's not path + """ if isinstance(paths, list): - return [path if path.startswith('gs://') else os.path.abspath(os.path.expanduser(path)) for path in paths] + return [path if is_cloud_path(path) else os.path.abspath(os.path.expanduser(path)) for path in paths] elif isinstance(paths, str): - return paths if paths.startswith('gs://') else os.path.abspath(os.path.expanduser(paths)) + return paths if is_cloud_path(paths) else os.path.abspath(os.path.expanduser(paths)) else: return None