diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index 660fa04bb08d..3daf93ac0ed2 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -303,7 +303,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf'), started else: logits[indices_to_remove] = filter_value - if top_p > 0.0: + if 0.0 < top_p < 1.0: # Cconvert to 1D sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)