Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implement TCPGen contextual biasing #1669

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 284 additions & 0 deletions egs/librispeech/ASR/zipformer/biasing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Tuple

import math
import torch
import torch.nn as nn

from icefall import ContextGraph
import logging
import random


class TCPGen(nn.Module):
def __init__(
self,
encoder_dim: int,
embed_dim: int,
joiner_dim: int,
decoder: nn.Module,
attn_dim: int = 512,
tcpgen_dropout: float = 0.1,
):
super().__init__()
# embedding for ool (out of biasing list) token
self.oolemb = torch.nn.Embedding(1, embed_dim)
# project symbol embeddings
self.q_proj_sym = torch.nn.Linear(embed_dim, attn_dim)
# project encoder embeddings
self.q_proj_acoustic = torch.nn.Linear(encoder_dim, attn_dim)
# project symbol embeddings (vocabulary + ool)
self.k_proj = torch.nn.Linear(embed_dim, attn_dim)
# generate tcpgen probability
self.tcp_gate = torch.nn.Linear(attn_dim + joiner_dim, 1)
self.dropout_tcpgen = torch.nn.Dropout(tcpgen_dropout)
self.decoder = decoder
self.vocab_size = decoder.vocab_size

def get_tcpgen_masks(
self,
targets: torch.Tensor,
context_graph: ContextGraph,
vocab_size: int,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, sql_len = targets.shape
dist_masks = torch.ones((batch_size, sql_len, vocab_size + 1))
gen_masks = []
yseqs = targets.tolist()
for i, yseq in enumerate(yseqs):
node = context_graph.root
gen_mask = []
for j, y in enumerate(yseq):
not_matched = False
if y == blank_id:
node = context_graph.root
gen_mask.append(0)
elif y in node.next:
gen_mask.append(0)
node = node.next[y]
if node.is_end:
node = context_graph.root
else:
gen_mask.append(1)
node = context_graph.root
not_matched = True
# unmask_index = (
# [vocab_size]
# if node.token == -1
# else list(node.next.keys()) + [vocab_size]
# )
# logging.info(f"token : {node.token}, keys : {node.next.keys()}")
# dist_masks[i, j, unmask_index] = 0
if not not_matched:
dist_masks[i, j, list(node.next.keys())] = 0

gen_masks.append(gen_mask + [1] * (sql_len - len(gen_mask)))
gen_masks = torch.Tensor(gen_masks).to(targets.device).bool()
dist_masks = dist_masks.to(targets.device).bool()
if random.random() >= 0.95:
logging.info(
f"gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
)
logging.info(
f"dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=2)}"
)
return dist_masks, gen_masks

def get_tcpgen_distribution(
self, query: torch.Tensor, dist_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query:
shape : (B, T, s_range, attn_dim)
dist_masks:
shape : (B, T, s_range, V + 1)
"""
# From original paper, k, v share same embeddings
# (V + 1, embed_dim)
kv = torch.cat([self.decoder.embedding.weight.data, self.oolemb.weight], dim=0)
# (V + 1, attn_dim)
kv = self.dropout_tcpgen(self.k_proj(kv))
# (B, T, s_range, attn_dim) * (attn_dim, V + 1) -> (B, T, s_range, V + 1)
distribution = torch.matmul(query, kv.permute(1, 0)) / math.sqrt(query.size(-1))
distribution = distribution.masked_fill(
dist_masks, torch.finfo(distribution.dtype).min
)
distribution = distribution.softmax(dim=-1)
# (B, T, s_range, V) * (V, attn_dim) -> (B, T, s_range, attn_dim)
# logging.info(f"distribution shape : {distribution.shape}")
hptr = torch.matmul(distribution[:, :, :, :-1], kv[:-1, :])
hptr = self.dropout_tcpgen(hptr)
if random.random() > 0.95:
logging.info(
f"distribution mean : {torch.mean(distribution, dim=3)}, std: {torch.std(distribution, dim=3)}"
)
logging.info(
f"distribution min : {torch.min(distribution, dim=3)}, max: {torch.max(distribution, dim=3)}"
)
return hptr, distribution

def prune_query_and_mask(
self,
query_sym: torch.Tensor,
query_acoustic: torch.Tensor,
dist_masks: torch.Tensor,
gen_masks: torch.Tensor,
ranges: torch.Tensor,
):
"""Prune the queries from symbols and acoustics with ranges
generated by `get_rnnt_prune_ranges` in pruned rnnt loss.

Args:
query_sym:
The symbol query, with shape (B, S, attn_dim).
query_acoustic:
The acoustic query, with shape (B, T, attn_dim).
dist_masks:
The TCPGen distribution masks, with shape (B, S, V + 1).
gen_masks:
The TCPGen probability masks, with shape (B, S).
ranges:
A tensor containing the symbol indexes for each frame that we want to
keep. Its shape is (B, T, s_range), see the docs in
`get_rnnt_prune_ranges` in rnnt_loss.py for more details of this tensor.

Returns:
Return the pruned query with the shape (B, T, s_range, attn_dim).
"""
assert ranges.shape[0] == query_sym.shape[0], (ranges.shape, query_sym.shape)
assert ranges.shape[0] == query_acoustic.shape[0], (
ranges.shape,
query_acoustic.shape,
)
assert query_acoustic.shape[1] == ranges.shape[1], (
query_acoustic.shape,
ranges.shape,
)
(B, T, s_range) = ranges.shape
(B, S, attn_dim) = query_sym.shape
assert query_acoustic.shape == (B, T, attn_dim), (
query_acoustic.shape,
(B, T, attn_dim),
)
assert dist_masks.shape == (B, S, self.vocab_size + 1), (
dist_masks.shape,
(B, S, self.vocab_size + 1),
)
assert gen_masks.shape == (B, S), (gen_masks.shape, (B, S))
# (B, T, s_range, attn_dim)
query_acoustic_pruned = query_acoustic.unsqueeze(2).expand(
(B, T, s_range, attn_dim)
)
# logging.info(f"query_sym : {query_sym.shape}")
# logging.info(f"ranges : {ranges}")
# (B, T, s_range, attn_dim)
query_sym_pruned = torch.gather(
query_sym,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, attn_dim)),
).reshape(B, T, s_range, attn_dim)
# (B, T, s_range, V + 1)
dist_masks_pruned = torch.gather(
dist_masks,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand(
(B, T * s_range, self.vocab_size + 1)
),
).reshape(B, T, s_range, self.vocab_size + 1)
# (B, T, s_range)
gen_masks_pruned = torch.gather(
gen_masks, dim=1, index=ranges.reshape(B, T * s_range)
).reshape(B, T, s_range)
return (
query_sym_pruned + query_acoustic_pruned,
dist_masks_pruned,
gen_masks_pruned,
)

def forward(
self,
targets: torch.Tensor,
encoder_embeddings: torch.Tensor,
ranges: torch.Tensor,
context_graph: ContextGraph,
):
"""
Args:
target:
The training targets in token ids (padded with blanks). shape : (B, S)
encoder_embeddings:
The encoder outputs. shape: (B, T, attn_dim)
ranges:
The prune ranges from pruned rnnt. shape: (B, T, s_range)
context_graphs:
The context_graphs for each utterance. B == len(context_graphs).

Return:
returns tcpgen embedding with shape (B, T, s_range, attn_dim) and
tcpgen distribution with shape (B, T, s_range, V + 1).
"""
query_sym = self.decoder.embedding(targets)

query_sym = self.q_proj_sym(query_sym) # (B, S, attn_dim)
query_acoustic = self.q_proj_acoustic(encoder_embeddings) # (B , T, attn_dim)

# dist_masks : (B, S, V + 1)
# gen_masks : (B, S)
dist_masks, gen_masks = self.get_tcpgen_masks(
targets=targets, context_graph=context_graph, vocab_size=self.vocab_size
)
# query : (B, T, s_range, attn_dim)
# dist_masks : (B, T, s_range, V + 1)
query, dist_masks, gen_masks = self.prune_query_and_mask(
query_sym=query_sym,
query_acoustic=query_acoustic,
dist_masks=dist_masks,
gen_masks=gen_masks,
ranges=ranges,
)

if random.random() >= 0.95:
logging.info(
f"pruned gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
)
logging.info(
f"pruned dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=3)}"
)

# hptr : (B, T, s_range, attn_dim)
# tcpgen_dist : (B, T, s_range, V + 1)
hptr, tcpgen_dist = self.get_tcpgen_distribution(query, dist_masks)
return hptr, tcpgen_dist, gen_masks

def generator_prob(
self, hptr: torch.Tensor, h_joiner: torch.Tensor, gen_masks: torch.Tensor
) -> torch.Tensor:
# tcpgen_prob : (B, T, s_range, 1)
tcpgen_prob = self.tcp_gate(torch.cat((h_joiner, hptr), dim=-1))
tcpgen_prob = torch.sigmoid(tcpgen_prob)
tcpgen_prob = tcpgen_prob.masked_fill(gen_masks.unsqueeze(-1), 0)
if random.random() >= 0.95:
logging.info(
f"tcpgen_prob mean : {torch.mean(tcpgen_prob.squeeze(-1), dim=(1,2))}, std : {torch.std(tcpgen_prob.squeeze(-1), dim=(1, 2))}"
)
logging.info(
f"tcpgen_prob min : {torch.min(tcpgen_prob.squeeze(-1), dim=1)}, max : {torch.max(tcpgen_prob.squeeze(-1), dim=1)}"
)
return tcpgen_prob
11 changes: 9 additions & 2 deletions egs/librispeech/ASR/zipformer/joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.nn as nn
from scaling import ScaledLinear
from typing import Optional


class Joiner(nn.Module):
Expand All @@ -37,6 +38,7 @@ def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
tcpgen_hptr: Optional[torch.Tensor] = None,
project_input: bool = True,
) -> torch.Tensor:
"""
Expand All @@ -62,6 +64,11 @@ def forward(
else:
logit = encoder_out + decoder_out

logit = self.output_linear(torch.tanh(logit))
if tcpgen_hptr is not None:
logit += tcpgen_hptr

return logit
activations = torch.tanh(logit)

logit = self.output_linear(activations)

return logit if tcpgen_hptr is None else (logit, activations)
Loading
Loading