Skip to content

Commit

Permalink
[Not for Merge]: Visualize the gradient of each node in the lattice.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Mar 14, 2022
1 parent bb7f6ed commit 054e239
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 3 deletions.
21 changes: 20 additions & 1 deletion egs/librispeech/ASR/pruned_transducer_stateless/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,23 @@ def forward(
reduction="sum",
)

return (simple_loss, pruned_loss)
B = px_grad.size(0)
S = px_grad.size(1)
T = px_grad.size(2) - 1
# px_grad's shape (B, S, T+1)
# py_grad's shape (B, S+1, T)

px_grad_pad = torch.zeros(
(B, 1, T + 1), dtype=px_grad.dtype, device=px_grad.device
)
py_grad_pad = torch.zeros(
(B, S + 1, 1), dtype=px_grad.dtype, device=px_grad.device
)

px_grad_padded = torch.cat([px_grad, px_grad_pad], dim=1)
py_grad_padded = torch.cat([py_grad, py_grad_pad], dim=2)

# tot_grad's shape (B, S+1, T+1)
tot_grad = px_grad_padded + py_grad_padded

return (simple_loss, pruned_loss, tot_grad, x_lens, y_lens)
58 changes: 56 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import k2
import sentencepiece as spm
Expand Down Expand Up @@ -434,14 +434,20 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
simple_loss, pruned_loss, tot_grad, x_lens, y_lens, = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
cut_ids = [c.id for c in supervisions["cut"]]
save_and_plot_tot_grad(tot_grad, cut_ids, x_lens, y_lens)
import sys

sys.exit()

loss = params.simple_loss_scale * simple_loss + pruned_loss

assert loss.requires_grad == is_training
Expand Down Expand Up @@ -491,6 +497,53 @@ def compute_validation_loss(
return tot_loss


def save_and_plot_tot_grad(
tot_grad: torch.Tensor,
cut_ids: List[str],
x_lens: torch.Tensor,
y_lens: torch.Tensor,
):
"""Save and plot the tot_grad.
Args:
tot_grad:
A tensor of shape (B, U+1, T+1). It contains the gradient of
each node in the lattice.
cut_ids:
A list of size B, containing the cut ID of each utterance in the batch.
x_lens:
A 1-D tensor of shape (B,), specifying the number of valid acoustic
frames in tot_grad for each utterance in the batch.
y_lens:
A 1-D tensor of shape (B,), specifying the number of valid tokens
in tot_grad for each utterance in the batch.
"""
import matplotlib.pyplot as plt

tot_grad = tot_grad.detach().cpu().permute(0, 2, 1)

ext = "png" # supported types: png, ps, pdf, svg

x_lens = x_lens.tolist()
y_lens = y_lens.tolist()

tot_grad = tot_grad.unbind(0)
for i in range(len(cut_ids)):
cid = cut_ids[i]
T = x_lens[i]
U = y_lens[i]
grad = tot_grad[i][:T, :U]

filename = f"{cid}.{ext}"
logging.info(f"Saving to {filename}")
# plt.matshow(grad.t(), origin="lower", cmap="gray")
plt.matshow(grad.t(), origin="lower")
plt.xlabel("t")
plt.ylabel("u")
plt.title(cid)
plt.savefig(filename)


def train_one_epoch(
params: AttributeDict,
model: nn.Module,
Expand Down Expand Up @@ -577,6 +630,7 @@ def maybe_log_param_relative_changes():
batch=batch,
is_training=True,
)

# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info

Expand Down

0 comments on commit 054e239

Please sign in to comment.