Skip to content

Commit

Permalink
Merged generate.py and full.py and removed generate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ruoyu61 committed Jul 26, 2023
1 parent 03f5d5e commit 47c38a7
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 189 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ To generate text predictions, you need to download the model weights. **If you d
Run inference:

```bash
python generate.py --prompt "Hello, my name is"
python generate/full.py --prompt "Hello, my name is"
```

This will run the 7B model and require ~26 GB of GPU memory (A100 GPU).
Expand All @@ -86,14 +86,14 @@ This will run the 7B model and require ~26 GB of GPU memory (A100 GPU).

### Run Lit-LLaMA on consumer devices

On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB.
On GPUs with `bfloat16` support, the `full.py` script will automatically convert the weights and consume about ~14 GB.
For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`):

```bash
python generate.py --quantize llm.int8 --prompt "Hello, my name is"
python generate/full.py --quantize llm.int8 --prompt "Hello, my name is"
```

See `python generate.py --help` for more options.
See `python generate/full.py --help` for more options.

You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first:

Expand Down
170 changes: 0 additions & 170 deletions generate.py

This file was deleted.

2 changes: 1 addition & 1 deletion generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from generate import generate
from generate.generate_utils import generate
from lit_llama import Tokenizer
from lit_llama.adapter import LLaMA
from lit_llama.utils import lazy_load, llama_model_lookup, quantization
Expand Down
2 changes: 1 addition & 1 deletion generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from generate import generate
from generate.generate_utils import generate
from lit_llama import Tokenizer
from lit_llama.adapter import LLaMA
from lit_llama.utils import lazy_load, llama_model_lookup, quantization
Expand Down
23 changes: 14 additions & 9 deletions generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
sys.path.append(str(wd))

from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import quantization
from lit_llama.utils import quantization, lazy_load, llama_model_lookup
from scripts.prepare_alpaca import generate_prompt
from generate import generate

from generate.generate_utils import generate

def main(
prompt: str = "Hello, my name is",
Expand All @@ -28,6 +27,7 @@ def main(
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
model_size: str = "7B",
quantize: Optional[str] = None,
instruction_tuning: Optional[bool] = False
) -> None:
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
Expand All @@ -44,6 +44,7 @@ def main(
quantize: Whether to quantize the model and using which method:
``"llm.int8"``: LLM.int8() mode,
``"gptq.int4"``: GPTQ 4-bit mode.
instruction_tuning: Whether to regenerate sample in instruction turning format.
"""
if not checkpoint_path:
checkpoint_path = Path(f"checkpoints/lit-llama/{model_size}/lit-llama.pth")
Expand All @@ -56,19 +57,23 @@ def main(
print("Loading model ...", file=sys.stderr)
t0 = time.time()

with fabric.init_module(empty_init=True), quantization(mode=quantize):
model = LLaMA.from_name(model_size)
with lazy_load(checkpoint_path) as checkpoint:
name = llama_model_lookup(checkpoint)

with fabric.init_module(empty_init=True), quantization(mode=quantize):
model = LLaMA.from_name(name)

checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
model = fabric.setup(model)

tokenizer = Tokenizer(tokenizer_path)
sample = {"instruction": prompt, "input": input}
prompt = generate_prompt(sample)

if instruction_tuning:
sample = {"instruction": prompt, "input": input}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
prompt_length = encoded.size(0)

Expand Down
77 changes: 77 additions & 0 deletions generate/generate_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import lightning as L
import torch
from typing import Optional
from lit_llama import LLaMA

@torch.no_grad()
def generate(
model: LLaMA,
idx: torch.Tensor,
max_new_tokens: int,
*,
max_seq_length: Optional[int] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
) -> torch.Tensor:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
The implementation of this function is modified from A. Karpathy's nanoGPT.
Args:
model: The model to use.
idx: Tensor of shape (T) with indices of the prompt sequence.
max_new_tokens: The number of new tokens to generate.
max_seq_length: The maximum sequence length allowed.
temperature: Scales the predicted logits by 1 / temperature
top_k: If specified, only sample among the tokens with the k highest probabilities
eos_id: If specified, stop generating any more token once the <eos> token is triggered
"""
# create an empty tensor of the expected final shape and fill in the current tokens
T = idx.size(0)
T_new = T + max_new_tokens
if max_seq_length is None:
max_seq_length = min(T_new, model.config.block_size)

device, dtype = idx.device, idx.dtype
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty[:T] = idx
idx = empty
input_pos = torch.arange(0, T, device=device)

if idx.device.type == "xla":
import torch_xla.core.xla_model as xm

xm.mark_step()

# generate max_new_tokens tokens
for _ in range(max_new_tokens):
x = idx.index_select(0, input_pos).view(1, -1)

# forward
logits = model(x, max_seq_length, input_pos)
logits = logits[0, -1] / temperature

# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)

# advance
input_pos = input_pos[-1:] + 1

if idx.device.type == "xla":
xm.mark_step()

# concatenate the new generation
idx = idx.index_copy(0, input_pos, idx_next)

# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
return idx[:input_pos] # include the EOS token

return idx
2 changes: 1 addition & 1 deletion generate/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from generate import generate
from lit_llama import Tokenizer, LLaMA
from lit_llama.lora import lora
from lit_llama.utils import lazy_load, llama_model_lookup
from generate.generate_utils import generate
from scripts.prepare_alpaca import generate_prompt

lora_r = 8
Expand Down
6 changes: 3 additions & 3 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
def load_generate_script():
sys.path.append(str(wd))

import generate as generate
from generate import full

return generate
return full


def test_generate():
Expand Down Expand Up @@ -111,7 +111,7 @@ def init_module(self, empty_init):


def test_cli():
cli_path = wd / "generate.py"
cli_path = wd / "generate/full.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
output = str(output.decode())
assert "Generates text samples" in output

0 comments on commit 47c38a7

Please sign in to comment.