diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 79507f6ee4..2d9f842545 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,3 +9,12 @@ repos: stages: [pre-commit] fail_fast: true verbose: true + - id: pylint-check + name: pylint-check + entry: pylint --rcfile=.pylintrc -rn -sn + language: system + types: [python] + stages: [pre-commit] + fail_fast: true + require_serial: true + verbose: true diff --git a/.pylintrc b/.pylintrc index ca5736a5f2..c17e50e122 100644 --- a/.pylintrc +++ b/.pylintrc @@ -114,7 +114,11 @@ disable=too-few-public-methods, consider-using-enumerate, too-many-statements, assignment-from-none, - eval-used + eval-used, + duplicate-code, + redefined-outer-name, + consider-using-f-string, + fixme, # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/examples/inferences/main.py b/examples/inferences/main.py index 78880551fb..393ab20fd5 100644 --- a/examples/inferences/main.py +++ b/examples/inferences/main.py @@ -15,8 +15,10 @@ import os import tensorflow as tf +import keras from tensorflow_asr import schemas, tokenizers +from tensorflow_asr.models import base_model from tensorflow_asr.configs import Config from tensorflow_asr.utils import cli_util, data_util, env_util, file_util @@ -35,7 +37,7 @@ def main( config = Config(config_path, training=False, repodir=repodir) tokenizer = tokenizers.get(config) - model: tf.keras.Model = tf.keras.models.model_from_config(config.model_config) + model: base_model.BaseModel = keras.models.model_from_config(config.model_config) model.make(batch_size=1) model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5), skip_mismatch=False) model.summary() @@ -44,7 +46,15 @@ def main( signal = tf.reshape(signal, [1, -1]) signal_length = tf.reshape(tf.shape(signal)[1], [1]) - outputs = model.recognize(schemas.PredictInput(signal, signal_length)) + outputs = model.recognize( + schemas.PredictInput( + inputs=signal, + inputs_length=signal_length, + previous_tokens=model.get_initial_tokens(), + previous_encoder_states=model.get_initial_encoder_states(), + previous_decoder_states=model.get_initial_decoder_states(), + ) + ) print(outputs.tokens) transcript = tokenizer.detokenize(outputs.tokens)[0].numpy().decode("utf-8") diff --git a/examples/inferences/rnn_transducer.py b/examples/inferences/rnn_transducer.py index 443e9a77df..b6d471a6d4 100644 --- a/examples/inferences/rnn_transducer.py +++ b/examples/inferences/rnn_transducer.py @@ -1,89 +1,89 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -from tensorflow_asr.utils import data_util, env_util, math_util - -logger = env_util.setup_environment() -import tensorflow as tf - -parser = argparse.ArgumentParser(prog="Rnn Transducer non streaming") - -parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back") - -parser.add_argument("--config", type=str, default=None, help="Path to rnnt config yaml") - -parser.add_argument("--saved", type=str, default=None, help="Path to rnnt saved h5 weights") - -parser.add_argument("--beam_width", type=int, default=0, help="Beam width") - -parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp") - -parser.add_argument("--device", type=int, default=0, help="Device's id to run test on") - -parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu") - -parser.add_argument("--subwords", default=False, action="store_true", help="Path to file that stores generated subwords") - -parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") - -args = parser.parse_args() - -env_util.setup_devices([args.device], cpu=args.cpu) - -from tensorflow_asr.configs import Config -from tensorflow_asr.features.speech_featurizers import SpeechFeaturizer, read_raw_audio -from tensorflow_asr.models.transducer.rnnt import RnnTransducer -from tensorflow_asr.tokenizers import CharTokenizer, SentencePieceTokenizer, SubwordFeaturizer - -config = Config(args.config) -speech_featurizer = SpeechFeaturizer(config.speech_config) -if args.sentence_piece: - logger.info("Loading SentencePiece model ...") - text_featurizer = SentencePieceTokenizer(config.decoder_config) -elif args.subwords: - logger.info("Loading subwords ...") - text_featurizer = SubwordFeaturizer(config.decoder_config) -else: - text_featurizer = CharTokenizer(config.decoder_config) -text_featurizer.decoder_config.beam_width = args.beam_width - -# build model -rnnt = RnnTransducer(**config.model_config, vocab_size=text_featurizer.num_classes) -rnnt.make(speech_featurizer.shape) -rnnt.load_weights(args.saved, by_name=True, skip_mismatch=True) -rnnt.summary() -rnnt.add_featurizers(speech_featurizer, text_featurizer) - -signal = read_raw_audio(args.filename) -features = speech_featurizer.tf_extract(signal) -input_length = math_util.get_reduced_length(tf.shape(features)[0], rnnt.time_reduction_factor) - -if args.beam_width: - transcript = rnnt.recognize_beam(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...])) - logger.info("Transcript:", transcript[0].numpy().decode("UTF-8")) -elif args.timestamp: - transcript, stime, etime, _, _, _ = rnnt.recognize_tflite_with_timestamp( - signal=signal, - predicted=tf.constant(text_featurizer.blank, dtype=tf.int32), - encoder_states=rnnt.encoder.get_initial_state(), - prediction_states=rnnt.predict_net.get_initial_state(), - ) - logger.info("Transcript:", transcript) - logger.info("Start time:", stime) - logger.info("End time:", etime) -else: - transcript = rnnt.recognize(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...])) - logger.info("Transcript:", transcript[0].numpy().decode("UTF-8")) +# # Copyright 2020 Huy Le Nguyen (@nglehuy) +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. + +# import argparse + +# from tensorflow_asr.utils import data_util, env_util, math_util + +# logger = env_util.setup_environment() +# import tensorflow as tf + +# parser = argparse.ArgumentParser(prog="Rnn Transducer non streaming") + +# parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back") + +# parser.add_argument("--config", type=str, default=None, help="Path to rnnt config yaml") + +# parser.add_argument("--saved", type=str, default=None, help="Path to rnnt saved h5 weights") + +# parser.add_argument("--beam_width", type=int, default=0, help="Beam width") + +# parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp") + +# parser.add_argument("--device", type=int, default=0, help="Device's id to run test on") + +# parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu") + +# parser.add_argument("--subwords", default=False, action="store_true", help="Path to file that stores generated subwords") + +# parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") + +# args = parser.parse_args() + +# env_util.setup_devices([args.device], cpu=args.cpu) + +# from tensorflow_asr.configs import Config +# from tensorflow_asr.features.speech_featurizers import SpeechFeaturizer, read_raw_audio +# from tensorflow_asr.models.transducer.rnnt import RnnTransducer +# from tensorflow_asr.tokenizers import CharTokenizer, SentencePieceTokenizer, SubwordFeaturizer + +# config = Config(args.config) +# speech_featurizer = SpeechFeaturizer(config.speech_config) +# if args.sentence_piece: +# logger.info("Loading SentencePiece model ...") +# text_featurizer = SentencePieceTokenizer(config.decoder_config) +# elif args.subwords: +# logger.info("Loading subwords ...") +# text_featurizer = SubwordFeaturizer(config.decoder_config) +# else: +# text_featurizer = CharTokenizer(config.decoder_config) +# text_featurizer.decoder_config.beam_width = args.beam_width + +# # build model +# rnnt = RnnTransducer(**config.model_config, vocab_size=text_featurizer.num_classes) +# rnnt.make(speech_featurizer.shape) +# rnnt.load_weights(args.saved, by_name=True, skip_mismatch=True) +# rnnt.summary() +# rnnt.add_featurizers(speech_featurizer, text_featurizer) + +# signal = read_raw_audio(args.filename) +# features = speech_featurizer.tf_extract(signal) +# input_length = math_util.get_reduced_length(tf.shape(features)[0], rnnt.time_reduction_factor) + +# if args.beam_width: +# transcript = rnnt.recognize_beam(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...])) +# logger.info("Transcript:", transcript[0].numpy().decode("UTF-8")) +# elif args.timestamp: +# transcript, stime, etime, _, _, _ = rnnt.recognize_tflite_with_timestamp( +# signal=signal, +# predicted=tf.constant(text_featurizer.blank, dtype=tf.int32), +# encoder_states=rnnt.encoder.get_initial_state(), +# prediction_states=rnnt.predict_net.get_initial_state(), +# ) +# logger.info("Transcript:", transcript) +# logger.info("Start time:", stime) +# logger.info("End time:", etime) +# else: +# transcript = rnnt.recognize(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...])) +# logger.info("Transcript:", transcript[0].numpy().decode("UTF-8")) diff --git a/examples/inferences/streaming_tflite_conformer.py b/examples/inferences/streaming_tflite_conformer.py index 321f2a9c5f..46c0523a58 100644 --- a/examples/inferences/streaming_tflite_conformer.py +++ b/examples/inferences/streaming_tflite_conformer.py @@ -1,172 +1,172 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# # Copyright 2020 Huy Le Nguyen (@nglehuy) +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. -import argparse -import queue -import sys -from multiprocessing import Event, Manager, Process +# import argparse +# import queue +# import sys +# from multiprocessing import Event, Manager, Process -import numpy as np -import sounddevice as sd -import soundfile as sf -import tensorflow as tf +# import numpy as np +# import sounddevice as sd +# import soundfile as sf +# import tensorflow as tf -def int_or_str(text): - """Helper function for argument parsing.""" - try: - return int(text) - except ValueError: - return text - - -parser = argparse.ArgumentParser(prog="Conformer audio file streaming") +# def int_or_str(text): +# """Helper function for argument parsing.""" +# try: +# return int(text) +# except ValueError: +# return text + + +# parser = argparse.ArgumentParser(prog="Conformer audio file streaming") -parser.add_argument("-l", "--list-devices", action="store_true", help="show list of audio devices and exit") - -args, remaining = parser.parse_known_args() - -if args.list_devices: - print(sd.query_devices()) - parser.exit(0) - -parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back") - -parser.add_argument("-d", "--device", type=int_or_str, help="output device (numeric ID or substring)") - -parser.add_argument("-b", "--blocksize", type=int, default=4096, help="block size (default: %(default)s)") - -parser.add_argument("-q", "--buffersize", type=int, default=20, help="number of blocks used for buffering (default: %(default)s)") - -parser.add_argument("--tflite", type=str, default=None, help="Path to conformer tflite") - -parser.add_argument("--blank", type=int, default=0, help="Path to conformer tflite") - -parser.add_argument("--num_rnns", type=int, default=1, help="Number of RNN layers in prediction network") - -parser.add_argument("--nstates", type=int, default=2, help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)") - -parser.add_argument("--statesize", type=int, default=320, help="Size of RNN state in prediction network") - -args = parser.parse_args(remaining) - -if args.blocksize == 0: - parser.error("blocksize must not be zero") -if args.buffersize < 1: - parser.error("buffersize must be at least 1") - -q = queue.Queue(maxsize=args.buffersize) -m = Manager() -Q = m.Queue() -E = Event() - - -def recognizer(Q): - tflitemodel = tf.lite.Interpreter(model_path=args.tflite) - - input_details = tflitemodel.get_input_details() - output_details = tflitemodel.get_output_details() - - tflitemodel.resize_tensor_input(input_details[0]["index"], [args.blocksize]) - tflitemodel.allocate_tensors() - - def recognize(signal, lastid, states): - if signal.shape[0] < args.blocksize: - signal = tf.pad(signal, [[0, args.blocksize - signal.shape[0]]]) - tflitemodel.set_tensor(input_details[0]["index"], signal) - tflitemodel.set_tensor(input_details[1]["index"], lastid) - tflitemodel.set_tensor(input_details[2]["index"], states) - tflitemodel.invoke() - upoints = tflitemodel.get_tensor(output_details[0]["index"]) - lastid = tflitemodel.get_tensor(output_details[1]["index"]) - states = tflitemodel.get_tensor(output_details[2]["index"]) - text = "".join([chr(u) for u in upoints]) - return text, lastid, states - - lastid = args.blank * tf.ones(shape=[], dtype=tf.int32) - states = tf.zeros(shape=[args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32) - transcript = "" - - while True: - try: - data = Q.get() - text, lastid, states = recognize(data, lastid, states) - transcript += text - print(transcript, flush=True) - except queue.Empty: - pass - - -tflite_process = Process(target=recognizer, args=[Q]) -tflite_process.start() - - -def send(q, Q, E): - def callback(outdata, frames, time, status): - assert frames == args.blocksize - if status.output_underflow: - print("Output underflow: increase blocksize?", file=sys.stderr) - raise sd.CallbackAbort - assert not status - try: - data = q.get_nowait() - Q.put(np.frombuffer(data, dtype=np.float32)) - except queue.Empty as e: - print("Buffer is empty: increase buffersize?", file=sys.stderr) - raise sd.CallbackAbort from e - if len(data) < len(outdata): - outdata[: len(data)] = data - outdata[len(data) :] = b"\x00" * (len(outdata) - len(data)) - raise sd.CallbackStop - else: - outdata[:] = data - - try: - with sf.SoundFile(args.filename) as f: - for _ in range(args.buffersize): - data = f.buffer_read(args.blocksize, dtype="float32") - if not data: - break - q.put_nowait(data) # Pre-fill queue - stream = sd.RawOutputStream( - samplerate=f.samplerate, - blocksize=args.blocksize, - device=args.device, - channels=f.channels, - dtype="float32", - callback=callback, - finished_callback=E.set, - ) - with stream: - timeout = args.blocksize * args.buffersize / f.samplerate - while data: - data = f.buffer_read(args.blocksize, dtype="float32") - q.put(data, timeout=timeout) - E.wait() - - except KeyboardInterrupt: - parser.exit("\nInterrupted by user") - except queue.Full: - # A timeout occurred, i.e. there was an error in the callback - parser.exit(1) - except Exception as e: - parser.exit(type(e).__name__ + ": " + str(e)) - - -send_process = Process(target=send, args=[q, Q, E]) -send_process.start() -send_process.join() -send_process.close() - -tflite_process.terminate() +# parser.add_argument("-l", "--list-devices", action="store_true", help="show list of audio devices and exit") + +# args, remaining = parser.parse_known_args() + +# if args.list_devices: +# print(sd.query_devices()) +# parser.exit(0) + +# parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back") + +# parser.add_argument("-d", "--device", type=int_or_str, help="output device (numeric ID or substring)") + +# parser.add_argument("-b", "--blocksize", type=int, default=4096, help="block size (default: %(default)s)") + +# parser.add_argument("-q", "--buffersize", type=int, default=20, help="number of blocks used for buffering (default: %(default)s)") + +# parser.add_argument("--tflite", type=str, default=None, help="Path to conformer tflite") + +# parser.add_argument("--blank", type=int, default=0, help="Path to conformer tflite") + +# parser.add_argument("--num_rnns", type=int, default=1, help="Number of RNN layers in prediction network") + +# parser.add_argument("--nstates", type=int, default=2, help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)") + +# parser.add_argument("--statesize", type=int, default=320, help="Size of RNN state in prediction network") + +# args = parser.parse_args(remaining) + +# if args.blocksize == 0: +# parser.error("blocksize must not be zero") +# if args.buffersize < 1: +# parser.error("buffersize must be at least 1") + +# q = queue.Queue(maxsize=args.buffersize) +# m = Manager() +# Q = m.Queue() +# E = Event() + + +# def recognizer(Q): +# tflitemodel = tf.lite.Interpreter(model_path=args.tflite) + +# input_details = tflitemodel.get_input_details() +# output_details = tflitemodel.get_output_details() + +# tflitemodel.resize_tensor_input(input_details[0]["index"], [args.blocksize]) +# tflitemodel.allocate_tensors() + +# def recognize(signal, lastid, states): +# if signal.shape[0] < args.blocksize: +# signal = tf.pad(signal, [[0, args.blocksize - signal.shape[0]]]) +# tflitemodel.set_tensor(input_details[0]["index"], signal) +# tflitemodel.set_tensor(input_details[1]["index"], lastid) +# tflitemodel.set_tensor(input_details[2]["index"], states) +# tflitemodel.invoke() +# upoints = tflitemodel.get_tensor(output_details[0]["index"]) +# lastid = tflitemodel.get_tensor(output_details[1]["index"]) +# states = tflitemodel.get_tensor(output_details[2]["index"]) +# text = "".join([chr(u) for u in upoints]) +# return text, lastid, states + +# lastid = args.blank * tf.ones(shape=[], dtype=tf.int32) +# states = tf.zeros(shape=[args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32) +# transcript = "" + +# while True: +# try: +# data = Q.get() +# text, lastid, states = recognize(data, lastid, states) +# transcript += text +# print(transcript, flush=True) +# except queue.Empty: +# pass + + +# tflite_process = Process(target=recognizer, args=[Q]) +# tflite_process.start() + + +# def send(q, Q, E): +# def callback(outdata, frames, time, status): +# assert frames == args.blocksize +# if status.output_underflow: +# print("Output underflow: increase blocksize?", file=sys.stderr) +# raise sd.CallbackAbort +# assert not status +# try: +# data = q.get_nowait() +# Q.put(np.frombuffer(data, dtype=np.float32)) +# except queue.Empty as e: +# print("Buffer is empty: increase buffersize?", file=sys.stderr) +# raise sd.CallbackAbort from e +# if len(data) < len(outdata): +# outdata[: len(data)] = data +# outdata[len(data) :] = b"\x00" * (len(outdata) - len(data)) +# raise sd.CallbackStop +# else: +# outdata[:] = data + +# try: +# with sf.SoundFile(args.filename) as f: +# for _ in range(args.buffersize): +# data = f.buffer_read(args.blocksize, dtype="float32") +# if not data: +# break +# q.put_nowait(data) # Pre-fill queue +# stream = sd.RawOutputStream( +# samplerate=f.samplerate, +# blocksize=args.blocksize, +# device=args.device, +# channels=f.channels, +# dtype="float32", +# callback=callback, +# finished_callback=E.set, +# ) +# with stream: +# timeout = args.blocksize * args.buffersize / f.samplerate +# while data: +# data = f.buffer_read(args.blocksize, dtype="float32") +# q.put(data, timeout=timeout) +# E.wait() + +# except KeyboardInterrupt: +# parser.exit("\nInterrupted by user") +# except queue.Full: +# # A timeout occurred, i.e. there was an error in the callback +# parser.exit(1) +# except Exception as e: +# parser.exit(type(e).__name__ + ": " + str(e)) + + +# send_process = Process(target=send, args=[q, Q, E]) +# send_process.start() +# send_process.join() +# send_process.close() + +# tflite_process.terminate() diff --git a/examples/models/transducer/conformer/inference/gen_saved_model.py b/examples/models/transducer/conformer/inference/gen_saved_model.py index 0048351bd0..c9cc875950 100644 --- a/examples/models/transducer/conformer/inference/gen_saved_model.py +++ b/examples/models/transducer/conformer/inference/gen_saved_model.py @@ -1,56 +1,56 @@ -# pylint: disable=no-member -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import fire -import tensorflow as tf - -from tensorflow_asr.configs import Config -from tensorflow_asr.helpers import featurizer_helpers -from tensorflow_asr.models.transducer.conformer import Conformer -from tensorflow_asr.utils import env_util - -logger = env_util.setup_environment() - -DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") - - -def main( - config_path: str = DEFAULT_YAML, - saved: str = None, - output_dir: str = None, -): - assert saved and output_dir - tf.random.set_seed(0) - tf.keras.backend.clear_session() - - logger.info("Load config and featurizers ...") - config = Config(config_path) - speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(config=config) - - logger.info("Build and load model ...") - conformer = Conformer(**config.model_config, vocab_size=text_featurizer.num_classes) - conformer.make(speech_featurizer.shape) - conformer.add_featurizers(speech_featurizer, text_featurizer) - conformer.load_weights(saved, by_name=True) - conformer.summary() - - logger.info("Save model ...") - tf.saved_model.save(conformer, export_dir=output_dir, signatures=conformer.recognize_from_signal.get_concrete_function()) - - -if __name__ == "__main__": - fire.Fire(main) +# # pylint: disable=no-member +# # Copyright 2020 Huy Le Nguyen (@nglehuy) +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. + +# import os + +# import fire +# from tensorflow_asr import tf, keras + +# from tensorflow_asr.configs import Config +# from tensorflow_asr.helpers import featurizer_helpers +# from tensorflow_asr.models.transducer.conformer import Conformer +# from tensorflow_asr.utils import env_util + +# logger = env_util.setup_environment() + +# DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") + + +# def main( +# config_path: str = DEFAULT_YAML, +# saved: str = None, +# output_dir: str = None, +# ): +# assert saved and output_dir +# tf.random.set_seed(0) +# keras.backend.clear_session() + +# logger.info("Load config and featurizers ...") +# config = Config(config_path) +# speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(config=config) + +# logger.info("Build and load model ...") +# conformer = Conformer(**config.model_config, vocab_size=text_featurizer.num_classes) +# conformer.make(speech_featurizer.shape) +# conformer.add_featurizers(speech_featurizer, text_featurizer) +# conformer.load_weights(saved, by_name=True) +# conformer.summary() + +# logger.info("Save model ...") +# tf.saved_model.save(conformer, export_dir=output_dir, signatures=conformer.recognize_from_signal.get_concrete_function()) + + +# if __name__ == "__main__": +# fire.Fire(main) diff --git a/examples/models/transducer/conformer/inference/run_saved_model.py b/examples/models/transducer/conformer/inference/run_saved_model.py index eb00912d9d..56da5da980 100644 --- a/examples/models/transducer/conformer/inference/run_saved_model.py +++ b/examples/models/transducer/conformer/inference/run_saved_model.py @@ -1,43 +1,43 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# # Copyright 2020 Huy Le Nguyen (@nglehuy) +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. -import os +# import os -import fire -import tensorflow as tf +# import fire +# from tensorflow_asr import tf, keras -from tensorflow_asr.features.speech_featurizers import read_raw_audio -from tensorflow_asr.utils import env_util +# from tensorflow_asr.features.speech_featurizers import read_raw_audio +# from tensorflow_asr.utils import env_util -logger = env_util.setup_environment() +# logger = env_util.setup_environment() -DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") +# DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") -def main( - saved_model: str = None, - filename: str = None, -): - tf.keras.backend.clear_session() +# def main( +# saved_model: str = None, +# filename: str = None, +# ): +# keras.backend.clear_session() - module = tf.saved_model.load(export_dir=saved_model) +# module = tf.saved_model.load(export_dir=saved_model) - signal = read_raw_audio(filename) - transcript = module.pred(signal) +# signal = read_raw_audio(filename) +# transcript = module.pred(signal) - print("Transcript: ", "".join([chr(u) for u in transcript])) +# print("Transcript: ", "".join([chr(u) for u in transcript])) -if __name__ == "__main__": - fire.Fire(main) +# if __name__ == "__main__": +# fire.Fire(main) diff --git a/examples/models/transducer/rnnt/results/sentencepiece/README.md b/examples/models/transducer/rnnt/results/sentencepiece/README.md new file mode 100644 index 0000000000..03b137f092 --- /dev/null +++ b/examples/models/transducer/rnnt/results/sentencepiece/README.md @@ -0,0 +1,57 @@ +- [SentencePiece 256 + Tiny + LibriSpeech](#sentencepiece-256--tiny--librispeech) + - [Training Loss](#training-loss) + - [1. Epoch Loss](#1-epoch-loss) + - [2. Batch Loss](#2-batch-loss) + - [Results](#results) + + +# SentencePiece 256 + Tiny + LibriSpeech + +| Category | Description | +| :---------------- | :------------------------------- | +| Config | [tiny.yml.j2](../../tiny.yml.j2) | +| Tensorflow | **2.15.x** | +| Device | NVIDIA GeForce GTX 1650 | +| Global Batch Size | 3 | +| Max Epochs | 300 | + + +### Training Loss + +#### 1. Epoch Loss + +![Epoch Loss](./figs/rnnt-tiny-sp256-epoch-loss.svg) + +#### 2. Batch Loss + +![Batch Loss](./figs/rnnt-tiny-sp256-batch-loss.svg) + + +### Results + +Pretrain Model here: [link](https://drive.google.com/drive/folders/1h0BrCzZo8JTz_MUU5bJPJ3UBqroBnsuv?usp=sharing) + +```json +[ + { + "epoch": 136, + "test-clean": { + "greedy": { + "wer": 0.15853241022519782, + "cer": 0.07179696657549817, + "mer": 0.15537908021549876, + "wil": 0.2587056704145151, + "wip": 0.7412943295854849 + } + }, + "test-other": { + "greedy": { + "wer": 0.3457577899623636, + "cer": 0.18733822655980759, + "mer": 0.33391759995571874, + "wil": 0.5185365485613327, + "wip": 0.48146345143866726 + } + } + }, +] \ No newline at end of file diff --git a/examples/models/transducer/rnnt/results/sentencepiece/figs/rnnt-tiny-sp256-batch-loss.svg b/examples/models/transducer/rnnt/results/sentencepiece/figs/rnnt-tiny-sp256-batch-loss.svg new file mode 100644 index 0000000000..c90c689ff3 --- /dev/null +++ b/examples/models/transducer/rnnt/results/sentencepiece/figs/rnnt-tiny-sp256-batch-loss.svg @@ -0,0 +1 @@ +303540455055606570-100k0100k200k300k400k500k600k700k800k \ No newline at end of file diff --git a/examples/models/transducer/rnnt/results/sentencepiece/figs/rnnt-tiny-sp256-epoch-loss.svg b/examples/models/transducer/rnnt/results/sentencepiece/figs/rnnt-tiny-sp256-epoch-loss.svg new file mode 100644 index 0000000000..21b438c108 --- /dev/null +++ b/examples/models/transducer/rnnt/results/sentencepiece/figs/rnnt-tiny-sp256-epoch-loss.svg @@ -0,0 +1 @@ +3436384042444648505254565860-20020406080100120140 \ No newline at end of file diff --git a/examples/models/transducer/rnnt/small.yml.j2 b/examples/models/transducer/rnnt/small.yml.j2 index 8053bb560c..1c336564e2 100644 --- a/examples/models/transducer/rnnt/small.yml.j2 +++ b/examples/models/transducer/rnnt/small.yml.j2 @@ -22,6 +22,7 @@ model_config: prob: 1.0 num_masks: 1 mask_factor: 27 + encoder_reduction_positions: [ post, post, post, post ] encoder_reduction_factors: [ 3, 0, 2, 0 ] # downsampled to 30ms and add 2 reduction after second layer encoder_dmodel: 320 encoder_rnn_type: lstm diff --git a/examples/models/transducer/rnnt/tiny.yml.j2 b/examples/models/transducer/rnnt/tiny.yml.j2 index 86d2c57057..1853790a46 100644 --- a/examples/models/transducer/rnnt/tiny.yml.j2 +++ b/examples/models/transducer/rnnt/tiny.yml.j2 @@ -22,6 +22,7 @@ model_config: prob: 1.0 num_masks: 1 mask_factor: 27 + encoder_reduction_positions: [ pre, pre, pre, pre ] encoder_reduction_factors: [ 3, 0, 2, 0 ] # downsampled to 30ms and add 2 reduction after second layer encoder_dmodel: 128 encoder_rnn_type: lstm diff --git a/examples/save.py b/examples/save.py new file mode 100644 index 0000000000..a7ab0d5512 --- /dev/null +++ b/examples/save.py @@ -0,0 +1,51 @@ +# Copyright 2024 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from tensorflow_asr import tf, keras +from tensorflow_asr import tokenizers +from tensorflow_asr.configs import Config +from tensorflow_asr.models.base_model import BaseModel +from tensorflow_asr.utils import cli_util, env_util, file_util + + +def main( + config_path: str, + output: str, + h5: str = None, + bs: int = 2, + repodir: str = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")), +): + assert output + keras.backend.clear_session() + env_util.setup_seed() + + config = Config(config_path, training=False, repodir=repodir) + tokenizer = tokenizers.get(config) + + model: BaseModel = keras.models.model_from_config(config.model_config) + model.tokenizer = tokenizer + model.make(batch_size=bs) + if h5 and tf.io.gfile.exists(h5): + model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5)) + model.summary() + + model.save(output) + print(model.to_json()) + keras.utils.plot_model(model, to_file=f"{output}.png", show_shapes=True, show_dtype=True, expand_nested=True, show_layer_activations=True) + + +if __name__ == "__main__": + cli_util.run(main) diff --git a/examples/test.py b/examples/test.py index b427ba1e3b..30c87ac930 100644 --- a/examples/test.py +++ b/examples/test.py @@ -16,7 +16,7 @@ import json import os -from tensorflow_asr import datasets, tf, tokenizers # import to aid logging messages +from tensorflow_asr import datasets, tf, keras, tokenizers # import to aid logging messages from tensorflow_asr.callbacks import PredictLogger from tensorflow_asr.configs import Config from tensorflow_asr.models.base_model import BaseModel @@ -50,7 +50,7 @@ def main( tokenizer = tokenizers.get(config) - model: BaseModel = tf.keras.models.model_from_config(config.model_config) + model: BaseModel = keras.models.model_from_config(config.model_config) model.tokenizer = tokenizer model.make(batch_size=batch_size) model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5), skip_mismatch=False) diff --git a/examples/tflite.py b/examples/tflite.py index f2d46c56ff..d8b3b9cfd5 100644 --- a/examples/tflite.py +++ b/examples/tflite.py @@ -14,7 +14,7 @@ import os -from tensorflow_asr import tf # import to aid logging messages +from tensorflow_asr import tf, keras # import to aid logging messages from tensorflow_asr import tokenizers from tensorflow_asr.configs import Config from tensorflow_asr.models.base_model import BaseModel @@ -30,13 +30,13 @@ def main( repodir: str = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")), ): assert output - tf.keras.backend.clear_session() + keras.backend.clear_session() env_util.setup_seed() config = Config(config_path, training=False, repodir=repodir) tokenizer = tokenizers.get(config) - model: BaseModel = tf.keras.models.model_from_config(config.model_config) + model: BaseModel = keras.models.model_from_config(config.model_config) model.tokenizer = tokenizer model.make(batch_size=bs) if h5 and tf.io.gfile.exists(h5): diff --git a/examples/train.py b/examples/train.py index 62b88413c9..f20c704a83 100644 --- a/examples/train.py +++ b/examples/train.py @@ -15,7 +15,7 @@ import json import os -from tensorflow_asr import callbacks, datasets, tf, tokenizers # import to aid logging messages +from tensorflow_asr import callbacks, datasets, tf, keras, tokenizers # import to aid logging messages from tensorflow_asr.configs import Config from tensorflow_asr.models.base_model import BaseModel from tensorflow_asr.utils import cli_util, env_util, file_util @@ -36,7 +36,7 @@ def main( ga_steps: int = None, repodir: str = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")), ): - tf.keras.backend.clear_session() + keras.backend.clear_session() env_util.setup_seed() strategy = env_util.setup_strategy(devices) env_util.setup_mxp(mxp=mxp) @@ -73,7 +73,7 @@ def main( logger.info(f"eval_data_loader.element_spec = {json.dumps(eval_data_loader.element_spec, indent=2, default=str)}") with strategy.scope(): - model: BaseModel = tf.keras.models.model_from_config(config.model_config) + model: BaseModel = keras.models.model_from_config(config.model_config) model.tokenizer = tokenizer output_shapes = model.make(**shapes) if config.learning_config.pretrained: @@ -83,7 +83,7 @@ def main( skip_mismatch=True, ) model.compile( - optimizer=tf.keras.optimizers.get(config.learning_config.optimizer_config), + optimizer=keras.optimizers.get(config.learning_config.optimizer_config), output_shapes=output_shapes, steps_per_execution=spx, jit_compile=jit_compile, diff --git a/requirements.txt b/requirements.txt index eb9076cf47..91b58751cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,9 +14,11 @@ pytest~=7.4.1 black~=24.3.0 pylint~=3.2.1 matplotlib~=3.7.2 -pydot~=1.4.2 +pydot-ng~=1.4.2 graphviz~=0.20.1 pre-commit~=3.7.0 +tf2onnx~=1.16.1 +netron~=7.6.8 # extra=tf2-12 tensorflow~=2.12.0 diff --git a/scripts/create_mls_trans.py b/scripts/create_mls_trans.py index 3e1a82f3ac..b118820f44 100644 --- a/scripts/create_mls_trans.py +++ b/scripts/create_mls_trans.py @@ -16,29 +16,16 @@ import os import librosa -import tensorflow as tf +import keras import tqdm # example usage: python create_mls_trans.py -dataset-home /mnt/datasets/mls --language polish --opus base_url = "https://dl.fbaipublicfiles.com/mls/" -langs = [ - "dutch", - "english", - "german", - "french", - "italian", - "portuguese", - "polish", - "spanish" -] - -splits = [ - "dev", - "test", - "train" -] +langs = ["dutch", "english", "german", "french", "italian", "portuguese", "polish", "spanish"] + +splits = ["dev", "test", "train"] chars = set() @@ -46,17 +33,17 @@ def prepare_split(dataset_dir, split, opus=False): # Setup necessary paths split_home = os.path.join(dataset_dir, split) - transcripts_infile = os.path.join(split_home, 'transcripts.txt') - transcripts_outfile = os.path.join(split_home, 'transcripts_tfasr.tsv') + transcripts_infile = os.path.join(split_home, "transcripts.txt") + transcripts_outfile = os.path.join(split_home, "transcripts_tfasr.tsv") audio_home = os.path.join(split_home, "audio") extension = ".opus" if opus else ".flac" transcripts = [] # Make paths absolute, get durations and read chars to form alphabet later on - with open(transcripts_infile, 'r', encoding='utf8') as infile: + with open(transcripts_infile, "r", encoding="utf8") as infile: for line in tqdm.tqdm(infile.readlines(), desc=f"Reading from {transcripts_infile}..."): - file_id, transcript = line.strip().split('\t') - speaker_id, book_id, _ = file_id.split('_') + file_id, transcript = line.strip().split("\t") + speaker_id, book_id, _ = file_id.split("_") audio_path = os.path.join(audio_home, speaker_id, book_id, f"{file_id}{extension}") y, sr = librosa.load(audio_path, sr=None) duration = librosa.get_duration(y, sr) @@ -65,7 +52,7 @@ def prepare_split(dataset_dir, split, opus=False): chars.add(char) # Write transcripts to file - with open(transcripts_outfile, 'w', encoding='utf8') as outfile: + with open(transcripts_outfile, "w", encoding="utf8") as outfile: outfile.write("PATH\tDURATION\tTRANSCRIPT\n") for t in tqdm.tqdm(transcripts, desc=f"Writing to {transcripts_outfile}"): outfile.write(t) @@ -73,7 +60,7 @@ def prepare_split(dataset_dir, split, opus=False): def make_alphabet_file(filepath, chars_list, lang): print(f"Writing alphabet to {filepath}...") - with open(filepath, 'w', encoding='utf8') as outfile: + with open(filepath, "w", encoding="utf8") as outfile: outfile.write(f"# Alphabet file for language {lang}\n") outfile.write("Automatically generated. Do not edit\n#\n") for char in sorted(list(chars_list)): @@ -84,10 +71,10 @@ def make_alphabet_file(filepath, chars_list, lang): if __name__ == "__main__": ap = argparse.ArgumentParser(description="Download and prepare MLS dataset in a given language") - ap.add_argument("--dataset-home", "-d", default=None, required=False, - help="Path to home directory to download and prepare dataset. Default to ~/.keras") - ap.add_argument("--language", "-l", type=str, choices=langs, default=None, required=True, - help="Any name of language included in MLS") + ap.add_argument( + "--dataset-home", "-d", default=None, required=False, help="Path to home directory to download and prepare dataset. Default to ~/.keras" + ) + ap.add_argument("--language", "-l", type=str, choices=langs, default=None, required=True, help="Any name of language included in MLS") ap.add_argument("--opus", default=False, action="store_true", help="Whether to use dataset in opus format or not") args = ap.parse_args() @@ -97,12 +84,7 @@ def make_alphabet_file(filepath, chars_list, lang): dataset_dir = os.path.join(dataset_home, subdir) full_url = base_url + fname - downloaded_file = tf.keras.utils.get_file( - fname, - full_url, - cache_subdir=dataset_home, - extract=True - ) + downloaded_file = keras.utils.get_file(fname, full_url, cache_subdir=dataset_home, extract=True) print(f"Dataset extracted to {dataset_dir}. Preparing...") diff --git a/tensorflow_asr/callbacks.py b/tensorflow_asr/callbacks.py index e7dd439f05..012dd73869 100644 --- a/tensorflow_asr/callbacks.py +++ b/tensorflow_asr/callbacks.py @@ -16,6 +16,7 @@ import numpy as np import tensorflow as tf +import keras from tensorflow_asr.datasets import ASRDataset from tensorflow_asr.utils import file_util @@ -24,8 +25,8 @@ serialization_lib = importlib.import_module(f"{KERAS_SRC}.saving.serialization_lib") -@tf.keras.utils.register_keras_serializable("tensorflow_asr.callbacks") -class TestLogger(tf.keras.callbacks.Callback): +@keras.utils.register_keras_serializable("tensorflow_asr.callbacks") +class TestLogger(keras.callbacks.Callback): def __init__(self): super().__init__() self.wer = {"numer": 0, "denom": 0} @@ -80,8 +81,8 @@ def from_config(cls, config): return cls(**config) -@tf.keras.utils.register_keras_serializable("tensorflow_asr.callbacks") -class PredictLogger(tf.keras.callbacks.Callback): +@keras.utils.register_keras_serializable("tensorflow_asr.callbacks") +class PredictLogger(keras.callbacks.Callback): def __init__(self, test_dataset: ASRDataset, output_file_path: str): super().__init__() self.test_dataset = test_dataset @@ -123,8 +124,8 @@ def from_config(cls, config): return cls(**config) -@tf.keras.utils.register_keras_serializable("tensorflow_asr.callbacks") -class TensorBoard(tf.keras.callbacks.TensorBoard): +@keras.utils.register_keras_serializable("tensorflow_asr.callbacks") +class TensorBoard(keras.callbacks.TensorBoard): def __init__( self, log_dir="logs", @@ -165,8 +166,8 @@ def from_config(cls, config): return cls(**config) -@tf.keras.utils.register_keras_serializable("tensorflow_asr.callbacks") -class TerminateOnNaN(tf.keras.callbacks.TerminateOnNaN): +@keras.utils.register_keras_serializable("tensorflow_asr.callbacks") +class TerminateOnNaN(keras.callbacks.TerminateOnNaN): def get_config(self): return {} @@ -175,8 +176,8 @@ def from_config(cls, config): return cls(**config) -@tf.keras.utils.register_keras_serializable("tensorflow_asr.callbacks") -class ModelCheckpoint(tf.keras.callbacks.ModelCheckpoint): +@keras.utils.register_keras_serializable("tensorflow_asr.callbacks") +class ModelCheckpoint(keras.callbacks.ModelCheckpoint): def __init__( self, filepath, @@ -203,8 +204,8 @@ def from_config(cls, config): return cls(**config) -@tf.keras.utils.register_keras_serializable("tensorflow_asr.callbacks") -class BackupAndRestore(tf.keras.callbacks.BackupAndRestore): +@keras.utils.register_keras_serializable("tensorflow_asr.callbacks") +class BackupAndRestore(keras.callbacks.BackupAndRestore): def __init__( self, backup_dir, @@ -223,8 +224,8 @@ def from_config(cls, config): return cls(**config) -@tf.keras.utils.register_keras_serializable("tensorflow_asr.callbacks") -class EarlyStopping(tf.keras.callbacks.EarlyStopping): +@keras.utils.register_keras_serializable("tensorflow_asr.callbacks") +class EarlyStopping(keras.callbacks.EarlyStopping): def get_config(self): return {} diff --git a/tensorflow_asr/losses/ctc_loss.py b/tensorflow_asr/losses/ctc_loss.py index 5e8e52bf26..f46519547a 100644 --- a/tensorflow_asr/losses/ctc_loss.py +++ b/tensorflow_asr/losses/ctc_loss.py @@ -13,14 +13,15 @@ # limitations under the License. import tensorflow as tf +import keras from tensorflow_asr.utils import env_util, math_util logger = tf.get_logger() -class CtcLoss(tf.keras.losses.Loss): - def __init__(self, blank=0, reduction=tf.keras.losses.Reduction.AUTO, name=None): +class CtcLoss(keras.losses.Loss): + def __init__(self, blank=0, reduction=keras.losses.Reduction.AUTO, name=None): super().__init__(reduction=reduction, name=name) self.blank = blank self.use_tpu = env_util.has_devices("TPU") diff --git a/tensorflow_asr/losses/rnnt_loss.py b/tensorflow_asr/losses/rnnt_loss.py index 11000c3b7c..86f5febb28 100644 --- a/tensorflow_asr/losses/rnnt_loss.py +++ b/tensorflow_asr/losses/rnnt_loss.py @@ -19,6 +19,7 @@ import numpy as np import tensorflow as tf +import keras from tensorflow_asr.utils import env_util, math_util, shape_util @@ -29,11 +30,11 @@ logger = tf.get_logger() -class RnntLoss(tf.keras.losses.Loss): +class RnntLoss(keras.losses.Loss): def __init__( self, blank, - reduction=tf.keras.losses.Reduction.AUTO, + reduction=keras.losses.Reduction.AUTO, output_shapes=None, name=None, ): diff --git a/tensorflow_asr/metrics/error_rates.py b/tensorflow_asr/metrics/error_rates.py index a3bc72b62c..0ad18001cf 100644 --- a/tensorflow_asr/metrics/error_rates.py +++ b/tensorflow_asr/metrics/error_rates.py @@ -13,9 +13,10 @@ # limitations under the License. import tensorflow as tf +import keras -class ErrorRate(tf.keras.metrics.Metric): +class ErrorRate(keras.metrics.Metric): """Metric for WER or CER""" def __init__(self, name="error_rate", **kwargs): diff --git a/tensorflow_asr/models/activations/glu.py b/tensorflow_asr/models/activations/glu.py index 4bf049ce7b..a2a7ea26a8 100644 --- a/tensorflow_asr/models/activations/glu.py +++ b/tensorflow_asr/models/activations/glu.py @@ -30,3 +30,8 @@ def call(self, inputs): def compute_output_shape(self, input_shape): B, T, V = input_shape return (B, T, V // 2) + + def get_config(self): + config = super().get_config() + config.update({"axis": self.axis}) + return config diff --git a/tensorflow_asr/models/base_layer.py b/tensorflow_asr/models/base_layer.py index 54c331aaac..c59183936d 100644 --- a/tensorflow_asr/models/base_layer.py +++ b/tensorflow_asr/models/base_layer.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf +import keras from tensorflow_asr.utils import math_util -class Layer(tf.keras.layers.Layer): +class Layer(keras.layers.Layer): def __init__( self, trainable=True, diff --git a/tensorflow_asr/models/base_model.py b/tensorflow_asr/models/base_model.py index 731fc68d02..25c595c6a4 100644 --- a/tensorflow_asr/models/base_model.py +++ b/tensorflow_asr/models/base_model.py @@ -17,6 +17,7 @@ import importlib import tensorflow as tf +import keras from keras import callbacks as callbacks_module from keras.optimizers import Optimizer from tensorflow.python.eager import context # pylint: disable=no-name-in-module @@ -42,7 +43,7 @@ logger = tf.get_logger() -class BaseModel(tf.keras.Model): +class BaseModel(keras.Model): def __init__(self, speech_config: dict, *args, **kwargs): super().__init__(*args, **kwargs) self.feature_extraction = FeatureExtraction(**speech_config, dtype=self.dtype) @@ -68,22 +69,11 @@ def save( self, filepath, overwrite=True, - include_optimizer=True, save_format=None, - signatures=None, - options=None, - save_traces=True, + **kwargs, ): with file_util.save_file(filepath) as path: - super().save( - filepath=path, - overwrite=overwrite, - include_optimizer=include_optimizer, - save_format=save_format, - signatures=signatures, - options=options, - save_traces=save_traces, - ) + super().save(filepath=path, overwrite=overwrite, save_format=save_format, **kwargs) def save_weights( self, @@ -105,7 +95,7 @@ def load_weights( with file_util.read_file(filepath) as path: super().load_weights(filepath=path, by_name=by_name, skip_mismatch=skip_mismatch, options=options) - def add_custom_metric(self, metric: tf.keras.metrics.Metric): + def add_custom_metric(self, metric: keras.metrics.Metric): if not hasattr(self, "_tfasr_metrics"): self._tfasr_metrics = {} self._tfasr_metrics[metric.name] = metric @@ -124,10 +114,10 @@ def make(self, input_shape=[None], prediction_shape=[None], batch_size=None, cac Batch size, by default None """ assert batch_size is not None and batch_size > 0 - signals = tf.keras.Input(shape=input_shape, batch_size=batch_size, dtype=tf.float32) - signals_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) - predictions = tf.keras.Input(shape=prediction_shape, batch_size=batch_size, dtype=tf.int32) - predictions_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) + signals = keras.Input(shape=input_shape, batch_size=batch_size, dtype=tf.float32) + signals_length = keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) + predictions = keras.Input(shape=prediction_shape, batch_size=batch_size, dtype=tf.int32) + predictions_length = keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) self._per_replica_batch_size = int(batch_size / self.distribute_strategy.num_replicas_in_sync) self._batch_size = batch_size outputs = self( @@ -156,13 +146,13 @@ def compile( gradn_config=None, **kwargs, ): - optimizer = tf.keras.optimizers.get(optimizer) + optimizer = keras.optimizers.get(optimizer) if env_util.has_devices("TPU"): self.use_loss_scale = False else: self.use_loss_scale = mxp != "none" if self.use_loss_scale: - optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer) + optimizer = keras.mixed_precision.LossScaleOptimizer(optimizer) logger.info("Using loss scale") if isinstance(ga_steps, int) and ga_steps > 1: self.use_ga = True @@ -171,7 +161,7 @@ def compile( else: self.use_ga = False self.gwn_config = gwn_config - self.gradn = tf.keras.regularizers.get(gradn_config) if gradn_config else None + self.gradn = keras.regularizers.get(gradn_config) if gradn_config else None self.distribute_reduction_method = "sum" super().compile(optimizer=optimizer, loss=loss, run_eagerly=run_eagerly, **kwargs) @@ -278,12 +268,13 @@ def test_step(self, data): def predict_step(self, data): x, y_true = data + batch_size, *_ = shape_util.shape_list(x["inputs"]) inputs = schemas.PredictInput( inputs=x["inputs"], inputs_length=x["inputs_length"], - previous_tokens=self.get_initial_tokens(), - previous_encoder_states=self.get_initial_encoder_states(), - previous_decoder_states=self.get_initial_decoder_states(), + previous_tokens=self.get_initial_tokens(batch_size=batch_size), + previous_encoder_states=self.get_initial_encoder_states(batch_size=batch_size), + previous_decoder_states=self.get_initial_decoder_states(batch_size=batch_size), ) _tokens = self.recognize(inputs=inputs).tokens _beam_tokens = self.recognize_beam(inputs=inputs).tokens @@ -508,7 +499,7 @@ def fit( steps_per_execution=self._steps_per_execution, ) - # Container that configures and calls `tf.keras.Callback`s. + # Container that configures and calls `keras.Callback`s. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( callbacks, @@ -570,45 +561,46 @@ def fit( "`Model.compile(..., run_eagerly=True)`, or " "`tf.config.run_functions_eagerly(True)` for more " "information of where went wrong, or file a " - "issue/bug to `tf.keras`." + "issue/bug to `keras`." ) # Override with model metrics instead of last step logs logs = self._validate_and_get_metrics_result(logs) epoch_logs = copy.copy(logs) # Run validation. - if validation_data and self._should_eval(epoch, validation_freq): - # Create data_handler for evaluation and cache it. - if getattr(self, "_eval_data_handler", None) is None: - self._eval_data_handler = data_adapter.get_data_handler( + if validation_data: + if self._should_eval(epoch, validation_freq): + # Create data_handler for evaluation and cache it. + if getattr(self, "_eval_data_handler", None) is None: + self._eval_data_handler = data_adapter.get_data_handler( + x=val_x, + y=val_y, + sample_weight=val_sample_weight, + batch_size=validation_batch_size or batch_size, + steps_per_epoch=validation_steps, + initial_epoch=0, + epochs=1, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + model=self, + steps_per_execution=self._steps_per_execution, + ) + val_logs = self.evaluate( x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, - steps_per_epoch=validation_steps, - initial_epoch=0, - epochs=1, + steps=validation_steps, + callbacks=callbacks, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, - model=self, - steps_per_execution=self._steps_per_execution, + return_dict=True, + _use_cached_eval_dataset=True, ) - val_logs = self.evaluate( - x=val_x, - y=val_y, - sample_weight=val_sample_weight, - batch_size=validation_batch_size or batch_size, - steps=validation_steps, - callbacks=callbacks, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing, - return_dict=True, - _use_cached_eval_dataset=True, - ) - val_logs = {"val_" + name: val for name, val in val_logs.items()} - epoch_logs.update(val_logs) + val_logs = {"val_" + name: val for name, val in val_logs.items()} + epoch_logs.update(val_logs) callbacks.on_epoch_end(epoch, epoch_logs) training_logs = epoch_logs diff --git a/tensorflow_asr/models/ctc/base_ctc.py b/tensorflow_asr/models/ctc/base_ctc.py index cd48572d83..be59b9b6bc 100644 --- a/tensorflow_asr/models/ctc/base_ctc.py +++ b/tensorflow_asr/models/ctc/base_ctc.py @@ -14,6 +14,7 @@ import tensorflow as tf +import keras from tensorflow_asr import schemas from tensorflow_asr.losses.ctc_loss import CtcLoss @@ -26,8 +27,8 @@ def __init__( self, blank: int, speech_config: dict, - encoder: tf.keras.layers.Layer, - decoder: tf.keras.layers.Layer, + encoder: keras.layers.Layer, + decoder: keras.layers.Layer, **kwargs, ): super().__init__(speech_config=speech_config, **kwargs) @@ -94,9 +95,6 @@ def call_next( outputs, outputs_length, next_decoder_states = self.decoder.call_next(outputs, outputs_length, previous_decoder_states) return outputs, outputs_length, next_encoder_states, next_decoder_states - def get_initial_tokens(self, batch_size=1): - return super().get_initial_tokens(batch_size) - def get_initial_encoder_states(self, batch_size=1): return tf.zeros([], dtype=self.dtype) diff --git a/tensorflow_asr/models/ctc/conformer.py b/tensorflow_asr/models/ctc/conformer.py index cd1a5f2834..057e37aff8 100644 --- a/tensorflow_asr/models/ctc/conformer.py +++ b/tensorflow_asr/models/ctc/conformer.py @@ -13,6 +13,7 @@ # limitations under the License. import tensorflow as tf +import keras from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.models.ctc.base_ctc import CtcModel @@ -29,7 +30,7 @@ def __init__( ): super().__init__(**kwargs) self._vocab_size = vocab_size - self.vocab = tf.keras.layers.Dense( + self.vocab = keras.layers.Dense( units=vocab_size, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, @@ -51,8 +52,19 @@ def compute_output_shape(self, input_shape): outputs_shape = logits_shape[:-1] + (self._vocab_size,) return tuple(outputs_shape), tuple(logits_length_shape) + def get_config(self): + config = super().get_config() + config.update( + { + "vocab_size": self._vocab_size, + "kernel_regularizer": self.vocab.kernel_regularizer, + "bias_regularizer": self.vocab.bias_regularizer, + } + ) + return config + -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.ctc") +@keras.utils.register_keras_serializable("tensorflow_asr.models.ctc") class Conformer(CtcModel): def __init__( self, @@ -141,7 +153,7 @@ def make(self, input_shape=[None], prediction_shape=[None], batch_size=None, **k None if self.encoder._memory_length is None else [ - tf.keras.Input(shape=[self.encoder._memory_length, self.encoder._dmodel], batch_size=batch_size, dtype=tf.float32) + keras.Input(shape=[self.encoder._memory_length, self.encoder._dmodel], batch_size=batch_size, dtype=tf.float32) for _ in range(self.encoder._num_blocks) ] ) diff --git a/tensorflow_asr/models/ctc/deepspeech2.py b/tensorflow_asr/models/ctc/deepspeech2.py index 7fbf62afd8..0660439fb4 100644 --- a/tensorflow_asr/models/ctc/deepspeech2.py +++ b/tensorflow_asr/models/ctc/deepspeech2.py @@ -13,6 +13,7 @@ # limitations under the License. import tensorflow as tf +import keras from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.models.ctc.base_ctc import CtcModel @@ -29,7 +30,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.vocab = tf.keras.layers.Dense( + self.vocab = keras.layers.Dense( vocab_size, name="logits", kernel_regularizer=kernel_regularizer, @@ -53,8 +54,20 @@ def compute_output_shape(self, input_shape): output_shape = self.vocab.compute_output_shape(output_shape) return output_shape, output_length_shape + def get_config(self): + config = super().get_config() + config.update( + { + "vocab_size": self.vocab.units, + "kernel_regularizer": self.vocab.kernel_regularizer, + "bias_regularizer": self.vocab.bias_regularizer, + "initializer": self.vocab.kernel_initializer, + } + ) + return config + -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.ctc") +@keras.utils.register_keras_serializable("tensorflow_asr.models.ctc") class DeepSpeech2(CtcModel): def __init__( self, diff --git a/tensorflow_asr/models/ctc/jasper.py b/tensorflow_asr/models/ctc/jasper.py index db5b2da088..54f9b33874 100644 --- a/tensorflow_asr/models/ctc/jasper.py +++ b/tensorflow_asr/models/ctc/jasper.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf +import keras from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.models.ctc.base_ctc import CtcModel @@ -57,7 +57,7 @@ def compute_output_shape(self, input_shape): return tuple(outputs_shape), tuple(logits_length_shape) -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.ctc") +@keras.utils.register_keras_serializable("tensorflow_asr.models.ctc") class Jasper(CtcModel): def __init__( self, diff --git a/tensorflow_asr/models/ctc/transformer.py b/tensorflow_asr/models/ctc/transformer.py index a37ef36f85..ca0214d922 100644 --- a/tensorflow_asr/models/ctc/transformer.py +++ b/tensorflow_asr/models/ctc/transformer.py @@ -13,6 +13,7 @@ # limitations under the License. import tensorflow as tf +import keras from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.models.ctc.base_ctc import CtcModel @@ -29,7 +30,7 @@ def __init__( ): super().__init__(**kwargs) self._vocab_size = vocab_size - self.vocab = tf.keras.layers.Dense( + self.vocab = keras.layers.Dense( vocab_size, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, @@ -52,7 +53,7 @@ def compute_output_shape(self, input_shape): return tuple(outputs_shape), tuple(logits_length_shape) -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.ctc") +@keras.utils.register_keras_serializable("tensorflow_asr.models.ctc") class Transformer(CtcModel): def __init__( self, @@ -127,7 +128,7 @@ def make(self, input_shape=[None], prediction_shape=[None], batch_size=None, **k None if self.encoder._memory_length is None else [ - tf.keras.Input(shape=[self.encoder._memory_length, self.encoder._dmodel], batch_size=batch_size, dtype=tf.float32) + keras.Input(shape=[self.encoder._memory_length, self.encoder._dmodel], batch_size=batch_size, dtype=tf.float32) for _ in range(self.encoder._num_blocks) ] ) diff --git a/tensorflow_asr/models/encoders/conformer.py b/tensorflow_asr/models/encoders/conformer.py index 516b5e3d8d..c91b33921e 100644 --- a/tensorflow_asr/models/encoders/conformer.py +++ b/tensorflow_asr/models/encoders/conformer.py @@ -15,6 +15,7 @@ """ http://arxiv.org/abs/2005.08100 """ import tensorflow as tf +import keras from tensorflow_asr.models.activations.glu import GLU from tensorflow_asr.models.base_layer import Identity, Layer @@ -26,7 +27,7 @@ from tensorflow_asr.models.layers.residual import Residual from tensorflow_asr.models.layers.subsampling import Conv1dSubsampling, Conv2dSubsampling, VggSubsampling -L2 = tf.keras.regularizers.l2(1e-6) +L2 = keras.regularizers.l2(1e-6) class FFModule(Layer): @@ -61,12 +62,21 @@ def __init__( ): super().__init__(name=name, **kwargs) assert norm_position in ("pre", "post", "none") + self._config = { + "input_dim": input_dim, + "dropout": dropout, + "scale_factor": scale_factor, + "residual_factor": residual_factor, + "norm_position": norm_position, + "kernel_regularizer": kernel_regularizer, + "bias_regularizer": bias_regularizer, + } self.pre_norm = ( - tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) + keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) if norm_position == "pre" else Identity(name="preiden" if norm_position == "none" else "iden", dtype=self.dtype) ) - self.ffn1 = tf.keras.layers.Dense( + self.ffn1 = keras.layers.Dense( units=scale_factor * input_dim, name="dense_1", kernel_regularizer=kernel_regularizer, @@ -74,17 +84,17 @@ def __init__( activation="swish", dtype=self.dtype, ) - self.do1 = tf.keras.layers.Dropout(rate=dropout, name="dropout_1", dtype=self.dtype) - self.ffn2 = tf.keras.layers.Dense( + self.do1 = keras.layers.Dropout(rate=dropout, name="dropout_1", dtype=self.dtype) + self.ffn2 = keras.layers.Dense( units=input_dim, name="dense_2", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, dtype=self.dtype, ) - self.do2 = tf.keras.layers.Dropout(rate=dropout, name="dropout_2", dtype=self.dtype) + self.do2 = keras.layers.Dropout(rate=dropout, name="dropout_2", dtype=self.dtype) self.post_norm = ( - tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) + keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) if norm_position == "post" else Identity(name="postiden" if norm_position == "none" else "iden", dtype=self.dtype) ) @@ -103,6 +113,11 @@ def call(self, inputs, training=False): def compute_output_shape(self, input_shape): return input_shape + def get_config(self): + config = super().get_config() + config.update(self._config) + return config + class MHSAModule(Layer): r""" @@ -138,8 +153,22 @@ def __init__( super().__init__(name=name, **kwargs) assert norm_position in ("pre", "post", "none") assert mha_type in ("relmha", "mha") + self._config = { + "dmodel": dmodel, + "head_size": head_size, + "num_heads": num_heads, + "residual_factor": residual_factor, + "dropout": dropout, + "mha_type": mha_type, + "relmha_causal": relmha_causal, + "norm_position": norm_position, + "memory_length": memory_length, + "use_attention_bias": use_attention_bias, + "kernel_regularizer": kernel_regularizer, + "bias_regularizer": bias_regularizer, + } self.pre_norm = ( - tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) + keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) if norm_position == "pre" else Identity(name="preiden" if norm_position == "none" else "iden", dtype=self.dtype) ) @@ -167,9 +196,9 @@ def __init__( name="mhsa", dtype=self.dtype, ) - self.do = tf.keras.layers.Dropout(dropout, name="dropout", dtype=self.dtype) + self.do = keras.layers.Dropout(dropout, name="dropout", dtype=self.dtype) self.post_norm = ( - tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) + keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) if norm_position == "post" else Identity(name="postiden" if norm_position == "none" else "iden", dtype=self.dtype) ) @@ -201,6 +230,11 @@ def compute_output_shape(self, input_shape): output_shape, caching_shape, *_ = input_shape return output_shape, caching_shape + def get_config(self): + config = super().get_config() + config.update(self._config) + return config + class ConvModule(Layer): r""" @@ -240,8 +274,20 @@ def __init__( ): super().__init__(name=name, **kwargs) assert norm_position in ("pre", "post", "none") + self._config = { + "input_dim": input_dim, + "kernel_size": kernel_size, + "dropout": dropout, + "padding": padding, + "scale_factor": scale_factor, + "residual_factor": residual_factor, + "norm_position": norm_position, + "use_group_conv": use_group_conv, + "kernel_regularizer": kernel_regularizer, + "bias_regularizer": bias_regularizer, + } self.pre_norm = ( - tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) + keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) if norm_position == "pre" else Identity(name="preiden" if norm_position == "none" else "iden", dtype=self.dtype) ) @@ -278,10 +324,10 @@ def __init__( bias_regularizer=bias_regularizer, dtype=self.dtype, ) - self.bn = tf.keras.layers.BatchNormalization( + self.bn = keras.layers.BatchNormalization( name="bn", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, dtype=self.dtype ) - self.swish = tf.keras.layers.Activation(tf.nn.swish, name="swish", dtype=self.dtype) + self.swish = keras.layers.Activation(tf.nn.swish, name="swish", dtype=self.dtype) self.pw_conv_2 = Conv1D( filters=input_dim, kernel_size=1, @@ -292,9 +338,9 @@ def __init__( bias_regularizer=bias_regularizer, dtype=self.dtype, ) - self.do = tf.keras.layers.Dropout(rate=dropout, name="dropout", dtype=self.dtype) + self.do = keras.layers.Dropout(rate=dropout, name="dropout", dtype=self.dtype) self.post_norm = ( - tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) + keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) if norm_position == "post" else Identity(name="postiden" if norm_position == "none" else "iden", dtype=self.dtype) ) @@ -316,6 +362,11 @@ def call(self, inputs, training=False): def compute_output_shape(self, input_shape): return input_shape + def get_config(self): + config = super().get_config() + config.update(self._config) + return config + class ConformerBlock(Layer): r""" @@ -354,8 +405,30 @@ def __init__( ): super().__init__(name=name, **kwargs) assert block_norm_position in ("pre", "post", "none") + self._config = { + "input_dim": input_dim, + "dropout": dropout, + "ffm_scale_factor": ffm_scale_factor, + "ffm_residual_factor": ffm_residual_factor, + "head_size": head_size, + "num_heads": num_heads, + "mha_type": mha_type, + "mhsam_residual_factor": mhsam_residual_factor, + "mhsam_use_attention_bias": mhsam_use_attention_bias, + "mhsam_causal": mhsam_causal, + "kernel_size": kernel_size, + "padding": padding, + "convm_scale_factor": convm_scale_factor, + "convm_residual_factor": convm_residual_factor, + "convm_use_group_conv": convm_use_group_conv, + "module_norm_position": module_norm_position, + "block_norm_position": block_norm_position, + "memory_length": memory_length, + "kernel_regularizer": kernel_regularizer, + "bias_regularizer": bias_regularizer, + } self.pre_norm = ( - tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) + keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) if block_norm_position == "pre" else Identity(name="preiden" if block_norm_position == "none" else "iden", dtype=self.dtype) ) @@ -412,7 +485,7 @@ def __init__( dtype=self.dtype, ) self.post_norm = ( - tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) + keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) if block_norm_position == "post" else Identity(name="postiden" if block_norm_position == "none" else "iden", dtype=self.dtype) ) @@ -444,6 +517,11 @@ def compute_output_shape(self, input_shape): output_shape, caching_shape, *_ = input_shape return output_shape, caching_shape + def get_config(self): + config = super().get_config() + config.update(self._config) + return config + class ConformerEncoder(Layer): def __init__( @@ -478,6 +556,33 @@ def __init__( ): super().__init__(name=name, **kwargs) assert mha_type in ("relmha", "mha") + self._config = { + "subsampling": subsampling, + "dmodel": dmodel, + "num_blocks": num_blocks, + "mha_type": mha_type, + "head_size": head_size, + "num_heads": num_heads, + "kernel_size": kernel_size, + "padding": padding, + "interleave_relpe": interleave_relpe, + "use_attention_causal_mask": use_attention_causal_mask, + "use_attention_auto_mask": use_attention_auto_mask, + "ffm_scale_factor": ffm_scale_factor, + "ffm_residual_factor": ffm_residual_factor, + "mhsam_residual_factor": mhsam_residual_factor, + "mhsam_use_attention_bias": mhsam_use_attention_bias, + "mhsam_causal": mhsam_causal, + "convm_scale_factor": convm_scale_factor, + "convm_residual_factor": convm_residual_factor, + "convm_use_group_conv": convm_use_group_conv, + "dropout": dropout, + "module_norm_position": module_norm_position, + "block_norm_position": block_norm_position, + "memory_length": memory_length, + "kernel_regularizer": kernel_regularizer, + "bias_regularizer": bias_regularizer, + } self._dmodel = dmodel self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer @@ -502,10 +607,10 @@ def __init__( ) self.time_reduction_factor = self.conv_subsampling.time_reduction_factor - self.linear = tf.keras.layers.Dense( + self.linear = keras.layers.Dense( dmodel, name="linear", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, dtype=self.dtype ) - self.do = tf.keras.layers.Dropout(dropout, name="dropout", dtype=self.dtype) + self.do = keras.layers.Dropout(dropout, name="dropout", dtype=self.dtype) self._mha_type = mha_type self._num_heads = num_heads @@ -638,3 +743,8 @@ def compute_output_shape(self, input_shape): for cblock in self.conformer_blocks: output_shape, caching_shape = cblock.compute_output_shape((output_shape, caching_shape, relative_position_encoding_shape, None, None)) return output_shape, output_length_shape, caching_shape + + def get_config(self): + config = super().get_config() + config.update(self._config) + return config diff --git a/tensorflow_asr/models/encoders/contextnet.py b/tensorflow_asr/models/encoders/contextnet.py index 5804b17255..660d2e0e39 100644 --- a/tensorflow_asr/models/encoders/contextnet.py +++ b/tensorflow_asr/models/encoders/contextnet.py @@ -16,11 +16,12 @@ from typing import List import tensorflow as tf +import keras from tensorflow_asr.models.base_layer import Layer, Reshape from tensorflow_asr.utils import math_util -L2 = tf.keras.regularizers.l2(1e-6) +L2 = keras.regularizers.l2(1e-6) def get_activation( @@ -32,7 +33,7 @@ def get_activation( if activation == "relu": return tf.nn.relu if activation == "linear": - return tf.keras.activations.linear + return keras.activations.linear raise ValueError("activation must be either 'silu', 'swish', 'relu' or 'linear'") @@ -50,7 +51,7 @@ def __init__( ): super().__init__(**kwargs) self.strides = strides - self.conv = tf.keras.layers.SeparableConv1D( + self.conv = keras.layers.SeparableConv1D( filters=filters, kernel_size=kernel_size, strides=strides, @@ -61,7 +62,7 @@ def __init__( name="conv", dtype=self.dtype, ) - self.bn = tf.keras.layers.BatchNormalization( + self.bn = keras.layers.BatchNormalization( name="bn", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, dtype=self.dtype ) self.activation = get_activation(activation) @@ -116,16 +117,16 @@ def __init__( name="conv_module", dtype=self.dtype, ) - self.global_avg_pool = tf.keras.layers.GlobalAveragePooling1D(keepdims=True, name="global_avg_pool", dtype=self.dtype) + self.global_avg_pool = keras.layers.GlobalAveragePooling1D(keepdims=True, name="global_avg_pool", dtype=self.dtype) self.activation = get_activation(activation) - self.fc1 = tf.keras.layers.Dense( + self.fc1 = keras.layers.Dense( filters // 8, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name="fc1", dtype=self.dtype, ) - self.fc2 = tf.keras.layers.Dense( + self.fc2 = keras.layers.Dense( filters, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, diff --git a/tensorflow_asr/models/encoders/deepspeech2.py b/tensorflow_asr/models/encoders/deepspeech2.py index f3f81f14c7..b67e261236 100644 --- a/tensorflow_asr/models/encoders/deepspeech2.py +++ b/tensorflow_asr/models/encoders/deepspeech2.py @@ -13,6 +13,7 @@ # limitations under the License. import tensorflow as tf +import keras from tensorflow_asr.models.base_layer import Identity, Layer, Reshape from tensorflow_asr.models.layers.convolution import DepthwiseConv1D @@ -44,13 +45,13 @@ def __init__( name="conv", dtype=self.dtype, ) - self.bn = tf.keras.layers.BatchNormalization( + self.bn = keras.layers.BatchNormalization( name="bn", gamma_regularizer=regularizer, beta_regularizer=regularizer, dtype=self.dtype, ) - self.activation = tf.keras.activations.get(activation) + self.activation = keras.activations.get(activation) def call(self, inputs, training=False): outputs = self.conv(inputs, training=training) @@ -92,11 +93,11 @@ def __init__( bias_initializer=initializer, dtype=self.dtype, ) - self.bn = tf.keras.layers.BatchNormalization( + self.bn = keras.layers.BatchNormalization( name="bn", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, dtype=self.dtype ) - self.act = tf.keras.layers.Activation(activation=activation, dtype=self.dtype) - self.do = tf.keras.layers.Dropout(dropout, name="dropout", dtype=self.dtype) + self.act = keras.layers.Activation(activation=activation, dtype=self.dtype) + self.do = keras.layers.Dropout(dropout, name="dropout", dtype=self.dtype) self.time_reduction_factor = self.conv.strides[0] def call(self, inputs, training=False): @@ -236,8 +237,8 @@ def __init__( ) self._bidirectional = bidirectional if bidirectional: - self.rnn = tf.keras.layers.Bidirectional(self.rnn, name=f"b{rnn_type}", dtype=self.dtype) - self.bn = tf.keras.layers.BatchNormalization( + self.rnn = keras.layers.Bidirectional(self.rnn, name=f"b{rnn_type}", dtype=self.dtype) + self.bn = keras.layers.BatchNormalization( name="bn", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, dtype=self.dtype ) self.rowconv = None @@ -371,7 +372,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.fc = tf.keras.layers.Dense( + self.fc = keras.layers.Dense( units, kernel_regularizer=kernel_regularizer, kernel_initializer=initializer, @@ -380,11 +381,11 @@ def __init__( name="fc", dtype=self.dtype, ) - self.bn = tf.keras.layers.BatchNormalization( + self.bn = keras.layers.BatchNormalization( name="bn", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, dtype=self.dtype ) - self.act = tf.keras.layers.Activation(activation=activation, dtype=self.dtype) - self.do = tf.keras.layers.Dropout(dropout, name="dropout", dtype=self.dtype) + self.act = keras.layers.Activation(activation=activation, dtype=self.dtype) + self.do = keras.layers.Dropout(dropout, name="dropout", dtype=self.dtype) def call(self, inputs, training=False): outputs, outputs_length = inputs diff --git a/tensorflow_asr/models/encoders/jasper.py b/tensorflow_asr/models/encoders/jasper.py index 723c02525e..7597af24ba 100644 --- a/tensorflow_asr/models/encoders/jasper.py +++ b/tensorflow_asr/models/encoders/jasper.py @@ -13,13 +13,14 @@ # limitations under the License. import tensorflow as tf +import keras from tensorflow_asr.models.base_layer import Layer, Reshape from tensorflow_asr.models.layers.convolution import Conv1D from tensorflow_asr.utils import math_util -class JasperSubBlock(tf.keras.layers.Layer): +class JasperSubBlock(keras.layers.Layer): def __init__( self, channels: int = 256, @@ -44,11 +45,11 @@ def __init__( name="conv1d", dtype=self.dtype, ) - self.bn = tf.keras.layers.BatchNormalization( + self.bn = keras.layers.BatchNormalization( name="bn", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, dtype=self.dtype ) - self.relu = tf.keras.layers.ReLU(name="relu", dtype=self.dtype) - self.do = tf.keras.layers.Dropout(dropout, name="dropout", dtype=self.dtype) + self.relu = keras.layers.ReLU(name="relu", dtype=self.dtype) + self.do = keras.layers.Dropout(dropout, name="dropout", dtype=self.dtype) self.reduction_factor = strides def call(self, inputs, training=False): @@ -60,7 +61,7 @@ def call(self, inputs, training=False): return outputs -class JasperResidual(tf.keras.layers.Layer): +class JasperResidual(keras.layers.Layer): def __init__( self, channels: int = 256, @@ -80,7 +81,7 @@ def __init__( name="pointwise_conv1d", dtype=self.dtype, ) - self.bn = tf.keras.layers.BatchNormalization( + self.bn = keras.layers.BatchNormalization( name="bn", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, dtype=self.dtype ) @@ -128,7 +129,7 @@ def __init__( for i in range(nresiduals) ] - self.add = tf.keras.layers.Add(name="add") + self.add = keras.layers.Add(name="add") def call(self, inputs, training=False): outputs, residuals = inputs @@ -142,7 +143,7 @@ def call(self, inputs, training=False): return outputs -class JasperBlock(tf.keras.layers.Layer): +class JasperBlock(keras.layers.Layer): def __init__( self, nsubblocks: int = 3, diff --git a/tensorflow_asr/models/encoders/rnnt.py b/tensorflow_asr/models/encoders/rnnt.py index bdea97ba7e..7dcd24880b 100644 --- a/tensorflow_asr/models/encoders/rnnt.py +++ b/tensorflow_asr/models/encoders/rnnt.py @@ -14,6 +14,7 @@ """ http://arxiv.org/abs/1811.06621 """ import tensorflow as tf +import keras from tensorflow_asr.models.base_layer import Layer, Reshape from tensorflow_asr.models.layers.subsampling import TimeReduction @@ -23,6 +24,7 @@ class RnnTransducerBlock(Layer): def __init__( self, + reduction_position: str = "pre", reduction_factor: int = 0, dmodel: int = 640, rnn_type: str = "lstm", @@ -34,6 +36,8 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + assert reduction_position in ["post", "pre"], "reduction_position must be 'post' or 'pre'" + self._reduction_position = reduction_position self.reduction = TimeReduction(reduction_factor, name="reduction", dtype=self.dtype) if reduction_factor > 0 else None self.rnn = layer_util.get_rnn(rnn_type)( units=rnn_units, @@ -47,11 +51,11 @@ def __init__( dtype=self.dtype, ) self.ln = ( - tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, dtype=self.dtype) + keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, dtype=self.dtype) if layer_norm else None ) - self.projection = tf.keras.layers.Dense( + self.projection = keras.layers.Dense( dmodel, name="projection", kernel_regularizer=kernel_regularizer, @@ -61,12 +65,16 @@ def __init__( def call(self, inputs, training=False): outputs, outputs_length = inputs - if self.reduction is not None: - outputs, outputs_length = self.reduction((outputs, outputs_length)) + if self._reduction_position == "pre": + if self.reduction is not None: + outputs, outputs_length = self.reduction((outputs, outputs_length)) outputs, *_ = self.rnn(outputs, training=training) if self.ln is not None: outputs = self.ln(outputs, training=training) outputs = self.projection(outputs, training=training) + if self._reduction_position == "post": + if self.reduction is not None: + outputs, outputs_length = self.reduction((outputs, outputs_length)) return outputs, outputs_length def compute_mask(self, inputs, mask=None): @@ -89,6 +97,9 @@ def call_next(self, inputs, inputs_length, previous_encoder_states): """ with tf.name_scope(f"{self.name}_call_next"): outputs, outputs_length = inputs, inputs_length + if self._reduction_position == "pre": + if self.reduction is not None: + outputs, outputs_length = self.reduction([outputs, outputs_length]) outputs, *_states = self.rnn( outputs, training=False, @@ -98,9 +109,10 @@ def call_next(self, inputs, inputs_length, previous_encoder_states): new_states = tf.stack(_states, axis=0) if self.ln is not None: outputs = self.ln(outputs, training=False) - if self.reduction is not None: - outputs, outputs_length = self.reduction([outputs, outputs_length]) outputs = self.projection(outputs, training=False) + if self._reduction_position == "post": + if self.reduction is not None: + outputs, outputs_length = self.reduction([outputs, outputs_length]) return outputs, outputs_length, new_states def compute_output_shape(self, input_shape): @@ -114,6 +126,7 @@ def compute_output_shape(self, input_shape): class RnnTransducerEncoder(Layer): def __init__( self, + reduction_positions: list = ["pre", "pre", "pre", "pre", "pre", "pre", "pre", "pre"], reduction_factors: list = [6, 0, 0, 0, 0, 0, 0, 0], dmodel: int = 640, nlayers: int = 8, @@ -126,6 +139,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + assert len(reduction_positions) == nlayers, "reduction_positions length must be equal to nlayers" assert len(reduction_factors) == nlayers, "reduction_factors length must be equal to nlayers" self.reshape = Reshape(name="reshape", dtype=self.dtype) @@ -133,6 +147,7 @@ def __init__( self.blocks = [] for i in range(nlayers): block = RnnTransducerBlock( + reduction_position=reduction_positions[i], reduction_factor=reduction_factors[i], dmodel=dmodel, rnn_type=rnn_type, @@ -151,12 +166,13 @@ def get_initial_state(self, batch_size=1): """Get zeros states Returns: - tf.Tensor: states having shape [num_rnns, 1 or 2, 1, P] + tf.Tensor, shape [B, num_rnns, nstates, state_size] + Zero initialized states """ states = [] for block in self.blocks: states.append(tf.stack(block.rnn.get_initial_state(tf.zeros([batch_size, 1, 1], dtype=self.dtype)), axis=0)) - return tf.stack(states, axis=0) + return tf.transpose(tf.stack(states, axis=0), perm=[2, 0, 1, 3]) def call(self, inputs, training=False): outputs, outputs_length, caching = inputs diff --git a/tensorflow_asr/models/encoders/transformer.py b/tensorflow_asr/models/encoders/transformer.py index 25d961ee2a..e4e2b0c38e 100644 --- a/tensorflow_asr/models/encoders/transformer.py +++ b/tensorflow_asr/models/encoders/transformer.py @@ -14,6 +14,7 @@ # limitations under the License. import tensorflow as tf +import keras from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.models.layers.multihead_attention import MultiHeadAttention, MultiHeadRelativeAttention @@ -33,7 +34,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.ffn1 = tf.keras.layers.Dense( + self.ffn1 = keras.layers.Dense( units=dff, activation=activation, kernel_regularizer=kernel_regularizer, @@ -41,7 +42,7 @@ def __init__( name="ffn_1", dtype=self.dtype, ) - self.ffn2 = tf.keras.layers.Dense( + self.ffn2 = keras.layers.Dense( units=dmodel, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, @@ -82,7 +83,7 @@ def __init__( self.norm1 = ( None if self._norm_position == "none" - else tf.keras.layers.LayerNormalization( + else keras.layers.LayerNormalization( beta_regularizer=kernel_regularizer, gamma_regularizer=bias_regularizer, name="ln_1", dtype=self.dtype ) ) @@ -111,12 +112,12 @@ def __init__( dtype=self.dtype, ) ) - self.do1 = tf.keras.layers.Dropout(dropout, name="do_1", dtype=self.dtype) + self.do1 = keras.layers.Dropout(dropout, name="do_1", dtype=self.dtype) self.residual1 = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual_1", dtype=self.dtype) self.norm2 = ( None if self._norm_position == "none" - else tf.keras.layers.LayerNormalization( + else keras.layers.LayerNormalization( beta_regularizer=kernel_regularizer, gamma_regularizer=bias_regularizer, name="ln_2", dtype=self.dtype ) ) @@ -129,7 +130,7 @@ def __init__( name="pwffn", dtype=self.dtype, ) - self.do2 = tf.keras.layers.Dropout(dropout, name="do_2", dtype=self.dtype) + self.do2 = keras.layers.Dropout(dropout, name="do_2", dtype=self.dtype) self.residual2 = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual_2", dtype=self.dtype) def call( @@ -213,14 +214,14 @@ def __init__( dtype=self.dtype, ) self.time_reduction_factor = self.subsampling.time_reduction_factor - self.linear = tf.keras.layers.Dense( + self.linear = keras.layers.Dense( units=dmodel, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name="linear", dtype=self.dtype, ) - self.do = tf.keras.layers.Dropout(dropout, name="dropout", dtype=self.dtype) + self.do = keras.layers.Dropout(dropout, name="dropout", dtype=self.dtype) if mha_type == "relmha": self.relpe = RelativeSinusoidalPositionalEncoding( diff --git a/tensorflow_asr/models/layers/embedding.py b/tensorflow_asr/models/layers/embedding.py index cba7362aa2..2bc27c7e0f 100644 --- a/tensorflow_asr/models/layers/embedding.py +++ b/tensorflow_asr/models/layers/embedding.py @@ -13,11 +13,12 @@ # limitations under the License. import tensorflow as tf +import keras from tensorflow_asr.models.base_layer import Layer -class Embedding(tf.keras.layers.Embedding): +class Embedding(keras.layers.Embedding): def __init__( self, vocab_size, diff --git a/tensorflow_asr/models/layers/feature_extraction.py b/tensorflow_asr/models/layers/feature_extraction.py index f651bfe78e..74b4173acf 100644 --- a/tensorflow_asr/models/layers/feature_extraction.py +++ b/tensorflow_asr/models/layers/feature_extraction.py @@ -131,7 +131,8 @@ def __init__( self.padding = padding self.nfft = self.frame_length if nfft is None else nfft - self.augmentations = Augmentation(augmentation_config) + self._augmentation_config = augmentation_config + self.augmentations = Augmentation(self._augmentation_config) # ---------------------------------- signals --------------------------------- # @@ -293,3 +294,29 @@ def compute_output_shape(self, input_shape): else: output_shape = [B, self.get_nframes(nsamples + self.padding), self.num_feature_bins, 1] return tf.TensorShape(output_shape), tf.TensorShape(signal_length_shape) + + def get_config(self): + config = super().get_config() + config.update( + { + "sample_rate": self.sample_rate, + "frame_ms": self.frame_ms, + "stride_ms": self.stride_ms, + "num_feature_bins": self.num_feature_bins, + "feature_type": self.feature_type, + "preemphasis": self.preemphasis, + "pad_end": self.pad_end, + "use_librosa_like_stft": self.use_librosa_like_stft, + "output_floor": self.output_floor, + "lower_edge_hertz": self.lower_edge_hertz, + "upper_edge_hertz": self.upper_edge_hertz, + "log_base": self.log_base, + "nfft": self.nfft, + "normalize_signal": self._normalize_signal, + "normalize_zscore": self._normalize_zscore, + "normalize_min_max": self._normalize_min_max, + "padding": self.padding, + "augmentation_config": self._augmentations.config, + } + ) + return config diff --git a/tensorflow_asr/models/layers/memory.py b/tensorflow_asr/models/layers/memory.py index f473213b73..dc624fbaac 100644 --- a/tensorflow_asr/models/layers/memory.py +++ b/tensorflow_asr/models/layers/memory.py @@ -121,7 +121,7 @@ def call(self, inputs, memories=None): # memory = tf.zeros(shape=(self.batch_size, self.memory_length, self.dmodel), dtype=self.dtype) # if memory_mask is None: # memory_mask = tf.zeros(shape=(self.batch_size, self.memory_length), dtype=tf.bool) - # self.add_update([tf.keras.backend.update(self.memory, memory), tf.keras.backend.update(self.memory_mask, memory_mask)]) + # self.add_update([keras.backend.update(self.memory, memory), keras.backend.update(self.memory_mask, memory_mask)]) # def call(self, inputs): # inputs, inputs_mask = self._get_inputs(inputs) @@ -147,7 +147,7 @@ def call(self, inputs, memories=None): # begin=[0, tf.shape(new_memory_mask)[1] - self.memory_length], # size=[-1, self.memory_length], # ) - # self.add_update([tf.keras.backend.update(self.memory, new_memory), tf.keras.backend.update(self.memory_mask, new_memory_mask)]) + # self.add_update([keras.backend.update(self.memory, new_memory), keras.backend.update(self.memory_mask, new_memory_mask)]) # new_memory._keras_mask = new_memory_mask # pylint: disable=protected-access # return new_memory diff --git a/tensorflow_asr/models/layers/multihead_attention.py b/tensorflow_asr/models/layers/multihead_attention.py index 70e3d54388..7720a7ae65 100644 --- a/tensorflow_asr/models/layers/multihead_attention.py +++ b/tensorflow_asr/models/layers/multihead_attention.py @@ -17,6 +17,7 @@ import math import tensorflow as tf +import keras from keras.layers import EinsumDense from keras.layers import MultiHeadAttention as KerasMultiHeadAttention @@ -244,8 +245,8 @@ def _build_attention(self, rank): attn_scores_rank, ) = mha_module._build_attention_equation(rank, attn_axes=self._attention_axes) norm_axes = tuple(range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) - self._softmax = tf.keras.layers.Softmax(axis=norm_axes, dtype=self.dtype) # stable training - self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout, dtype=self.dtype) + self._softmax = keras.layers.Softmax(axis=norm_axes, dtype=self.dtype) # stable training + self._dropout_layer = keras.layers.Dropout(rate=self._dropout, dtype=self.dtype) def _masked_softmax(self, attention_scores, attention_mask=None): # Normalize the attention scores to probabilities. diff --git a/tensorflow_asr/models/layers/positional_encoding.py b/tensorflow_asr/models/layers/positional_encoding.py index 53f0b05543..db97c19481 100755 --- a/tensorflow_asr/models/layers/positional_encoding.py +++ b/tensorflow_asr/models/layers/positional_encoding.py @@ -14,6 +14,7 @@ # limitations under the License. import tensorflow as tf +import keras from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.utils import shape_util @@ -61,7 +62,7 @@ def __init__( **kwargs, ): super().__init__(trainable=False, **kwargs) - self.do = tf.keras.layers.Dropout(dropout, dtype=self.dtype, name="dropout") + self.do = keras.layers.Dropout(dropout, dtype=self.dtype, name="dropout") self._scale = scale self._interleave = interleave diff --git a/tensorflow_asr/models/layers/residual.py b/tensorflow_asr/models/layers/residual.py index 63da2e7eda..5de0b4fc18 100644 --- a/tensorflow_asr/models/layers/residual.py +++ b/tensorflow_asr/models/layers/residual.py @@ -15,6 +15,7 @@ from typing import Optional import tensorflow as tf +import keras from tensorflow_asr.models.base_layer import Layer @@ -30,8 +31,8 @@ class Residual(Layer): def __init__( self, factor="rezero", - initializer: tf.keras.initializers.Initializer = "zeros", - regularizer: Optional[tf.keras.regularizers.Regularizer] = None, + initializer: keras.initializers.Initializer = "zeros", + regularizer: Optional[keras.regularizers.Regularizer] = None, name="residual", **kwargs, ): diff --git a/tensorflow_asr/models/layers/sequence_wise_bn.py b/tensorflow_asr/models/layers/sequence_wise_bn.py index 96c6469324..e03f8e382a 100644 --- a/tensorflow_asr/models/layers/sequence_wise_bn.py +++ b/tensorflow_asr/models/layers/sequence_wise_bn.py @@ -12,16 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import tensorflow as tf +import keras # https://arxiv.org/abs/1510.01378 -class SequenceBatchNorm(tf.keras.layers.Layer): +class SequenceBatchNorm(keras.layers.Layer): def __init__(self, name, time_major=False, gamma_regularizer=None, beta_regularizer=None, **kwargs): super().__init__(name=name, **kwargs) self.time_major = time_major - self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer) - self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer) + self.gamma_regularizer = keras.regularizers.get(gamma_regularizer) + self.beta_regularizer = keras.regularizers.get(beta_regularizer) def build( self, @@ -53,12 +55,12 @@ def call( ): mean, variance = tf.nn.moments(inputs, axes=[0, 1], keepdims=False) if self.time_major: - total_padded_frames = tf.cast(tf.shape(inputs)[0], tf.keras.backend.dtype(mean)) - batch_size = tf.cast(tf.shape(inputs)[1], tf.keras.backend.dtype(mean)) + total_padded_frames = tf.cast(tf.shape(inputs)[0], keras.backend.dtype(mean)) + batch_size = tf.cast(tf.shape(inputs)[1], keras.backend.dtype(mean)) else: - total_padded_frames = tf.cast(tf.shape(inputs)[1], tf.keras.backend.dtype(mean)) - batch_size = tf.cast(tf.shape(inputs)[0], tf.keras.backend.dtype(mean)) - total_unpadded_frames_batch = tf.math.count_nonzero(inputs, axis=[0, 1], keepdims=False, dtype=tf.keras.backend.dtype(mean)) + total_padded_frames = tf.cast(tf.shape(inputs)[1], keras.backend.dtype(mean)) + batch_size = tf.cast(tf.shape(inputs)[0], keras.backend.dtype(mean)) + total_unpadded_frames_batch = tf.math.count_nonzero(inputs, axis=[0, 1], keepdims=False, dtype=keras.backend.dtype(mean)) mean = (mean * total_padded_frames * batch_size) / total_unpadded_frames_batch variance = (variance * total_padded_frames * batch_size) / total_unpadded_frames_batch return tf.nn.batch_normalization( @@ -67,5 +69,5 @@ def call( variance=variance, offset=self.beta, scale=self.gamma, - variance_epsilon=tf.keras.backend.epsilon(), + variance_epsilon=keras.backend.epsilon(), ) diff --git a/tensorflow_asr/models/layers/subsampling.py b/tensorflow_asr/models/layers/subsampling.py index 80d976d039..3328e4f0d9 100644 --- a/tensorflow_asr/models/layers/subsampling.py +++ b/tensorflow_asr/models/layers/subsampling.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing import tensorflow as tf +import keras from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.models.layers.convolution import Conv1D, Conv2D @@ -53,10 +55,10 @@ def compute_output_shape(self, input_shape): class VggSubsampling(Layer): def __init__( self, - filters: tuple or list = (32, 64), - kernel_size: int or list or tuple = 3, - pool_size: int or list or tuple = 2, - strides: int or list or tuple = 2, + filters: typing.Union[tuple, list] = (32, 64), + kernel_size: typing.Union[int, list, tuple] = 3, + pool_size: typing.Union[int, list, tuple] = 2, + strides: typing.Union[int, list, tuple] = 2, padding: str = "same", activation: str = "relu", kernel_regularizer=None, @@ -87,7 +89,7 @@ def __init__( activation=activation, dtype=self.dtype, ) - self.maxpool1 = tf.keras.layers.MaxPool2D(pool_size=pool_size, strides=strides, padding=padding, dtype=self.dtype, name="maxpool_1") + self.maxpool1 = keras.layers.MaxPool2D(pool_size=pool_size, strides=strides, padding=padding, dtype=self.dtype, name="maxpool_1") self.conv3 = Conv2D( filters=filters[1], kernel_size=kernel_size, @@ -110,7 +112,7 @@ def __init__( activation=activation, dtype=self.dtype, ) - self.maxpool2 = tf.keras.layers.MaxPool2D(pool_size=pool_size, strides=strides, padding=padding, dtype=self.dtype, name="maxpool_2") + self.maxpool2 = keras.layers.MaxPool2D(pool_size=pool_size, strides=strides, padding=padding, dtype=self.dtype, name="maxpool_2") self.time_reduction_factor = self.maxpool1.pool_size[0] * self.maxpool2.pool_size[0] def call(self, inputs, training=False): @@ -169,7 +171,7 @@ def __init__( self.convs = [] self.time_reduction_factor = 1 for i in range(len(filters)): - subblock = tf.keras.Sequential(name=f"block_{i}") + subblock = keras.Sequential(name=f"block_{i}") subblock.add( Conv2D( filters=filters[i], @@ -184,7 +186,7 @@ def __init__( ) if norms[i] == "batch": subblock.add( - tf.keras.layers.BatchNormalization( + keras.layers.BatchNormalization( name=f"bn_{i}", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, @@ -193,14 +195,14 @@ def __init__( ) elif norms[i] == "layer": subblock.add( - tf.keras.layers.LayerNormalization( + keras.layers.LayerNormalization( name=f"ln_{i}", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, dtype=self.dtype, ) ) - subblock.add(tf.keras.layers.Activation(activations[i], name=f"{activations[i]}_{i}", dtype=self.dtype)) + subblock.add(keras.layers.Activation(activations[i], name=f"{activations[i]}_{i}", dtype=self.dtype)) self.convs.append(subblock) self.time_reduction_factor *= subblock.layers[0].strides[0] @@ -257,7 +259,7 @@ def __init__( self.convs = [] self.time_reduction_factor = 1 for i in range(len(filters)): - subblock = tf.keras.Sequential(name=f"block_{i}") + subblock = keras.Sequential(name=f"block_{i}") subblock.add( Conv1D( filters=filters[i], @@ -272,7 +274,7 @@ def __init__( ) if norms[i] == "batch": subblock.add( - tf.keras.layers.BatchNormalization( + keras.layers.BatchNormalization( name=f"bn_{i}", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, @@ -281,14 +283,14 @@ def __init__( ) elif norms[i] == "layer": subblock.add( - tf.keras.layers.LayerNormalization( + keras.layers.LayerNormalization( name=f"ln_{i}", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, dtype=self.dtype, ) ) - subblock.add(tf.keras.layers.Activation(activations[i], name=f"{activations[i]}_{i}", dtype=self.dtype)) + subblock.add(keras.layers.Activation(activations[i], name=f"{activations[i]}_{i}", dtype=self.dtype)) self.convs.append(subblock) self.time_reduction_factor *= subblock.layers[0].strides[0] diff --git a/tensorflow_asr/models/transducer/base_transducer.py b/tensorflow_asr/models/transducer/base_transducer.py index b1eb075e9e..11084ee762 100644 --- a/tensorflow_asr/models/transducer/base_transducer.py +++ b/tensorflow_asr/models/transducer/base_transducer.py @@ -17,6 +17,7 @@ import collections import tensorflow as tf +import keras from tensorflow_asr import schemas from tensorflow_asr.losses.rnnt_loss import RnntLoss @@ -53,6 +54,21 @@ def __init__( ): super().__init__(name=name, **kwargs) assert label_encoder_mode in ("one_hot", "embedding"), "label_encode_mode must be either 'one_hot' or 'embedding'" + self._config = { + "blank": blank, + "vocab_size": vocab_size, + "label_encoder_mode": label_encoder_mode, + "embed_dim": embed_dim, + "num_rnns": num_rnns, + "rnn_units": rnn_units, + "rnn_type": rnn_type, + "rnn_implementation": rnn_implementation, + "rnn_unroll": rnn_unroll, + "layer_norm": layer_norm, + "projection_units": projection_units, + "kernel_regularizer": kernel_regularizer, + "bias_regularizer": bias_regularizer, + } self.label_encoder = ( Embedding(vocab_size, embed_dim, regularizer=kernel_regularizer, name=label_encoder_mode, dtype=self.dtype) if label_encoder_mode == "embedding" @@ -77,14 +93,14 @@ def __init__( dtype=self.dtype, ) ln = ( - tf.keras.layers.LayerNormalization( + keras.layers.LayerNormalization( name=f"ln_{i}", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, dtype=self.dtype ) if layer_norm else None ) projection = ( - tf.keras.layers.Dense( + keras.layers.Dense( projection_units, name=f"projection_{i}", kernel_regularizer=kernel_regularizer, @@ -164,6 +180,11 @@ def compute_output_shape(self, input_shape): ) return tuple(output_shape), tuple(output_length_shape) + def get_config(self): + config = super().get_config() + config.update(self._config) + return config + class TransducerJointMerge(Layer): def __init__(self, joint_mode: str = "add", name="transducer_joint_merge", **kwargs): @@ -201,6 +222,11 @@ def compute_output_shape(self, input_shape): enc_shape, pred_shape = input_shape return enc_shape[0], enc_shape[1], pred_shape[1], enc_shape[-1] + def get_config(self): + config = super().get_config() + config.update({"joint_mode": self.joint_mode}) + return config + class TransducerJoint(Layer): def __init__( @@ -219,12 +245,24 @@ def __init__( ): super().__init__(name=name, **kwargs) + self._config = { + "vocab_size": vocab_size, + "joint_dim": joint_dim, + "activation": activation, + "prejoint_encoder_linear": prejoint_encoder_linear, + "prejoint_prediction_linear": prejoint_prediction_linear, + "postjoint_linear": postjoint_linear, + "joint_mode": joint_mode, + "kernel_regularizer": kernel_regularizer, + "bias_regularizer": bias_regularizer, + } + self.prejoint_encoder_linear = prejoint_encoder_linear self.prejoint_prediction_linear = prejoint_prediction_linear self.postjoint_linear = postjoint_linear if self.prejoint_encoder_linear: - self.ffn_enc = tf.keras.layers.Dense( + self.ffn_enc = keras.layers.Dense( joint_dim, name="enc", kernel_regularizer=kernel_regularizer, @@ -232,7 +270,7 @@ def __init__( dtype=self.dtype, ) if self.prejoint_prediction_linear: - self.ffn_pred = tf.keras.layers.Dense( + self.ffn_pred = keras.layers.Dense( joint_dim, use_bias=False, name="pred", @@ -243,10 +281,10 @@ def __init__( self.joint = TransducerJointMerge(joint_mode=joint_mode, name="merge", dtype=self.dtype) activation = activation.lower() - self.activation = tf.keras.layers.Activation(activation, name=activation, dtype=self.dtype) + self.activation = keras.layers.Activation(activation, name=activation, dtype=self.dtype) if self.postjoint_linear: - self.ffn = tf.keras.layers.Dense( + self.ffn = keras.layers.Dense( joint_dim, name="ffn", kernel_regularizer=kernel_regularizer, @@ -254,7 +292,7 @@ def __init__( dtype=self.dtype, ) - self.ffn_out = tf.keras.layers.Dense( + self.ffn_out = keras.layers.Dense( vocab_size, name="vocab", kernel_regularizer=kernel_regularizer, @@ -286,6 +324,11 @@ def compute_output_shape(self, input_shape): encoder_time_shape, prediction_time_shape = encoder_shape[1], prediction_shape[1] return batch_shape, encoder_time_shape, prediction_time_shape, self.ffn_out.units + def get_config(self): + config = super().get_config() + config.update(self._config) + return config + class Transducer(BaseModel): """Transducer Model Warper""" @@ -295,7 +338,7 @@ def __init__( blank: int, vocab_size: int, speech_config: dict, - encoder: tf.keras.layers.Layer, + encoder: keras.layers.Layer, prediction_label_encoder_mode: str = "embedding", prediction_embed_dim: int = 512, prediction_num_rnns: int = 1, @@ -444,9 +487,6 @@ def call_next( ytu = tf.nn.log_softmax(ytu) return ytu, new_states - def get_initial_tokens(self, batch_size=1): - return super().get_initial_tokens(batch_size) - def get_initial_encoder_states(self, batch_size=1): return tf.zeros([], dtype=self.dtype) diff --git a/tensorflow_asr/models/transducer/conformer.py b/tensorflow_asr/models/transducer/conformer.py index 1b90e02099..3c87a50137 100644 --- a/tensorflow_asr/models/transducer/conformer.py +++ b/tensorflow_asr/models/transducer/conformer.py @@ -14,12 +14,13 @@ import tensorflow as tf +import keras from tensorflow_asr.models.encoders.conformer import L2, ConformerEncoder from tensorflow_asr.models.transducer.base_transducer import Transducer -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.transducer") +@keras.utils.register_keras_serializable("tensorflow_asr.models.transducer") class Conformer(Transducer): def __init__( self, @@ -71,6 +72,53 @@ def __init__( name: str = "conformer", **kwargs, ): + self._config = { + "blank": blank, + "vocab_size": vocab_size, + "speech_config": speech_config, + "encoder_subsampling": encoder_subsampling, + "encoder_dmodel": encoder_dmodel, + "encoder_num_blocks": encoder_num_blocks, + "encoder_head_size": encoder_head_size, + "encoder_num_heads": encoder_num_heads, + "encoder_mha_type": encoder_mha_type, + "encoder_interleave_relpe": encoder_interleave_relpe, + "encoder_use_attention_causal_mask": encoder_use_attention_causal_mask, + "encoder_use_attention_auto_mask": encoder_use_attention_auto_mask, + "encoder_kernel_size": encoder_kernel_size, + "encoder_padding": encoder_padding, + "encoder_ffm_scale_factor": encoder_ffm_scale_factor, + "encoder_ffm_residual_factor": encoder_ffm_residual_factor, + "encoder_mhsam_residual_factor": encoder_mhsam_residual_factor, + "encoder_mhsam_use_attention_bias": encoder_mhsam_use_attention_bias, + "encoder_convm_scale_factor": encoder_convm_scale_factor, + "encoder_convm_residual_factor": encoder_convm_residual_factor, + "encoder_convm_use_group_conv": encoder_convm_use_group_conv, + "encoder_dropout": encoder_dropout, + "encoder_module_norm_position": encoder_module_norm_position, + "encoder_block_norm_position": encoder_block_norm_position, + "encoder_memory_length": encoder_memory_length, + "encoder_trainable": encoder_trainable, + "prediction_label_encode_mode": prediction_label_encode_mode, + "prediction_embed_dim": prediction_embed_dim, + "prediction_num_rnns": prediction_num_rnns, + "prediction_rnn_units": prediction_rnn_units, + "prediction_rnn_type": prediction_rnn_type, + "prediction_rnn_implementation": prediction_rnn_implementation, + "prediction_rnn_unroll": prediction_rnn_unroll, + "prediction_layer_norm": prediction_layer_norm, + "prediction_projection_units": prediction_projection_units, + "prediction_trainable": prediction_trainable, + "joint_dim": joint_dim, + "joint_activation": joint_activation, + "prejoint_encoder_linear": prejoint_encoder_linear, + "prejoint_prediction_linear": prejoint_prediction_linear, + "postjoint_linear": postjoint_linear, + "joint_mode": joint_mode, + "joint_trainable": joint_trainable, + "kernel_regularizer": kernel_regularizer, + "bias_regularizer": bias_regularizer, + } super().__init__( speech_config=speech_config, encoder=ConformerEncoder( @@ -136,8 +184,13 @@ def make(self, input_shape=[None], prediction_shape=[None], batch_size=None, **k None if self.encoder._memory_length is None else [ - tf.keras.Input(shape=[self.encoder._memory_length, self.encoder._dmodel], batch_size=batch_size, dtype=tf.float32) + keras.Input(shape=[self.encoder._memory_length, self.encoder._dmodel], batch_size=batch_size, dtype=tf.float32) for _ in range(self.encoder._num_blocks) ] ) return super().make(input_shape, prediction_shape, batch_size, caching, **kwargs) + + def get_config(self): + config = super().get_config() + config.update(self._config) + return config diff --git a/tensorflow_asr/models/transducer/contextnet.py b/tensorflow_asr/models/transducer/contextnet.py index 1dc1dad7ab..57cc9a923c 100644 --- a/tensorflow_asr/models/transducer/contextnet.py +++ b/tensorflow_asr/models/transducer/contextnet.py @@ -14,13 +14,13 @@ from typing import List -import tensorflow as tf +import keras from tensorflow_asr.models.encoders.contextnet import L2, ContextNetEncoder from tensorflow_asr.models.transducer.base_transducer import Transducer -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.transducer") +@keras.utils.register_keras_serializable("tensorflow_asr.models.transducer") class ContextNet(Transducer): def __init__( self, diff --git a/tensorflow_asr/models/transducer/rnnt.py b/tensorflow_asr/models/transducer/rnnt.py index 248d83a537..9b147a717d 100644 --- a/tensorflow_asr/models/transducer/rnnt.py +++ b/tensorflow_asr/models/transducer/rnnt.py @@ -13,19 +13,20 @@ # limitations under the License. """ http://arxiv.org/abs/1811.06621 """ -import tensorflow as tf +import keras from tensorflow_asr.models.encoders.rnnt import RnnTransducerEncoder from tensorflow_asr.models.transducer.base_transducer import Transducer -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.transducer") +@keras.utils.register_keras_serializable("tensorflow_asr.models.transducer") class RnnTransducer(Transducer): def __init__( self, blank: int, vocab_size: int, speech_config: dict, + encoder_reduction_positions: list = ["pre", "pre", "pre", "pre", "pre", "pre", "pre", "pre"], encoder_reduction_factors: list = [6, 0, 0, 0, 0, 0, 0, 0], encoder_dmodel: int = 640, encoder_nlayers: int = 8, @@ -59,6 +60,7 @@ def __init__( super().__init__( speech_config=speech_config, encoder=RnnTransducerEncoder( + reduction_positions=encoder_reduction_positions, reduction_factors=encoder_reduction_factors, dmodel=encoder_dmodel, nlayers=encoder_nlayers, diff --git a/tensorflow_asr/models/transducer/transformer.py b/tensorflow_asr/models/transducer/transformer.py index 0f0bda8c65..dfdfb9ebd6 100644 --- a/tensorflow_asr/models/transducer/transformer.py +++ b/tensorflow_asr/models/transducer/transformer.py @@ -13,12 +13,13 @@ # limitations under the License. import tensorflow as tf +import keras from tensorflow_asr.models.encoders.transformer import TransformerEncoder from tensorflow_asr.models.transducer.base_transducer import Transducer -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.transducer") +@keras.utils.register_keras_serializable("tensorflow_asr.models.transducer") class Transformer(Transducer): def __init__( self, @@ -121,7 +122,7 @@ def make(self, input_shape=[None], prediction_shape=[None], batch_size=None, **k None if self.encoder._memory_length is None else [ - tf.keras.Input(shape=[self.encoder._memory_length, self.encoder._dmodel], batch_size=batch_size, dtype=tf.float32) + keras.Input(shape=[self.encoder._memory_length, self.encoder._dmodel], batch_size=batch_size, dtype=tf.float32) for _ in range(self.encoder._num_blocks) ] ) diff --git a/tensorflow_asr/optimizers/accumulation.py b/tensorflow_asr/optimizers/accumulation.py index b7287cf56c..ed94e2236e 100644 --- a/tensorflow_asr/optimizers/accumulation.py +++ b/tensorflow_asr/optimizers/accumulation.py @@ -4,6 +4,7 @@ """ import tensorflow as tf +import keras class GradientAccumulator: @@ -14,7 +15,7 @@ class GradientAccumulator: def __init__( self, ga_steps, - model: tf.keras.Model, + model: keras.Model, name="ga", ): self.name = name diff --git a/tensorflow_asr/optimizers/regularizers.py b/tensorflow_asr/optimizers/regularizers.py index efab9d00f6..2397dcb879 100644 --- a/tensorflow_asr/optimizers/regularizers.py +++ b/tensorflow_asr/optimizers/regularizers.py @@ -1,10 +1,11 @@ from typing import List import tensorflow as tf +import keras -@tf.keras.utils.register_keras_serializable("tensorflow_asr.optimizers.regularizers") -class TimeDependentGaussianGradientNoise(tf.keras.regularizers.Regularizer): +@keras.utils.register_keras_serializable("tensorflow_asr.optimizers.regularizers") +class TimeDependentGaussianGradientNoise(keras.regularizers.Regularizer): """ Reference: https://openreview.net/pdf/ZY9xxQDMMu5Pk8ELfEz4.pdf """ diff --git a/tensorflow_asr/optimizers/schedules.py b/tensorflow_asr/optimizers/schedules.py index cb90d1d40d..a33ef66bec 100755 --- a/tensorflow_asr/optimizers/schedules.py +++ b/tensorflow_asr/optimizers/schedules.py @@ -13,10 +13,11 @@ # limitations under the License. import tensorflow as tf +import keras -@tf.keras.utils.register_keras_serializable("tensorflow_asr.optimizers.schedules") -class TransformerSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): +@keras.utils.register_keras_serializable("tensorflow_asr.optimizers.schedules") +class TransformerSchedule(keras.optimizers.schedules.LearningRateSchedule): def __init__(self, dmodel, scale=1.0, warmup_steps=4000, max_lr=None, min_lr=None): super().__init__() self.dmodel = tf.convert_to_tensor(dmodel, dtype=tf.float32) @@ -46,8 +47,8 @@ def get_config(self): } -@tf.keras.utils.register_keras_serializable("tensorflow_asr.optimizers.schedules") -class CyclicTransformerSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): +@keras.utils.register_keras_serializable("tensorflow_asr.optimizers.schedules") +class CyclicTransformerSchedule(keras.optimizers.schedules.LearningRateSchedule): """This callback implements a cyclical learning rate policy (CLR) to the square root decay generally used to train transformers. The method cycles the learning rate around the square root decay LR with an amplitude diff --git a/tensorflow_asr/utils/env_util.py b/tensorflow_asr/utils/env_util.py index 0c235e3526..3fcb8cedf6 100644 --- a/tensorflow_asr/utils/env_util.py +++ b/tensorflow_asr/utils/env_util.py @@ -18,6 +18,7 @@ import numpy as np import tensorflow as tf +import keras from packaging import version logger = tf.get_logger() @@ -126,11 +127,11 @@ def setup_mxp( raise ValueError(f"mxp must be in {options}") if mxp == "strict": policy = "mixed_bfloat16" if has_devices("TPU") else "mixed_float16" - tf.keras.mixed_precision.set_global_policy(policy) + keras.mixed_precision.set_global_policy(policy) logger.info(f"USING mixed precision policy {policy}") elif mxp == "strict_auto": policy = "mixed_bfloat16" if has_devices("TPU") else "mixed_float16" - tf.keras.mixed_precision.set_global_policy(policy) + keras.mixed_precision.set_global_policy(policy) tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True}) logger.info(f"USING auto mixed precision policy {policy}") elif mxp == "auto": @@ -156,5 +157,5 @@ def setup_seed( random.seed(seed) np.random.seed(seed) tf.random.set_seed(seed) - tf.keras.backend.experimental.enable_tf_random_generator() - tf.keras.utils.set_random_seed(seed) + keras.backend.experimental.enable_tf_random_generator() + keras.utils.set_random_seed(seed) diff --git a/tensorflow_asr/utils/layer_util.py b/tensorflow_asr/utils/layer_util.py index 8e050c02a4..33b9bdd5a0 100644 --- a/tensorflow_asr/utils/layer_util.py +++ b/tensorflow_asr/utils/layer_util.py @@ -13,6 +13,7 @@ # limitations under the License. import tensorflow as tf +import keras from tensorflow_asr.models.layers import convolution @@ -22,10 +23,10 @@ def get_rnn( ): assert rnn_type in ["lstm", "gru", "rnn"] if rnn_type == "lstm": - return tf.keras.layers.LSTM + return keras.layers.LSTM if rnn_type == "gru": - return tf.keras.layers.GRU - return tf.keras.layers.SimpleRNN + return keras.layers.GRU + return keras.layers.SimpleRNN def get_conv(