Skip to content
This repository has been archived by the owner on Sep 6, 2022. It is now read-only.

warp-transducer - Is rnnt_loss() "internal state" causing wrong loss computation? #36

Closed
stefan-falk opened this issue Jul 10, 2020 · 1 comment

Comments

@stefan-falk
Copy link

So, I have tried to implement the model more or less "from scratch" on the basis of this repository.

For that I have implemented a training loop which I am executing eagerly to be able to debug.

However, in doing so I noticed that my loss is just jumping around weirdly and I still haven't figured out exactly why. In order to get insight I intended to take a closer look towards the rnnt_loss() function. While I executed some simple text-examples, I noticed that calling rnnt_losss() repeatedly on the same input, the loss is always different. But not just that: It's monotonically increasing.

The code I am running:

from warprnnt_tensorflow import rnnt_loss
import numpy as np


def main():
    acts = np.asarray([
        [
            [[0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0]],
            [[0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0]],
            [[0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0]],
        ]
    ])

    labels = np.asarray([[1, 2, 0]])
    label_lengths = [len(t) for t in labels]

    for i in range(10):
        loss = rnnt_loss(
            acts=acts,
            labels=labels,
            input_lengths=label_lengths,
            label_lengths=label_lengths
        )
        print(np.mean(loss))


if __name__ == '__main__':
    main()

Output:

1.0986123
2.1490226
5.274593
6.7222075
9.581686
11.274273
13.95323
15.808798
18.36151
20.329256

I am on tensorflow==2.2.0 and I compiled the warp-transducer with GPU support.

@noahchalifour
Copy link
Owner

@stefan-falk The acts dim 3 needs to be >= label_lengths as per HawkAaron/warp-transducer#72

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants