Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNNT loss in pure TF #95

Merged
merged 20 commits into from
Jan 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tensorflow_asr/datasets/asr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,14 @@ def __init__(self,
data_paths, augmentations, cache, shuffle
)
self.tfrecords_dir = tfrecords_dir
if not os.path.exists(self.tfrecords_dir):
os.makedirs(self.tfrecords_dir)
if not tf.io.gfile.exists(self.tfrecords_dir):
tf.io.gfile.makedirs(self.tfrecords_dir)

def create_tfrecords(self):
if not os.path.exists(self.tfrecords_dir):
os.makedirs(self.tfrecords_dir)
if not tf.io.gfile.exists(self.tfrecords_dir):
tf.io.gfile.makedirs(self.tfrecords_dir)

if glob.glob(os.path.join(self.tfrecords_dir, f"{self.stage}*.tfrecord")):
if tf.io.gfile.glob(os.path.join(self.tfrecords_dir, f"{self.stage}*.tfrecord")):
print(f"TFRecords're already existed: {self.stage}")
return True

Expand Down
222 changes: 220 additions & 2 deletions tensorflow_asr/losses/rnnt_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# RNNT loss implementation in pure TensorFlow is borrowed from [iamjanvijay's repo](https://github.com/iamjanvijay/rnnt)

import tensorflow as tf
from warprnnt_tensorflow import rnnt_loss as warp_rnnt_loss
try:
from warprnnt_tensorflow import rnnt_loss as warp_rnnt_loss
use_warprnnt = True
except ImportError:
print("Cannot import RNNT loss in warprnnt. Falls back to RNNT in TensorFlow")
from tensorflow.python.ops.gen_array_ops import matrix_diag_part_v2
use_warprnnt = False


def rnnt_loss(logits, labels, label_length, logit_length, blank=0, name=None):
if use_warprnnt:
return rnnt_loss_warprnnt(logits=logits, labels=labels, label_length=label_length, logit_length=logit_length, blank=blank)
else:
return rnnt_loss_tf(logits=logits, labels=labels, label_length=label_length, logit_length=logit_length, name=name)

@tf.function
def rnnt_loss(logits, labels, label_length, logit_length, blank=0):
def rnnt_loss_warprnnt(logits, labels, label_length, logit_length, blank=0):
if not tf.config.list_physical_devices('GPU'):
logits = tf.nn.log_softmax(logits)
loss = warp_rnnt_loss(
Expand All @@ -28,3 +41,208 @@ def rnnt_loss(logits, labels, label_length, logit_length, blank=0):
blank_label=blank
)
return loss


LOG_0 = float("-inf")


def nan_to_zero(input_tensor):
return tf.where(tf.math.is_nan(input_tensor), tf.zeros_like(input_tensor), input_tensor)


def reduce_logsumexp(input_tensor, axis):
maximum = tf.reduce_max(input_tensor, axis=axis)
input_tensor = nan_to_zero(input_tensor - maximum)
return tf.math.log(tf.reduce_sum(tf.exp(input_tensor), axis=axis)) + maximum


def extract_diagonals(log_probs):
time_steps = tf.shape(log_probs)[1] # T
output_steps = tf.shape(log_probs)[2] # U + 1
reverse_log_probs = tf.reverse(log_probs, axis=[-1])
paddings = [[0, 0], [0, 0], [time_steps - 1, 0]]
padded_reverse_log_probs = tf.pad(reverse_log_probs, paddings,
'CONSTANT', constant_values=LOG_0)
diagonals = matrix_diag_part_v2(padded_reverse_log_probs, k=(0, time_steps + output_steps - 2),
padding_value=LOG_0)

return tf.transpose(diagonals, perm=[1, 0, 2])


def transition_probs(one_hot_labels, log_probs):
"""
:return: blank_probs with shape batch_size x input_max_len x target_max_len
truth_probs with shape batch_size x input_max_len x (target_max_len-1)
"""
blank_probs = log_probs[:, :, :, 0]
truth_probs = tf.reduce_sum(tf.multiply(log_probs[:, :, :-1, :], one_hot_labels), axis=-1)

return blank_probs, truth_probs


def forward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len):
"""
:return: forward variable alpha with shape batch_size x input_max_len x target_max_len
"""

def next_state(x, trans_probs):
blank_probs = trans_probs[0]
truth_probs = trans_probs[1]

x_b = tf.concat([LOG_0 * tf.ones(shape=[batch_size, 1]), x[:, :-1] + blank_probs], axis=1)
x_t = x + truth_probs

x = tf.math.reduce_logsumexp(tf.stack([x_b, x_t], axis=0), axis=0)
return x

initial_alpha = tf.concat(
[tf.zeros(shape=[batch_size, 1]), tf.ones(shape=[batch_size, input_max_len - 1]) * LOG_0], axis=1)

fwd = tf.scan(next_state, (bp_diags[:-1, :, :-1], tp_diags), initializer=initial_alpha)

