-
Notifications
You must be signed in to change notification settings - Fork 295
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
454 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,284 @@ | ||
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang) | ||
# | ||
# See ../../../../LICENSE for clarification regarding multiple authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List, Optional, Tuple | ||
|
||
import math | ||
import torch | ||
import torch.nn as nn | ||
|
||
from icefall import ContextGraph | ||
import logging | ||
import random | ||
|
||
|
||
class TCPGen(nn.Module): | ||
def __init__( | ||
self, | ||
encoder_dim: int, | ||
embed_dim: int, | ||
joiner_dim: int, | ||
decoder: nn.Module, | ||
attn_dim: int = 512, | ||
tcpgen_dropout: float = 0.1, | ||
): | ||
super().__init__() | ||
# embedding for ool (out of biasing list) token | ||
self.oolemb = torch.nn.Embedding(1, embed_dim) | ||
# project symbol embeddings | ||
self.q_proj_sym = torch.nn.Linear(embed_dim, attn_dim) | ||
# project encoder embeddings | ||
self.q_proj_acoustic = torch.nn.Linear(encoder_dim, attn_dim) | ||
# project symbol embeddings (vocabulary + ool) | ||
self.k_proj = torch.nn.Linear(embed_dim, attn_dim) | ||
# generate tcpgen probability | ||
self.tcp_gate = torch.nn.Linear(attn_dim + joiner_dim, 1) | ||
self.dropout_tcpgen = torch.nn.Dropout(tcpgen_dropout) | ||
self.decoder = decoder | ||
self.vocab_size = decoder.vocab_size | ||
|
||
def get_tcpgen_masks( | ||
self, | ||
targets: torch.Tensor, | ||
context_graph: ContextGraph, | ||
vocab_size: int, | ||
blank_id: int = 0, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
batch_size, sql_len = targets.shape | ||
dist_masks = torch.ones((batch_size, sql_len, vocab_size + 1)) | ||
gen_masks = [] | ||
yseqs = targets.tolist() | ||
for i, yseq in enumerate(yseqs): | ||
node = context_graph.root | ||
gen_mask = [] | ||
for j, y in enumerate(yseq): | ||
not_matched = False | ||
if y == blank_id: | ||
node = context_graph.root | ||
gen_mask.append(0) | ||
elif y in node.next: | ||
gen_mask.append(0) | ||
node = node.next[y] | ||
if node.is_end: | ||
node = context_graph.root | ||
else: | ||
gen_mask.append(1) | ||
node = context_graph.root | ||
not_matched = True | ||
# unmask_index = ( | ||
# [vocab_size] | ||
# if node.token == -1 | ||
# else list(node.next.keys()) + [vocab_size] | ||
# ) | ||
# logging.info(f"token : {node.token}, keys : {node.next.keys()}") | ||
# dist_masks[i, j, unmask_index] = 0 | ||
if not not_matched: | ||
dist_masks[i, j, list(node.next.keys())] = 0 | ||
|
||
gen_masks.append(gen_mask + [1] * (sql_len - len(gen_mask))) | ||
gen_masks = torch.Tensor(gen_masks).to(targets.device).bool() | ||
dist_masks = dist_masks.to(targets.device).bool() | ||
if random.random() >= 0.95: | ||
logging.info( | ||
f"gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}" | ||
) | ||
logging.info( | ||
f"dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=2)}" | ||
) | ||
return dist_masks, gen_masks | ||
|
||
def get_tcpgen_distribution( | ||
self, query: torch.Tensor, dist_masks: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Args: | ||
query: | ||
shape : (B, T, s_range, attn_dim) | ||
dist_masks: | ||
shape : (B, T, s_range, V + 1) | ||
""" | ||
# From original paper, k, v share same embeddings | ||
# (V + 1, embed_dim) | ||
kv = torch.cat([self.decoder.embedding.weight.data, self.oolemb.weight], dim=0) | ||
# (V + 1, attn_dim) | ||
kv = self.dropout_tcpgen(self.k_proj(kv)) | ||
# (B, T, s_range, attn_dim) * (attn_dim, V + 1) -> (B, T, s_range, V + 1) | ||
distribution = torch.matmul(query, kv.permute(1, 0)) / math.sqrt(query.size(-1)) | ||
distribution = distribution.masked_fill( | ||
dist_masks, torch.finfo(distribution.dtype).min | ||
) | ||
distribution = distribution.softmax(dim=-1) | ||
# (B, T, s_range, V) * (V, attn_dim) -> (B, T, s_range, attn_dim) | ||
# logging.info(f"distribution shape : {distribution.shape}") | ||
hptr = torch.matmul(distribution[:, :, :, :-1], kv[:-1, :]) | ||
hptr = self.dropout_tcpgen(hptr) | ||
if random.random() > 0.95: | ||
logging.info( | ||
f"distribution mean : {torch.mean(distribution, dim=3)}, std: {torch.std(distribution, dim=3)}" | ||
) | ||
logging.info( | ||
f"distribution min : {torch.min(distribution, dim=3)}, max: {torch.max(distribution, dim=3)}" | ||
) | ||
return hptr, distribution | ||
|
||
def prune_query_and_mask( | ||
self, | ||
query_sym: torch.Tensor, | ||
query_acoustic: torch.Tensor, | ||
dist_masks: torch.Tensor, | ||
gen_masks: torch.Tensor, | ||
ranges: torch.Tensor, | ||
): | ||
"""Prune the queries from symbols and acoustics with ranges | ||
generated by `get_rnnt_prune_ranges` in pruned rnnt loss. | ||
Args: | ||
query_sym: | ||
The symbol query, with shape (B, S, attn_dim). | ||
query_acoustic: | ||
The acoustic query, with shape (B, T, attn_dim). | ||
dist_masks: | ||
The TCPGen distribution masks, with shape (B, S, V + 1). | ||
gen_masks: | ||
The TCPGen probability masks, with shape (B, S). | ||
ranges: | ||
A tensor containing the symbol indexes for each frame that we want to | ||
keep. Its shape is (B, T, s_range), see the docs in | ||
`get_rnnt_prune_ranges` in rnnt_loss.py for more details of this tensor. | ||
Returns: | ||
Return the pruned query with the shape (B, T, s_range, attn_dim). | ||
""" | ||
assert ranges.shape[0] == query_sym.shape[0], (ranges.shape, query_sym.shape) | ||
assert ranges.shape[0] == query_acoustic.shape[0], ( | ||
ranges.shape, | ||
query_acoustic.shape, | ||
) | ||
assert query_acoustic.shape[1] == ranges.shape[1], ( | ||
query_acoustic.shape, | ||
ranges.shape, | ||
) | ||
(B, T, s_range) = ranges.shape | ||
(B, S, attn_dim) = query_sym.shape | ||
assert query_acoustic.shape == (B, T, attn_dim), ( | ||
query_acoustic.shape, | ||
(B, T, attn_dim), | ||
) | ||
assert dist_masks.shape == (B, S, self.vocab_size + 1), ( | ||
dist_masks.shape, | ||
(B, S, self.vocab_size + 1), | ||
) | ||
assert gen_masks.shape == (B, S), (gen_masks.shape, (B, S)) | ||
# (B, T, s_range, attn_dim) | ||
query_acoustic_pruned = query_acoustic.unsqueeze(2).expand( | ||
(B, T, s_range, attn_dim) | ||
) | ||
# logging.info(f"query_sym : {query_sym.shape}") | ||
# logging.info(f"ranges : {ranges}") | ||
# (B, T, s_range, attn_dim) | ||
query_sym_pruned = torch.gather( | ||
query_sym, | ||
dim=1, | ||
index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, attn_dim)), | ||
).reshape(B, T, s_range, attn_dim) | ||
# (B, T, s_range, V + 1) | ||
dist_masks_pruned = torch.gather( | ||
dist_masks, | ||
dim=1, | ||
index=ranges.reshape(B, T * s_range, 1).expand( | ||
(B, T * s_range, self.vocab_size + 1) | ||
), | ||
).reshape(B, T, s_range, self.vocab_size + 1) | ||
# (B, T, s_range) | ||
gen_masks_pruned = torch.gather( | ||
gen_masks, dim=1, index=ranges.reshape(B, T * s_range) | ||
).reshape(B, T, s_range) | ||
return ( | ||
query_sym_pruned + query_acoustic_pruned, | ||
dist_masks_pruned, | ||
gen_masks_pruned, | ||
) | ||
|
||
def forward( | ||
self, | ||
targets: torch.Tensor, | ||
encoder_embeddings: torch.Tensor, | ||
ranges: torch.Tensor, | ||
context_graph: ContextGraph, | ||
): | ||
""" | ||
Args: | ||
target: | ||
The training targets in token ids (padded with blanks). shape : (B, S) | ||
encoder_embeddings: | ||
The encoder outputs. shape: (B, T, attn_dim) | ||
ranges: | ||
The prune ranges from pruned rnnt. shape: (B, T, s_range) | ||
context_graphs: | ||
The context_graphs for each utterance. B == len(context_graphs). | ||
Return: | ||
returns tcpgen embedding with shape (B, T, s_range, attn_dim) and | ||
tcpgen distribution with shape (B, T, s_range, V + 1). | ||
""" | ||
query_sym = self.decoder.embedding(targets) | ||
|
||
query_sym = self.q_proj_sym(query_sym) # (B, S, attn_dim) | ||
query_acoustic = self.q_proj_acoustic(encoder_embeddings) # (B , T, attn_dim) | ||
|
||
# dist_masks : (B, S, V + 1) | ||
# gen_masks : (B, S) | ||
dist_masks, gen_masks = self.get_tcpgen_masks( | ||
targets=targets, context_graph=context_graph, vocab_size=self.vocab_size | ||
) | ||
# query : (B, T, s_range, attn_dim) | ||
# dist_masks : (B, T, s_range, V + 1) | ||
query, dist_masks, gen_masks = self.prune_query_and_mask( | ||
query_sym=query_sym, | ||
query_acoustic=query_acoustic, | ||
dist_masks=dist_masks, | ||
gen_masks=gen_masks, | ||
ranges=ranges, | ||
) | ||
|
||
if random.random() >= 0.95: | ||
logging.info( | ||
f"pruned gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}" | ||
) | ||
logging.info( | ||
f"pruned dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=3)}" | ||
) | ||
|
||
# hptr : (B, T, s_range, attn_dim) | ||
# tcpgen_dist : (B, T, s_range, V + 1) | ||
hptr, tcpgen_dist = self.get_tcpgen_distribution(query, dist_masks) | ||
return hptr, tcpgen_dist, gen_masks | ||
|
||
def generator_prob( | ||
self, hptr: torch.Tensor, h_joiner: torch.Tensor, gen_masks: torch.Tensor | ||
) -> torch.Tensor: | ||
# tcpgen_prob : (B, T, s_range, 1) | ||
tcpgen_prob = self.tcp_gate(torch.cat((h_joiner, hptr), dim=-1)) | ||
tcpgen_prob = torch.sigmoid(tcpgen_prob) | ||
tcpgen_prob = tcpgen_prob.masked_fill(gen_masks.unsqueeze(-1), 0) | ||
if random.random() >= 0.95: | ||
logging.info( | ||
f"tcpgen_prob mean : {torch.mean(tcpgen_prob.squeeze(-1), dim=(1,2))}, std : {torch.std(tcpgen_prob.squeeze(-1), dim=(1, 2))}" | ||
) | ||
logging.info( | ||
f"tcpgen_prob min : {torch.min(tcpgen_prob.squeeze(-1), dim=1)}, max : {torch.max(tcpgen_prob.squeeze(-1), dim=1)}" | ||
) | ||
return tcpgen_prob |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.