Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 1fc72a2

Browse files
Add top-p and top-k sampling
Address PR comments and add Temperature
1 parent f5bc19d commit 1fc72a2

File tree

2 files changed

+131
-3
lines changed

2 files changed

+131
-3
lines changed

test/integration_tests/test_generate.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,65 @@ def test_warns_when_no_max_len_provided(self, mock) -> None:
5858
generation_model = GenerationUtils(self.model)
5959
generation_model.generate(self.transformed_inputs)
6060
mock.assert_called_with(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")
61+
62+
def test_get_top_k_restriction(self) -> None:
63+
generation_model = GenerationUtils(self.model)
64+
scores = torch.arange(0, 20).reshape(2, -1)
65+
66+
top_k = 5
67+
expected = torch.zeros_like(scores)
68+
expected[0, 5:] = torch.arange(5, 10)
69+
expected[1, 5:] = torch.arange(15, 20)
70+
output = generation_model._get_top_k_restriction(scores, top_k)
71+
self.assertEqual(output, expected)
72+
73+
top_k = 11
74+
output = generation_model._get_top_k_restriction(scores, top_k)
75+
self.assertEqual(output, scores)
76+
77+
top_k = 0
78+
with self.assertRaises(ValueError):
79+
generation_model._get_top_k_restriction(scores, top_k)
80+
81+
def test_get_top_p_restriction(self) -> None:
82+
generation_model = GenerationUtils(self.model)
83+
input_probs = (torch.arange(1, 5) / 10).repeat(2).reshape(2, -1)
84+
85+
top_p = 0.75
86+
expected = torch.tensor([0, 0, 0.3, 0.4]).repeat(2).reshape(2, -1)
87+
output = generation_model._get_top_p_restriction(input_probs, top_p)
88+
self.assertEqual(output, expected)
89+
90+
# test min_to_keep parameter
91+
top_p = 0.2
92+
93+
# min_tokens_to_keep defaults to 1
94+
expected_default = torch.tensor([0, 0, 0, 0.4]).repeat(2).reshape(2, -1)
95+
output_default = generation_model._get_top_p_restriction(input_probs, top_p)
96+
self.assertEqual(output_default, expected_default)
97+
98+
output_with_min_2 = generation_model._get_top_p_restriction(input_probs, top_p, min_tokens_to_keep=2)
99+
self.assertEqual(output_with_min_2, expected)
100+
101+
top_p = -1
102+
with self.assertRaises(ValueError):
103+
generation_model._get_top_p_restriction(input_probs, top_p)
104+
105+
def test_apply_temperature(self) -> None:
106+
generation_model = GenerationUtils(self.model)
107+
input_probs = torch.ones((2, 5))
108+
109+
# testing valid temperature
110+
temperature = 2
111+
valid_output = generation_model._apply_temperature(input_probs, temperature)
112+
expected_output = torch.ones((2, 5)) / 2
113+
114+
self.assertEqual(valid_output, expected_output)
115+
116+
# testing invalid temperature
117+
temperature = 0
118+
with self.assertRaises(ValueError):
119+
generation_model._apply_temperature(input_probs, temperature)
120+
temperature = -1
121+
with self.assertRaises(ValueError):
122+
generation_model._apply_temperature(input_probs, temperature)

torchtext/prototype/generate.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ def _prepare_decoder_ids_for_generation(
4848
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx
4949

5050
def greedy_search(
51-
self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: int, **model_kwargs
51+
self,
52+
input_ids: torch.Tensor,
53+
max_length: int,
54+
eos_idx: int,
55+
pad_idx: int,
56+
**model_kwargs,
5257
) -> torch.Tensor:
5358
"""Greedy search decoding for text generation. Takes the most likely next token every time.
5459
@@ -117,7 +122,6 @@ def generate(
117122
max_length (int): Max length to generate responses.
118123
pad_idx (int): Padding index. Defaults to 0.
119124
eos_idx (int): End of sequence index. Defaults to 1.
120-
121125
Returns:
122126
Tensor of Tensors containing output sequences as ids.
123127
@@ -138,8 +142,70 @@ def generate(
138142
max_length = DEFAULT_MAX_SEQ_LEN
139143

140144
if num_beams == 1 or num_beams is None:
141-
return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, **model_kwargs)
145+
return self.greedy_search(
146+
inputs,
147+
max_length,
148+
eos_idx,
149+
pad_idx=pad_idx,
150+
**model_kwargs,
151+
)
142152
elif num_beams > 1:
143153
return self.beam_search(inputs, num_beams, max_length)
144154
else:
145155
raise ValueError("`num_beams` must be >= 1.")
156+
157+
def _get_top_k_restriction(self, scores: torch.Tensor, top_k: int) -> torch.Tensor:
158+
"""Returns a copy of `scores` restricted to its k highest values (meaning every other value is zeroed)
159+
160+
Args:
161+
scores (Tensor): typically the output logits or probabilities for a language model's vocabulary.
162+
top_k (int): the number of highest values to keep.
163+
164+
Returns:
165+
A copy of `scores` restricted to its k highest values
166+
"""
167+
top_k = min(top_k, scores.size(-1))
168+
if top_k <= 0:
169+
raise ValueError(f"`top_k` is {top_k} but should be an int greater than 0")
170+
indices_to_remove = scores < torch.topk(scores, top_k)[0][:, -1, None]
171+
return scores.masked_fill(indices_to_remove, 0)
172+
173+
def _get_top_p_restriction(self, probs: torch.Tensor, top_p: float, min_tokens_to_keep: int = 1) -> torch.Tensor:
174+
"""Returns a copy of `probs` restricted to the top indices whose values sum up to `top_p`
175+
(meaning the value at any other index is zeroed)
176+
177+
Args:
178+
probs (Tensor): output probabilities for a language model vocabulary.
179+
top_p (float): the (cumulative) threshold for cutting off top-value indices; between 0 and 1.
180+
181+
Returns:
182+
A copy of `probs` restricted to the top indices whose values sum up to `top_p`
183+
"""
184+
if top_p < 0 or top_p > 1:
185+
raise ValueError(f"`top_p` is {top_p} but should be a float between 0 and 1")
186+
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
187+
cumulative_probs = sorted_probs.cumsum(dim=-1)
188+
189+
sorted_indices_to_keep = cumulative_probs <= top_p
190+
sorted_indices_to_keep[:, :min_tokens_to_keep] = True
191+
indices_to_remove = ~sorted_indices_to_keep.scatter(-1, sorted_indices, sorted_indices_to_keep)
192+
193+
return probs.masked_fill(indices_to_remove, 0)
194+
195+
def _apply_temperature(self, probs: torch.Tensor, temperature: float) -> torch.Tensor:
196+
"""Applies temperature scaling to `probs`
197+
Args:
198+
probs (Tensor): output probabilities for a language model vocabulary.
199+
temperature (float): value of temperature applied to the distribution.
200+
Returns:
201+
A copy of `probs` with applied `temperature`
202+
"""
203+
if not temperature > 0:
204+
raise ValueError(f"`temperature` is {temperature} but should be positive")
205+
return probs / temperature
206+
207+
def _remove_invalid_values(self, scores: torch.Tensor) -> torch.Tensor:
208+
"""Removes nan and inf values to prevent generation from failing when using sampling"""
209+
scores[scores != scores] = 0.0
210+
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
211+
return scores

0 commit comments

Comments
 (0)