alpha = tf.transpose(
tf.concat([tf.expand_dims(initial_alpha, axis=0), fwd], axis=0), perm=[1, 2, 0])
alpha = matrix_diag_part_v2(alpha, k=(0, target_max_len - 1), padding_value=LOG_0)
alpha = tf.transpose(tf.reverse(alpha, axis=[1]), perm=[0, 2, 1])

return alpha


def backward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len, label_length, logit_length, blank_sl):
"""
:return: backward variable beta with shape batch_size x input_max_len x target_max_len
"""

def next_state(x, mask_and_trans_probs):
mask_s, blank_probs_s, truth_probs = mask_and_trans_probs

beta_b = tf.concat([x[:, 1:] + blank_probs_s, LOG_0 * tf.ones(shape=[batch_size, 1])], axis=1)
beta_t = tf.concat([x[:, :-1] + truth_probs, LOG_0 * tf.ones(shape=[batch_size, 1])], axis=1)

beta_next = reduce_logsumexp(tf.stack([beta_b, beta_t], axis=0), axis=0)
masked_beta_next = nan_to_zero(beta_next * tf.expand_dims(mask_s, axis=1)) + nan_to_zero(x * tf.expand_dims((1.0 - mask_s), axis=1))
return masked_beta_next

# Initial beta for batches.
initial_beta_mask = tf.one_hot(logit_length - 1, depth=input_max_len + 1)
initial_beta = tf.expand_dims(blank_sl, axis=1) * initial_beta_mask + nan_to_zero(LOG_0 * (1.0 - initial_beta_mask))

# Mask for scan iterations.
mask = tf.sequence_mask(logit_length + label_length - 1, input_max_len + target_max_len - 2,
dtype=tf.dtypes.float32)
mask = tf.transpose(mask, perm=[1, 0])

bwd = tf.scan(next_state, (mask, bp_diags[:-1, :, :], tp_diags), initializer=initial_beta, reverse=True)

beta = tf.transpose(tf.concat([bwd, tf.expand_dims(initial_beta, axis=0)], axis=0), perm=[1, 2, 0])[:, :-1, :]
beta = matrix_diag_part_v2(beta, k=(0, target_max_len - 1), padding_value=LOG_0)
beta = tf.transpose(tf.reverse(beta, axis=[1]), perm=[0, 2, 1])

return beta


def compute_rnnt_loss_and_grad_helper(logits, labels, label_length, logit_length):
batch_size = tf.shape(logits)[0]
input_max_len = tf.shape(logits)[1]
target_max_len = tf.shape(logits)[2]
vocab_size = tf.shape(logits)[3]

one_hot_labels = tf.one_hot(tf.tile(tf.expand_dims(labels, axis=1),
multiples=[1, input_max_len, 1]), depth=vocab_size)

log_probs = tf.nn.log_softmax(logits)
blank_probs, truth_probs = transition_probs(one_hot_labels, log_probs)
bp_diags = extract_diagonals(blank_probs)
tp_diags = extract_diagonals(truth_probs)

label_mask = tf.expand_dims(tf.sequence_mask(
label_length + 1, maxlen=target_max_len, dtype=tf.float32), axis=1)
small_label_mask = tf.expand_dims(tf.sequence_mask(
label_length, maxlen=target_max_len, dtype=tf.float32), axis=1)
input_mask = tf.expand_dims(tf.sequence_mask(
logit_length, maxlen=input_max_len, dtype=tf.float32), axis=2)
small_input_mask = tf.expand_dims(tf.sequence_mask(
logit_length - 1, maxlen=input_max_len, dtype=tf.float32), axis=2)
mask = label_mask * input_mask
grad_blank_mask = (label_mask * small_input_mask)[:, :-1, :]
grad_truth_mask = (small_label_mask * input_mask)[:, :, :-1]

alpha = forward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len) * mask

indices = tf.stack([logit_length - 1, label_length], axis=1)
blank_sl = tf.gather_nd(blank_probs, indices, batch_dims=1)

beta = backward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len, label_length, logit_length,
blank_sl) * mask
beta = tf.where(tf.math.is_nan(beta), tf.zeros_like(beta), beta)
final_state_probs = beta[:, 0, 0]

# Compute gradients of loss w.r.t. blank log-probabilities.
grads_blank = -tf.exp((alpha[:, :-1, :] + beta[:, 1:, :] - tf.reshape(final_state_probs,
shape=[batch_size, 1, 1]) + blank_probs[:,
:-1,
:]) * grad_blank_mask) * grad_blank_mask
grads_blank = tf.concat([grads_blank, tf.zeros(shape=(batch_size, 1, target_max_len))], axis=1)
last_grads_blank = -1 * tf.scatter_nd(
tf.concat([tf.reshape(tf.range(batch_size, dtype=tf.int64), shape=[batch_size, 1]), tf.cast(indices, dtype=tf.int64)], axis=1),
tf.ones(batch_size, dtype=tf.float32), [batch_size, input_max_len, target_max_len])
grads_blank = grads_blank + last_grads_blank

