Skip to content

Commit

Permalink
Initial tcpgen biasing
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Jun 26, 2024
1 parent 1f5c0a8 commit cfb0cab
Show file tree
Hide file tree
Showing 5 changed files with 454 additions and 51 deletions.
284 changes: 284 additions & 0 deletions egs/librispeech/ASR/zipformer/biasing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Tuple

import math
import torch
import torch.nn as nn

from icefall import ContextGraph
import logging
import random


class TCPGen(nn.Module):
def __init__(
self,
encoder_dim: int,
embed_dim: int,
joiner_dim: int,
decoder: nn.Module,
attn_dim: int = 512,
tcpgen_dropout: float = 0.1,
):
super().__init__()
# embedding for ool (out of biasing list) token
self.oolemb = torch.nn.Embedding(1, embed_dim)
# project symbol embeddings
self.q_proj_sym = torch.nn.Linear(embed_dim, attn_dim)
# project encoder embeddings
self.q_proj_acoustic = torch.nn.Linear(encoder_dim, attn_dim)
# project symbol embeddings (vocabulary + ool)
self.k_proj = torch.nn.Linear(embed_dim, attn_dim)
# generate tcpgen probability
self.tcp_gate = torch.nn.Linear(attn_dim + joiner_dim, 1)
self.dropout_tcpgen = torch.nn.Dropout(tcpgen_dropout)
self.decoder = decoder
self.vocab_size = decoder.vocab_size

