-
Notifications
You must be signed in to change notification settings - Fork 517
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merged generate.py and full.py and removed generate.py
- Loading branch information
Showing
8 changed files
with
101 additions
and
189 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import lightning as L | ||
import torch | ||
from typing import Optional | ||
from lit_llama import LLaMA | ||
|
||
@torch.no_grad() | ||
def generate( | ||
model: LLaMA, | ||
idx: torch.Tensor, | ||
max_new_tokens: int, | ||
*, | ||
max_seq_length: Optional[int] = None, | ||
temperature: float = 1.0, | ||
top_k: Optional[int] = None, | ||
eos_id: Optional[int] = None, | ||
) -> torch.Tensor: | ||
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. | ||
The implementation of this function is modified from A. Karpathy's nanoGPT. | ||
Args: | ||
model: The model to use. | ||
idx: Tensor of shape (T) with indices of the prompt sequence. | ||
max_new_tokens: The number of new tokens to generate. | ||
max_seq_length: The maximum sequence length allowed. | ||
temperature: Scales the predicted logits by 1 / temperature | ||
top_k: If specified, only sample among the tokens with the k highest probabilities | ||
eos_id: If specified, stop generating any more token once the <eos> token is triggered | ||
""" | ||
# create an empty tensor of the expected final shape and fill in the current tokens | ||
T = idx.size(0) | ||
T_new = T + max_new_tokens | ||
if max_seq_length is None: | ||
max_seq_length = min(T_new, model.config.block_size) | ||
|
||
device, dtype = idx.device, idx.dtype | ||
# create an empty tensor of the expected final shape and fill in the current tokens | ||
empty = torch.empty(T_new, dtype=dtype, device=device) | ||
empty[:T] = idx | ||
idx = empty | ||
input_pos = torch.arange(0, T, device=device) | ||
|
||
if idx.device.type == "xla": | ||
import torch_xla.core.xla_model as xm | ||
|
||
xm.mark_step() | ||
|
||
# generate max_new_tokens tokens | ||
for _ in range(max_new_tokens): | ||
x = idx.index_select(0, input_pos).view(1, -1) | ||
|
||
# forward | ||
logits = model(x, max_seq_length, input_pos) | ||
logits = logits[0, -1] / temperature | ||
|
||
# optionally crop the logits to only the top k options | ||
if top_k is not None: | ||
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | ||
logits = torch.where(logits < v[[-1]], -float("Inf"), logits) | ||
|
||
probs = torch.nn.functional.softmax(logits, dim=-1) | ||
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) | ||
|
||
# advance | ||
input_pos = input_pos[-1:] + 1 | ||
|
||
if idx.device.type == "xla": | ||
xm.mark_step() | ||
|
||
# concatenate the new generation | ||
idx = idx.index_copy(0, input_pos, idx_next) | ||
|
||
# if <eos> token is triggered, return the output (stop generation) | ||
if idx_next == eos_id: | ||
return idx[:input_pos] # include the EOS token | ||
|
||
return idx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters