Skip to content

Commit

Permalink
lifted T >= S for regular case
Browse files Browse the repository at this point in the history
  • Loading branch information
durson committed Aug 23, 2023
1 parent 82983b5 commit b48d47e
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 19 deletions.
66 changes: 47 additions & 19 deletions fast_rnnt/python/fast_rnnt/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@
from .mutual_information import mutual_information_recursion


def validate_st_lengths(
S: int, T: int, is_rnnt_type_regular: bool, boundary: Optional[Tensor] = None
):
if boundary is None:
assert S >= 1, S
assert (
is_rnnt_type_regular or T >= S
), f"Modified transducer requires T >= S, but got T={T} and S={S}"
else:
Ss, Ts = boundary[2:4]
assert (Ss >= 1).all(), Ss
assert (
is_rnnt_type_regular or (Ts >= Ss).all()
), f"Modified transducer requires T >= S, but got T={Ts} and S={Ss}"


def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor:
"""
Insert -inf's into `px` in appropriate places if `boundary` is not
Expand Down Expand Up @@ -145,8 +161,8 @@ def get_rnnt_logprobs(
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)

validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type

# subtracting am_max and lm_max is to ensure the probs are in a good range
Expand Down Expand Up @@ -389,8 +405,8 @@ def get_rnnt_logprobs_joint(
(B, T, S1, C) = logits.shape
S = S1 - 1
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)

validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type

normalizers = torch.logsumexp(logits, dim=3)
Expand Down Expand Up @@ -616,7 +632,7 @@ def _adjust_pruning_lower_bound(
def get_rnnt_prune_ranges(
px_grad: torch.Tensor,
py_grad: torch.Tensor,
boundary: torch.Tensor,
boundary: Tensor,
s_range: int,
) -> torch.Tensor:
"""Get the pruning ranges of normal rnnt loss according to the grads
Expand Down Expand Up @@ -662,28 +678,40 @@ def get_rnnt_prune_ranges(
"""
(B, S, T1) = px_grad.shape
T = py_grad.shape[-1]

is_regular = T1 != T

assert T1 in [T, T + 1], T1
S1 = S + 1
assert py_grad.shape == (B, S + 1, T), py_grad.shape
assert boundary.shape == (B, 4), boundary.shape

assert S >= 1, S
assert T >= S, (T, S)
validate_st_lengths(S, T, is_regular, boundary)

# adjust s_range if S >> T for regular case
if is_regular:
Ss, Ts = boundary[2:4]
s_range_min = (Ss + 2*Ts - 2).div(Ts, rounding_mode="trunc").max().item()
if s_range > s_range_min:
s_range = s_range_min
print(
f"Warning: get_rnnt_prune_ranges - got s_range={s_range} "
f"for boundaries S={Ss}, T={Ts}. Adjusting to {s_range_min}"
)

# s_range > S means we won't prune out any symbols. To make indexing with
# ranges run normally, s_range should be equal to or less than ``S + 1``.
if s_range > S:
s_range = S + 1

if T1 == T:
assert (
s_range >= 1
), "Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning."

else:
if is_regular:
assert (
s_range >= 2
), "Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning."
else:
assert (
s_range >= 1
), "Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning."

(B_stride, S_stride, T_stride) = py_grad.stride()
blk_grad = torch.as_strided(
Expand Down Expand Up @@ -822,7 +850,7 @@ def get_rnnt_logprobs_pruned(
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor,
boundary: Optional[Tensor] = None,
rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]:
"""Construct px, py for mutual_information_recursion with pruned output.
Expand Down Expand Up @@ -893,8 +921,8 @@ def get_rnnt_logprobs_pruned(
(B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range), f"{ranges.shape} == ({B}, {T}, {s_range})"
(B, S) = symbols.shape
assert S >= 1, S
assert T >= S, (T, S)

validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type

normalizers = torch.logsumexp(logits, dim=3)
Expand Down Expand Up @@ -989,7 +1017,7 @@ def rnnt_loss_pruned(
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor = None,
boundary: Optional[Tensor] = None,
rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean",
Expand Down Expand Up @@ -1208,8 +1236,8 @@ def get_rnnt_logprobs_smoothed(
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)

validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type

# Caution: some parts of this code are a little less clear than they could
Expand Down
96 changes: 96 additions & 0 deletions fast_rnnt/python/tests/rnnt_loss_test.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,102 @@ def test_rnnt_loss_pruned_small_symbols_number(self):
)
print(f"Pruned loss with range {r} : {pruned_loss}")

# Test low s_range values with large S and small T,
# at this circumstance, the s_range would not be enough
# to cover the whole sequence length (in regular rnnt mode)
# and would result in inf loss
def test_rnnt_loss_pruned_small_s_range(self):
B = 2
T = 2
S = 10
C = 10

frames = torch.randint(1, T, (B,))
seq_lengths = torch.randint(1, S, (B,))
T = torch.max(frames)
S = torch.max(seq_lengths)

am_ = torch.randn((B, T, C), dtype=torch.float64)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float64)
symbols_ = torch.randint(0, C, (B, S))
terminal_symbol = C - 1

boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_lengths
boundary_[:, 3] = frames

print(f"B = {B}, T = {T}, S = {S}, C = {C}")

for rnnt_type in ["regular"]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
lm = lm_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)

logits = am.unsqueeze(2) + lm.unsqueeze(1)
logits = logits.float()

# nonlinear transform
logits = torch.sigmoid(logits)

loss = fast_rnnt.rnnt_loss(
logits=logits,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
reduction="none",
)

print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}")

# pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
return_grad=True,
reduction="none",
)

S0 = 2

for r in range(S0, S + 2):
ranges = fast_rnnt.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=r,
)
# (B, T, r, C)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)

logits = pruned_am + pruned_lm

# nonlinear transform
logits = torch.sigmoid(logits)

pruned_loss = fast_rnnt.rnnt_loss_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
reduction="none",
)
assert (
not pruned_loss.isinf().any()
), f"Pruned loss is inf for r={r}, S={S}, T={T}: {pruned_loss}"
print(f"Pruned loss with range {r} : {pruned_loss}")


if __name__ == "__main__":
unittest.main()

0 comments on commit b48d47e

Please sign in to comment.