API improvement of use_softmax and zero_infinity #180
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
From a suggestion proposed by PaddlePaddle community, it is recommended to add
use_softmax
andzero_infinity
options to warpctc, comparing to the Pytorch API oftorch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)
.I made an attempt to add this two options in this PR, and the logic is shown the following:
use_softmax
: The attribute oflog_probs
in Pytorch receives the output of logsoftmax operation. Therefore, it is suggested when the users want to use the output from logsoftmax operation as an input, we could omit the softmax operation in our code. The proposed improvement is whenuse_softmax=False
, perform exponential operation on the input logits.zero_infinity
: It is hoped to zero the infinity cost and associated gradient whenzero_infinity=True
. So first to make infinity value of cost available, I omit the truncation process done to the logits. And then I do zero operation after the calculation of cost and grad for each batch of the input.I also add the corresponding test for this two cases. (PS. The test of
inf_test
seems only to pass when the truncation processes are omited.)