# Compute gradients of loss w.r.t. truth log-probabilities.
grads_truth = -tf.exp((alpha[:, :, :-1] + beta[:, :, 1:] - tf.reshape(final_state_probs, shape=[batch_size, 1,
1]) + truth_probs) * grad_truth_mask) * grad_truth_mask

# Compute gradients of loss w.r.t. activations.
a = tf.tile(tf.reshape(tf.range(target_max_len - 1, dtype=tf.int64), shape=(1, 1, target_max_len - 1, 1)),
multiples=[batch_size, 1, 1, 1])
b = tf.cast(tf.reshape(labels - 1, shape=(batch_size, 1, target_max_len - 1, 1)), dtype=tf.int64)
c = tf.concat([a, b], axis=3)
d = tf.tile(c, multiples=(1, input_max_len, 1, 1))
e = tf.tile(tf.reshape(tf.range(input_max_len, dtype=tf.int64), shape=(1, input_max_len, 1, 1)),
multiples=(batch_size, 1, target_max_len - 1, 1))
f = tf.concat([e, d], axis=3)
g = tf.tile(tf.reshape(tf.range(batch_size, dtype=tf.int64), shape=(batch_size, 1, 1, 1)),
multiples=[1, input_max_len, target_max_len - 1, 1])
scatter_idx = tf.concat([g, f], axis=3)
# TODO - improve the part of code for scatter_idx computation.
probs = tf.exp(log_probs)
grads_truth_scatter = tf.scatter_nd(scatter_idx, grads_truth,
[batch_size, input_max_len, target_max_len, vocab_size - 1])
grads = tf.concat(
[tf.reshape(grads_blank, shape=(batch_size, input_max_len, target_max_len, -1)), grads_truth_scatter], axis=3)
grads_logits = grads - probs * (tf.reduce_sum(grads, axis=3, keepdims=True))

loss = -final_state_probs
return loss, grads_logits


def rnnt_loss_tf(logits, labels, label_length, logit_length, name=None):
name = "rnnt_loss" if name is None else name
with tf.name_scope(name):
logits = tf.convert_to_tensor(logits, name="logits")
logits = tf.nn.log_softmax(logits)
labels = tf.convert_to_tensor(labels, name="labels")
label_length = tf.convert_to_tensor(label_length, name="label_length")
logit_length = tf.convert_to_tensor(logit_length, name="logit_length")

args = [logits, labels, label_length, logit_length]

@tf.custom_gradient
def compute_rnnt_loss_and_grad(logits_t, labels_t, label_length_t, logit_length_t):
"""Compute RNN-T loss and gradients."""
logits_t.set_shape(logits.shape)
labels_t.set_shape(labels.shape)
label_length_t.set_shape(label_length.shape)
logit_length_t.set_shape(logit_length.shape)
kwargs = dict(logits=logits_t, labels=labels_t, label_length=label_length_t, logit_length=logit_length_t)
result = compute_rnnt_loss_and_grad_helper(**kwargs)

def grad(grad_loss):
grads = [tf.reshape(grad_loss, [-1, 1, 1, 1]) * result[1]]
grads += [None] * (len(args) - len(grads))
return grads

return result[0], grad


return compute_rnnt_loss_and_grad(*args)
16 changes: 8 additions & 8 deletions tensorflow_asr/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ def setup_strategy(devices):
return tf.distribute.MirroredStrategy()


# def setup_tpu(tpu_address):
# import tensorflow as tf

# resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + tpu_address)
# tf.config.experimental_connect_to_cluster(resolver)
# tf.tpu.experimental.initialize_tpu_system(resolver)
# print("All TPUs: ", tf.config.list_logical_devices('TPU'))
# return resolver
def setup_tpu(tpu_address=None):
import tensorflow as tf

resolver = tf.distribute.cluster_resolver.TPUClusterResolver() if tpu_address is None else tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + tpu_address)
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All TPUs: ", tf.config.list_logical_devices('TPU'))
return tf.distribute.experimental.TPUStrategy(resolver)
7 changes: 5 additions & 2 deletions tensorflow_asr/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ def check_key_in_dict(dictionary, keys):

def preprocess_paths(paths: Union[List, str]):
if isinstance(paths, list):
return [os.path.abspath(os.path.expanduser(path)) for path in paths]
return os.path.abspath(os.path.expanduser(paths)) if paths else None
return [path if path.startswith('gs://') else os.path.abspath(os.path.expanduser(path)) for path in paths]
elif isinstance(paths, str):
return paths if paths.startswith('gs://') else os.path.abspath(os.path.expanduser(paths))
else:
return None


def nan_to_zero(input_tensor):
Expand Down