def get_tcpgen_masks(
self,
targets: torch.Tensor,
context_graph: ContextGraph,
vocab_size: int,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, sql_len = targets.shape
dist_masks = torch.ones((batch_size, sql_len, vocab_size + 1))
gen_masks = []
yseqs = targets.tolist()
for i, yseq in enumerate(yseqs):
node = context_graph.root
gen_mask = []
for j, y in enumerate(yseq):
not_matched = False
if y == blank_id:
node = context_graph.root
gen_mask.append(0)
elif y in node.next:
gen_mask.append(0)
node = node.next[y]
if node.is_end:
node = context_graph.root
else:
gen_mask.append(1)
node = context_graph.root
not_matched = True
# unmask_index = (
# [vocab_size]
# if node.token == -1
# else list(node.next.keys()) + [vocab_size]
# )
# logging.info(f"token : {node.token}, keys : {node.next.keys()}")
# dist_masks[i, j, unmask_index] = 0
if not not_matched:
dist_masks[i, j, list(node.next.keys())] = 0

gen_masks.append(gen_mask + [1] * (sql_len - len(gen_mask)))
gen_masks = torch.Tensor(gen_masks).to(targets.device).bool()
dist_masks = dist_masks.to(targets.device).bool()
if random.random() >= 0.95:
logging.info(
f"gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
)
logging.info(
f"dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=2)}"
)
return dist_masks, gen_masks

def get_tcpgen_distribution(
self, query: torch.Tensor, dist_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query:
shape : (B, T, s_range, attn_dim)
dist_masks:
shape : (B, T, s_range, V + 1)
"""
# From original paper, k, v share same embeddings
# (V + 1, embed_dim)
kv = torch.cat([self.decoder.embedding.weight.data, self.oolemb.weight], dim=0)
# (V + 1, attn_dim)
kv = self.dropout_tcpgen(self.k_proj(kv))
# (B, T, s_range, attn_dim) * (attn_dim, V + 1) -> (B, T, s_range, V + 1)
distribution = torch.matmul(query, kv.permute(1, 0)) / math.sqrt(query.size(-1))
distribution = distribution.masked_fill(
dist_masks, torch.finfo(distribution.dtype).min
)
distribution = distribution.softmax(dim=-1)
# (B, T, s_range, V) * (V, attn_dim) -> (B, T, s_range, attn_dim)
# logging.info(f"distribution shape : {distribution.shape}")
hptr = torch.matmul(distribution[:, :, :, :-1], kv[:-1, :])
hptr = self.dropout_tcpgen(hptr)
if random.random() > 0.95:
logging.info(
f"distribution mean : {torch.mean(distribution, dim=3)}, std: {torch.std(distribution, dim=3)}"
)
logging.info(
f"distribution min : {torch.min(distribution, dim=3)}, max: {torch.max(distribution, dim=3)}"
)
return hptr, distribution

def prune_query_and_mask(
self,
query_sym: torch.Tensor,
query_acoustic: torch.Tensor,
dist_masks: torch.Tensor,
gen_masks: torch.Tensor,
ranges: torch.Tensor,
):
"""Prune the queries from symbols and acoustics with ranges
generated by `get_rnnt_prune_ranges` in pruned rnnt loss.
Args:
query_sym:
The symbol query, with shape (B, S, attn_dim).
query_acoustic:
The acoustic query, with shape (B, T, attn_dim).
dist_masks:
The TCPGen distribution masks, with shape (B, S, V + 1).
gen_masks:
The TCPGen probability masks, with shape (B, S).
ranges:
A tensor containing the symbol indexes for each frame that we want to
keep. Its shape is (B, T, s_range), see the docs in
`get_rnnt_prune_ranges` in rnnt_loss.py for more details of this tensor.
Returns:
Return the pruned query with the shape (B, T, s_range, attn_dim).
"""
assert ranges.shape[0] == query_sym.shape[0], (ranges.shape, query_sym.shape)
assert ranges.shape[0] == query_acoustic.shape[0], (
ranges.shape,
query_acoustic.shape,
)
assert query_acoustic.shape[1] == ranges.shape[1], (
query_acoustic.shape,
ranges.shape,
)
(B, T, s_range) = ranges.shape
(B, S, attn_dim) = query_sym.shape
assert query_acoustic.shape == (B, T, attn_dim), (
query_acoustic.shape,
(B, T, attn_dim),
)
assert dist_masks.shape == (B, S, self.vocab_size + 1), (
dist_masks.shape,
(B, S, self.vocab_size + 1),
)
assert gen_masks.shape == (B, S), (gen_masks.shape, (B, S))
# (B, T, s_range, attn_dim)
query_acoustic_pruned = query_acoustic.unsqueeze(2).expand(
(B, T, s_range, attn_dim)
)
# logging.info(f"query_sym : {query_sym.shape}")
# logging.info(f"ranges : {ranges}")
# (B, T, s_range, attn_dim)
query_sym_pruned = torch.gather(
query_sym,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, attn_dim)),
).reshape(B, T, s_range, attn_dim)
# (B, T, s_range, V + 1)
dist_masks_pruned = torch.gather(
dist_masks,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand(
(B, T * s_range, self.vocab_size + 1)
),
).reshape(B, T, s_range, self.vocab_size + 1)
# (B, T, s_range)
gen_masks_pruned = torch.gather(
gen_masks, dim=1, index=ranges.reshape(B, T * s_range)
).reshape(B, T, s_range)
return (
query_sym_pruned + query_acoustic_pruned,
dist_masks_pruned,
gen_masks_pruned,
)

def forward(
self,
targets: torch.Tensor,
encoder_embeddings: torch.Tensor,
ranges: torch.Tensor,
context_graph: ContextGraph,
):
"""
Args:
target:
The training targets in token ids (padded with blanks). shape : (B, S)
encoder_embeddings:
The encoder outputs. shape: (B, T, attn_dim)
ranges:
The prune ranges from pruned rnnt. shape: (B, T, s_range)
context_graphs:
The context_graphs for each utterance. B == len(context_graphs).
Return:
returns tcpgen embedding with shape (B, T, s_range, attn_dim) and
tcpgen distribution with shape (B, T, s_range, V + 1).
"""
query_sym = self.decoder.embedding(targets)

query_sym = self.q_proj_sym(query_sym) # (B, S, attn_dim)
query_acoustic = self.q_proj_acoustic(encoder_embeddings) # (B , T, attn_dim)

# dist_masks : (B, S, V + 1)
# gen_masks : (B, S)
dist_masks, gen_masks = self.get_tcpgen_masks(
targets=targets, context_graph=context_graph, vocab_size=self.vocab_size
)
# query : (B, T, s_range, attn_dim)
# dist_masks : (B, T, s_range, V + 1)
query, dist_masks, gen_masks = self.prune_query_and_mask(
query_sym=query_sym,
query_acoustic=query_acoustic,
dist_masks=dist_masks,
gen_masks=gen_masks,
ranges=ranges,
)

if random.random() >= 0.95:
logging.info(
f"pruned gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
)
logging.info(
f"pruned dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=3)}"
)

# hptr : (B, T, s_range, attn_dim)
# tcpgen_dist : (B, T, s_range, V + 1)
hptr, tcpgen_dist = self.get_tcpgen_distribution(query, dist_masks)
return hptr, tcpgen_dist, gen_masks

def generator_prob(
self, hptr: torch.Tensor, h_joiner: torch.Tensor, gen_masks: torch.Tensor
) -> torch.Tensor:
# tcpgen_prob : (B, T, s_range, 1)
tcpgen_prob = self.tcp_gate(torch.cat((h_joiner, hptr), dim=-1))
tcpgen_prob = torch.sigmoid(tcpgen_prob)
tcpgen_prob = tcpgen_prob.masked_fill(gen_masks.unsqueeze(-1), 0)
if random.random() >= 0.95:
logging.info(
f"tcpgen_prob mean : {torch.mean(tcpgen_prob.squeeze(-1), dim=(1,2))}, std : {torch.std(tcpgen_prob.squeeze(-1), dim=(1, 2))}"
)
logging.info(
f"tcpgen_prob min : {torch.min(tcpgen_prob.squeeze(-1), dim=1)}, max : {torch.max(tcpgen_prob.squeeze(-1), dim=1)}"
)
return tcpgen_prob
11 changes: 9 additions & 2 deletions egs/librispeech/ASR/zipformer/joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.nn as nn
from scaling import ScaledLinear
from typing import Optional


class Joiner(nn.Module):
Expand All @@ -37,6 +38,7 @@ def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
tcpgen_hptr: Optional[torch.Tensor] = None,
project_input: bool = True,
) -> torch.Tensor:
"""
Expand All @@ -62,6 +64,11 @@ def forward(
else:
logit = encoder_out + decoder_out

logit = self.output_linear(torch.tanh(logit))
if tcpgen_hptr is not None:
logit += tcpgen_hptr

return logit
activations = torch.tanh(logit)

logit = self.output_linear(activations)

return logit if tcpgen_hptr is None else (logit, activations)
Loading

0 comments on commit cfb0cab

Please sign in to comment.