Skip to content

Commit

Permalink
Add top-p and top-k sampling
Browse files Browse the repository at this point in the history
Address PR comments and add Temperature
  • Loading branch information
yohann-benchetrit committed Apr 7, 2023
1 parent f5bc19d commit 1fc72a2
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 3 deletions.
62 changes: 62 additions & 0 deletions test/integration_tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,65 @@ def test_warns_when_no_max_len_provided(self, mock) -> None:
generation_model = GenerationUtils(self.model)
generation_model.generate(self.transformed_inputs)
mock.assert_called_with(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")

def test_get_top_k_restriction(self) -> None:
generation_model = GenerationUtils(self.model)
scores = torch.arange(0, 20).reshape(2, -1)

top_k = 5
expected = torch.zeros_like(scores)
expected[0, 5:] = torch.arange(5, 10)
expected[1, 5:] = torch.arange(15, 20)
output = generation_model._get_top_k_restriction(scores, top_k)
self.assertEqual(output, expected)

top_k = 11
output = generation_model._get_top_k_restriction(scores, top_k)
self.assertEqual(output, scores)

top_k = 0
with self.assertRaises(ValueError):
generation_model._get_top_k_restriction(scores, top_k)

def test_get_top_p_restriction(self) -> None:
generation_model = GenerationUtils(self.model)
input_probs = (torch.arange(1, 5) / 10).repeat(2).reshape(2, -1)

top_p = 0.75
expected = torch.tensor([0, 0, 0.3, 0.4]).repeat(2).reshape(2, -1)
output = generation_model._get_top_p_restriction(input_probs, top_p)
self.assertEqual(output, expected)

# test min_to_keep parameter
top_p = 0.2

# min_tokens_to_keep defaults to 1
expected_default = torch.tensor([0, 0, 0, 0.4]).repeat(2).reshape(2, -1)
output_default = generation_model._get_top_p_restriction(input_probs, top_p)
self.assertEqual(output_default, expected_default)

output_with_min_2 = generation_model._get_top_p_restriction(input_probs, top_p, min_tokens_to_keep=2)
self.assertEqual(output_with_min_2, expected)

top_p = -1
with self.assertRaises(ValueError):
generation_model._get_top_p_restriction(input_probs, top_p)

def test_apply_temperature(self) -> None:
generation_model = GenerationUtils(self.model)
input_probs = torch.ones((2, 5))

# testing valid temperature
temperature = 2
valid_output = generation_model._apply_temperature(input_probs, temperature)
expected_output = torch.ones((2, 5)) / 2

self.assertEqual(valid_output, expected_output)

# testing invalid temperature
temperature = 0
with self.assertRaises(ValueError):
generation_model._apply_temperature(input_probs, temperature)
temperature = -1
with self.assertRaises(ValueError):
generation_model._apply_temperature(input_probs, temperature)
72 changes: 69 additions & 3 deletions torchtext/prototype/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ def _prepare_decoder_ids_for_generation(
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx

def greedy_search(
self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: int, **model_kwargs
self,
input_ids: torch.Tensor,
max_length: int,
eos_idx: int,
pad_idx: int,
**model_kwargs,
) -> torch.Tensor:
"""Greedy search decoding for text generation. Takes the most likely next token every time.
Expand Down Expand Up @@ -117,7 +122,6 @@ def generate(
max_length (int): Max length to generate responses.
pad_idx (int): Padding index. Defaults to 0.
eos_idx (int): End of sequence index. Defaults to 1.
Returns:
Tensor of Tensors containing output sequences as ids.
Expand All @@ -138,8 +142,70 @@ def generate(
max_length = DEFAULT_MAX_SEQ_LEN

if num_beams == 1 or num_beams is None:
return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, **model_kwargs)
return self.greedy_search(
inputs,
max_length,
eos_idx,
pad_idx=pad_idx,
**model_kwargs,
)
elif num_beams > 1:
return self.beam_search(inputs, num_beams, max_length)
else:
raise ValueError("`num_beams` must be >= 1.")

def _get_top_k_restriction(self, scores: torch.Tensor, top_k: int) -> torch.Tensor:
"""Returns a copy of `scores` restricted to its k highest values (meaning every other value is zeroed)
Args:
scores (Tensor): typically the output logits or probabilities for a language model's vocabulary.
top_k (int): the number of highest values to keep.
Returns:
A copy of `scores` restricted to its k highest values
"""
top_k = min(top_k, scores.size(-1))
if top_k <= 0:
raise ValueError(f"`top_k` is {top_k} but should be an int greater than 0")
indices_to_remove = scores < torch.topk(scores, top_k)[0][:, -1, None]
return scores.masked_fill(indices_to_remove, 0)

def _get_top_p_restriction(self, probs: torch.Tensor, top_p: float, min_tokens_to_keep: int = 1) -> torch.Tensor:
"""Returns a copy of `probs` restricted to the top indices whose values sum up to `top_p`
(meaning the value at any other index is zeroed)
Args:
probs (Tensor): output probabilities for a language model vocabulary.
top_p (float): the (cumulative) threshold for cutting off top-value indices; between 0 and 1.
Returns:
A copy of `probs` restricted to the top indices whose values sum up to `top_p`
"""
if top_p < 0 or top_p > 1:
raise ValueError(f"`top_p` is {top_p} but should be a float between 0 and 1")
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
cumulative_probs = sorted_probs.cumsum(dim=-1)

sorted_indices_to_keep = cumulative_probs <= top_p
sorted_indices_to_keep[:, :min_tokens_to_keep] = True
indices_to_remove = ~sorted_indices_to_keep.scatter(-1, sorted_indices, sorted_indices_to_keep)

return probs.masked_fill(indices_to_remove, 0)

def _apply_temperature(self, probs: torch.Tensor, temperature: float) -> torch.Tensor:
"""Applies temperature scaling to `probs`
Args:
probs (Tensor): output probabilities for a language model vocabulary.
temperature (float): value of temperature applied to the distribution.
Returns:
A copy of `probs` with applied `temperature`
"""
if not temperature > 0:
raise ValueError(f"`temperature` is {temperature} but should be positive")
return probs / temperature

def _remove_invalid_values(self, scores: torch.Tensor) -> torch.Tensor:
"""Removes nan and inf values to prevent generation from failing when using sampling"""
scores[scores != scores] = 0.0
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
return scores

0 comments on commit 1fc72a2

Please sign in to comment.