From ac3d642007b1f4ce4c19cc788a25668d106f1edc Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 6 Jan 2025 17:51:48 -0800 Subject: [PATCH] Use `torch.logsumexp` in advanced_tutorial.py `torch.logsumexp` is numerically stabilized: https://pytorch.org/docs/stable/generated/torch.logsumexp.html Found with TorchFix https://github.com/pytorch-labs/torchfix/ --- beginner_source/nlp/advanced_tutorial.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/beginner_source/nlp/advanced_tutorial.py b/beginner_source/nlp/advanced_tutorial.py index a6c6857128..1866142d6a 100644 --- a/beginner_source/nlp/advanced_tutorial.py +++ b/beginner_source/nlp/advanced_tutorial.py @@ -142,8 +142,7 @@ def prepare_sequence(seq, to_ix): def log_sum_exp(vec): max_score = vec[0, argmax(vec)] max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1]) - return max_score + \ - torch.log(torch.sum(torch.exp(vec - max_score_broadcast))) + return max_score + torch.logsumexp(vec - max_score_broadcast) ##################################################################### # Create model