Skip to content

Commit

Permalink
TP llama with continuous batching (#2709)
Browse files Browse the repository at this point in the history
* Enabled left side padding in llama2 model

* Add unit test for tp_llama

* Use no_grad context manager

* [WIP] Converting llama_tp into cont batching handler

* [WIP]Implement prefill and decode for tp_llama

* WIP fixing decode

* Fix return prefill vs decode format in tp_llama handler

* Fix current_position index in llama_handler

* Fix kv caching issue with different results for different padding lengths

* Fix liniting error

* Remove cuda dependency for tp_llama test

* fix handler mock

* Adjust expected result for 13b model in tp llama test

* Make continuous batching work with tp llama

* Adjust sample txt

* Add missing requirements

* Add support for chat dialogs

* Use model archiver config in tp_llama test

---------

Co-authored-by: Hamid Shojanazeri <hamid.nazeri2010@gmail.com>
  • Loading branch information
mreso and HamidShojanazeri authored Dec 14, 2023
1 parent be5ff32 commit df94a56
Show file tree
Hide file tree
Showing 9 changed files with 1,373 additions and 392 deletions.
8 changes: 4 additions & 4 deletions examples/large_models/tp_llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Make sure to have PyTorch Nighlies installed.
```
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
pip install transformers
pip install transformers fire sentencepiece
```

Expand All @@ -53,7 +53,7 @@ The script prints the path where the model is downloaded as below.

### Step 3: Convert the "Meta" checkpoints to PyTorch Distributed compliant checkpoints

Convert the checkpoints to PT-D compliant checkpoints as follows, note that for 7B `--model_parallel_size 1` for 13B would be `--model_parallel_size 2` and 70B `model_parallel_size 8`, you can also set `--nproc_per_node ` accordingly. PT-D compliant support flexible world_size when loading back the checkpoints into TP(lized) model.
Convert the checkpoints to PT-D compliant checkpoints as follows, note that for 7B `--model_parallel_size 1` for 13B would be `--model_parallel_size 2` and 70B `model_parallel_size 8`, you can also set `--nproc_per_node ` accordingly. PT-D compliant support flexible world_size when loading back the checkpoints into TP(lized) model.

You would be able to use larger number of processes/ TP size when load the model back. For example if you have converted the `13B` checkpoints with `--nproc_per_node 2`, during the inference you can use `--nproc_per_node` be `[2, max_num_available_gpu]` which you are changing the world_size and effectively the TP size. The recommendation here is to keep the TP size as shown above respective to model size, 7B (TP Size =1), 13B (TP Size =2), 70B (TP Size =8), unless your benchmark and your batch size/ compute load compensate for communication cost.

Expand All @@ -69,7 +69,7 @@ torchrun --nnodes 1 --nproc_per_node 8 convert_checkpoints.py --original_ckpt_di

### Step 4: set up the configs:

Lets setup configs in `model-config.yaml`
Lets setup configs in `model-config.yaml`

```
#frontend settings
Expand Down Expand Up @@ -97,7 +97,7 @@ handler:
```

### step 5: Create the mar file:
Create the mar file using the following command here.
Create the mar file using the following command here.

```
torch-model-archiver --model-name llama --version 1.0 --handler llama-handler.py --config-file model-config.yaml --archive-format no-archive --extra-files "llama2.py,llama2_tokenizer.py,generate.py,checkpoint_converter.py"
Expand Down
10 changes: 6 additions & 4 deletions examples/large_models/tp_llama/dialogs.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
[
{
"dialog":
[
{
"role": "user",
"content": "what is the recipe of mayonnaise?"
}
]

]
],
"max_new_tokens": 50,
"mode":"chat"
}
223 changes: 114 additions & 109 deletions examples/large_models/tp_llama/generate.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import torch
from llama2 import Llama
import torch.distributed as dist
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import List, Literal, Optional, Tuple, TypedDict
import abc
import logging
import os
import sys
import fire
from typing import List, Literal, Optional, Tuple, TypedDict

import torch

current_working_directory = os.getcwd()
sys.path.insert(0,current_working_directory)
sys.path.insert(0, current_working_directory)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

Role = Literal["system", "user", "assistant"]

Expand Down Expand Up @@ -38,6 +39,7 @@ class ChatPrediction(TypedDict, total=False):
SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."


def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Expand All @@ -62,110 +64,116 @@ def sample_top_p(probs, p):
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token



# @torch.inference_mode()
with torch.no_grad():
def generate(model,
tokenizer,
prompt_tokens: List[List[int]],
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
"""
Generate text sequences based on provided prompts using the language generation model.
Args:
prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
max_gen_len (int): Maximum length of the generated text sequence.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
Returns:
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
Note:
This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
If logprobs is True, token log probabilities are computed for each generated token.
"""
bsz = len(prompt_tokens)
assert bsz <= model.max_batch_size, (bsz, model.max_batch_size)

min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= model.max_seq_len
total_len = min(model.max_seq_len, max_gen_len + max_prompt_len)

pad_id = tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
@torch.no_grad()
def generate(
model,
tokenizer,
prompt_tokens: List[List[int]],
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
"""
Generate text sequences based on provided prompts using the language generation model.
Args:
prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
max_gen_len (int): Maximum length of the generated text sequence.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
Returns:
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
Note:
This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
If logprobs is True, token log probabilities are computed for each generated token.
"""
bsz = len(prompt_tokens)
assert bsz <= model.max_batch_size, (bsz, model.max_batch_size)

min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= model.max_seq_len
total_len = min(model.max_seq_len, max_gen_len + max_prompt_len)

pad_id = tokenizer.eos_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, max_prompt_len - len(t) : max_prompt_len] = torch.tensor(
t, dtype=torch.long, device="cuda"
)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

padding = torch.tensor(
[max_prompt_len - len(t) for t in prompt_tokens],
dtype=torch.int64,
device="cuda",
)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
logits = model.forward(tokens, prev_pos, padding=padding)
token_logprobs = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens,
reduction="none",
ignore_index=pad_id,
)

for cur_pos in range(max_prompt_len, total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos, padding=padding)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)

next_token = next_token.reshape(-1)
tokens[:, cur_pos] = next_token
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
logits = model.forward(tokens, prev_pos)
token_logprobs = -F.cross_entropy(
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens,
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (next_token == tokenizer.eos_id)
prev_pos = cur_pos
if all(eos_reached):
break

for cur_pos in range(min_prompt_len, total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)

next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
next_token == tokenizer.eos_id
)
prev_pos = cur_pos
if all(eos_reached):
break

if logprobs:
token_logprobs = token_logprobs.tolist()
out_tokens, out_logprobs = [], []
for i, toks in enumerate(tokens.tolist()):
# cut to max gen len
start = 0 if echo else padding[i] + len(prompt_tokens[i])
toks = toks[start : padding[i] + len(prompt_tokens[i]) + max_gen_len]
probs = None
if logprobs:
token_logprobs = token_logprobs.tolist()
out_tokens, out_logprobs = [], []
for i, toks in enumerate(tokens.tolist()):
# cut to max gen len
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
probs = None
if logprobs:
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
# cut to eos tok if any
if tokenizer.eos_id in toks:
eos_idx = toks.index(tokenizer.eos_id)
toks = toks[:eos_idx]
probs = probs[:eos_idx] if logprobs else None
out_tokens.append(toks)
out_logprobs.append(probs)
return (out_tokens, out_logprobs if logprobs else None)
probs = token_logprobs[i][
start : padding[i] + len(prompt_tokens[i]) + max_gen_len
]
# cut to eos tok if any
if tokenizer.eos_id in toks:
eos_idx = toks.index(tokenizer.eos_id)
toks = toks[:eos_idx]
probs = probs[:eos_idx] if logprobs else None
out_tokens.append(toks)
out_logprobs.append(probs)
return (out_tokens, out_logprobs if logprobs else None)


def text_completion(
Expand Down Expand Up @@ -222,6 +230,7 @@ def text_completion(
]
return [{"generation": tokenizer.decode(t)} for t in generation_tokens]


def chat_completion(
model,
tokenizer,
Expand Down Expand Up @@ -317,9 +326,7 @@ def chat_completion(
{
"generation": {
"role": "assistant",
"content": tokenizer.decode(t)
if not unsafe
else UNSAFE_ERROR,
"content": tokenizer.decode(t) if not unsafe else UNSAFE_ERROR,
},
"tokens": [tokenizer.decode(x) for x in t],
"logprobs": logprobs_i,
Expand All @@ -337,5 +344,3 @@ def chat_completion(
}
for t, unsafe in zip(generation_tokens, unsafe_requests)
]


Loading

0 comments on commit df94a56

Please sign in to comment.