Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add top-p and top-k sampling to GenerationUtils #2137

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions test/integration_tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,65 @@ def test_warns_when_no_max_len_provided(self, mock) -> None:
generation_model = GenerationUtils(self.model)
generation_model.generate(self.transformed_inputs)
mock.assert_called_with(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")

def test_get_top_k_restriction(self) -> None:
generation_model = GenerationUtils(self.model)
scores = torch.arange(0, 20).reshape(2, -1)

top_k = 5
expected = torch.zeros_like(scores)
expected[0, 5:] = torch.arange(5, 10)
expected[1, 5:] = torch.arange(15, 20)
output = generation_model._get_top_k_restriction(scores, top_k)
self.assertEqual(output, expected)

top_k = 11
output = generation_model._get_top_k_restriction(scores, top_k)
self.assertEqual(output, scores)

top_k = 0
with self.assertRaises(ValueError):
generation_model._get_top_k_restriction(scores, top_k)

def test_get_top_p_restriction(self) -> None:
generation_model = GenerationUtils(self.model)
input_probs = (torch.arange(1, 5) / 10).repeat(2).reshape(2, -1)

top_p = 0.75
expected = torch.tensor([0, 0, 0.3, 0.4]).repeat(2).reshape(2, -1)
output = generation_model._get_top_p_restriction(input_probs, top_p)
self.assertEqual(output, expected)

# test min_to_keep parameter
top_p = 0.2

# min_tokens_to_keep defaults to 1
expected_default = torch.tensor([0, 0, 0, 0.4]).repeat(2).reshape(2, -1)
output_default = generation_model._get_top_p_restriction(input_probs, top_p)
self.assertEqual(output_default, expected_default)

output_with_min_2 = generation_model._get_top_p_restriction(input_probs, top_p, min_tokens_to_keep=2)
self.assertEqual(output_with_min_2, expected)

top_p = -1
with self.assertRaises(ValueError):
generation_model._get_top_p_restriction(input_probs, top_p)

def test_apply_temperature(self) -> None:
generation_model = GenerationUtils(self.model)
input_probs = torch.ones((2, 5))

# testing valid temperature
temperature = 2
valid_output = generation_model._apply_temperature(input_probs, temperature)
expected_output = torch.ones((2, 5)) / 2

self.assertEqual(valid_output, expected_output)

# testing invalid temperature
temperature = 0
with self.assertRaises(ValueError):
generation_model._apply_temperature(input_probs, temperature)
temperature = -1
with self.assertRaises(ValueError):
generation_model._apply_temperature(input_probs, temperature)
72 changes: 69 additions & 3 deletions torchtext/prototype/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ def _prepare_decoder_ids_for_generation(
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx

def greedy_search(
self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: int, **model_kwargs
self,
input_ids: torch.Tensor,
max_length: int,
eos_idx: int,
pad_idx: int,
**model_kwargs,
) -> torch.Tensor:
"""Greedy search decoding for text generation. Takes the most likely next token every time.

Expand Down Expand Up @@ -117,7 +122,6 @@ def generate(
max_length (int): Max length to generate responses.
pad_idx (int): Padding index. Defaults to 0.
eos_idx (int): End of sequence index. Defaults to 1.

Returns:
Tensor of Tensors containing output sequences as ids.

Expand All @@ -138,8 +142,70 @@ def generate(
max_length = DEFAULT_MAX_SEQ_LEN

if num_beams == 1 or num_beams is None:
return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, **model_kwargs)
return self.greedy_search(
inputs,
max_length,
eos_idx,
pad_idx=pad_idx,
**model_kwargs,
)
elif num_beams > 1:
return self.beam_search(inputs, num_beams, max_length)
else:
raise ValueError("`num_beams` must be >= 1.")

def _get_top_k_restriction(self, scores: torch.Tensor, top_k: int) -> torch.Tensor:
"""Returns a copy of `scores` restricted to its k highest values (meaning every other value is zeroed)

Args:
scores (Tensor): typically the output logits or probabilities for a language model's vocabulary.
top_k (int): the number of highest values to keep.

Returns:
A copy of `scores` restricted to its k highest values
"""
top_k = min(top_k, scores.size(-1))
if top_k <= 0:
raise ValueError(f"`top_k` is {top_k} but should be an int greater than 0")
indices_to_remove = scores < torch.topk(scores, top_k)[0][:, -1, None]
return scores.masked_fill(indices_to_remove, 0)

def _get_top_p_restriction(self, probs: torch.Tensor, top_p: float, min_tokens_to_keep: int = 1) -> torch.Tensor:
"""Returns a copy of `probs` restricted to the top indices whose values sum up to `top_p`
(meaning the value at any other index is zeroed)

Args:
probs (Tensor): output probabilities for a language model vocabulary.
top_p (float): the (cumulative) threshold for cutting off top-value indices; between 0 and 1.

Returns:
A copy of `probs` restricted to the top indices whose values sum up to `top_p`
"""
if top_p < 0 or top_p > 1:
raise ValueError(f"`top_p` is {top_p} but should be a float between 0 and 1")
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
cumulative_probs = sorted_probs.cumsum(dim=-1)

sorted_indices_to_keep = cumulative_probs <= top_p
sorted_indices_to_keep[:, :min_tokens_to_keep] = True
indices_to_remove = ~sorted_indices_to_keep.scatter(-1, sorted_indices, sorted_indices_to_keep)

return probs.masked_fill(indices_to_remove, 0)

def _apply_temperature(self, probs: torch.Tensor, temperature: float) -> torch.Tensor:
"""Applies temperature scaling to `probs`
Args:
probs (Tensor): output probabilities for a language model vocabulary.
temperature (float): value of temperature applied to the distribution.
Returns:
A copy of `probs` with applied `temperature`
"""
if not temperature > 0:
raise ValueError(f"`temperature` is {temperature} but should be positive")
return probs / temperature

def _remove_invalid_values(self, scores: torch.Tensor) -> torch.Tensor:
"""Removes nan and inf values to prevent generation from failing when using sampling"""
scores[scores != scores] = 0.0
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
return scores