Skip to content

Commit

Permalink
Expose sample from generation utils
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Sep 12, 2024
1 parent d7fae96 commit 32e8f8c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/source/api_ref_generation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ torchtune.generation
:nosignatures:

generate
sample
4 changes: 2 additions & 2 deletions torchtune/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._generation import generate, generate_next_token
from ._generation import generate, generate_next_token, sample

__all__ = ["generate", "generate_next_token"]
__all__ = ["generate", "generate_next_token", "sample"]
13 changes: 11 additions & 2 deletions torchtune/generation/_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,18 @@ def multinomial_sample_one(probs: torch.Tensor) -> torch.Tensor:


def sample(
logits: torch.Tensor, temperature: float = 1.0, top_k: int = None
logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None
) -> torch.Tensor:
"""Generic sample from a probability distribution."""
"""Generic sample from a probability distribution.
Args:
logits (torch.Tensor): logits from which to sample
temperature (float): value to scale the predicted logits by, default 1.0.
top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities
Returns:
torch.Tensor: sampled token id
"""
# scale the logits based on temperature
logits = logits / max(temperature, 1e-5)
if top_k is not None:
Expand Down

0 comments on commit 32e8f8c

Please sign in to comment.