Skip to content

Commit

Permalink
[Distributed] Implement universal batch_decode & decode_in_flight for…
Browse files Browse the repository at this point in the history
… llama2 & llama3, with deterministic or multinomial (topk) decoding (handle both sentencepiece (llama2) and tiktoken (llama3)) (pytorch#1234)

* working multi-prompt same lengths

* working multi-prompt multiple lengths

* tighten up results decoding and display

* improve batch_decode_next_tokens

* update _decode_in_flight

* move prompt outside of main, auto-update batch size based on prompt

* faster batch_decode_next_tokens, add topk/temperature option

* ruff format and check

* simplify decode step, remove old comments

* add explanatory comment on topk min check
  • Loading branch information
lessw2020 authored Oct 2, 2024
1 parent 58185b6 commit 77bac00
Showing 1 changed file with 81 additions and 46 deletions.
127 changes: 81 additions & 46 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import torch
import torch.distributed as dist
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs

from torchchat.distributed.logging_utils import SingletonLogger

Expand All @@ -33,8 +35,6 @@
get_num_params,
GPUMemoryMonitor,
)
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
from torchchat.model import ModelArgs, Transformer, TransformerArgs
from torchchat.utils.build_utils import set_precision

Expand Down Expand Up @@ -189,23 +189,49 @@ def _create_padded_prompts(

def _batch_decode_next_tokens(
output: torch.Tensor,
pos: int,
pos: List[int],
step: int = -1,
temperature: float = 1.0,
topk: int = 10,
) -> torch.Tensor:
"""
Decode the next token for each prompt in the batch.
Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding.
Args:
output (torch.Tensor): The output tensor to decode.
pos: the position of the `output` to decode in the sequence length dimension.
pos (List[int]): The positions of the `output` to decode in the sequence length dimension.
step (int): Step indicator. If -1, use positions from `pos`. Otherwise, use the first token.
temperature (float): Sampling temperature for non-deterministic decoding.
Returns:
Decoded token ids.
torch.Tensor: Decoded token ids.
"""
# Take the next token logits for each prompt
next_token_logits = output[:, pos, :]
# Argmax (deterministic) TODO: add temperature
next_token = torch.argmax(next_token_logits, dim=-1)
# Token ids in int tensor form
return next_token
batch_size, seq_len, vocab_size = output.shape

if step != -1:
next_token_logits = output[:, 0, :]
else:
# get the logits for each prompt at the specified positions
next_token_logits = output[torch.arange(batch_size), torch.tensor(pos) - 1]

if temperature != 1.0:
next_token_logits = next_token_logits / temperature

# Uses top-k sampling if temperature is not 1.0, otherwise use argmax
if temperature != 1.0:
top_k = min(topk, vocab_size) # Ensure top-k is not greater than vocab size
top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1)
probs = torch.softmax(top_k_logits, dim=-1)
next_token_indices = torch.multinomial(probs, num_samples=1).squeeze(-1)
next_tokens = top_k_indices.gather(
-1, next_token_indices.unsqueeze(-1)
).squeeze(-1)
else:
# Argmax (deterministic)
next_tokens = torch.argmax(next_token_logits, dim=-1)

logger.info(f"{color.yellow}Next tokens: {color.blue}{next_tokens}{color.reset}")
return next_tokens


def _update_padded_sequence(
Expand All @@ -218,11 +244,32 @@ def _update_padded_sequence(
# logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")


# Decode token id into string and print it
def _decode_in_flight(token, tokenizer, tp_rank):
"""decode token ids for all prompts in the batch and log them"""
token_str = tokenizer.decode(token.tolist())
# print the token string on tp rank 0
if tp_rank == 0:
logger.info(
f"{color.green} responses ====>>>> "
f"{color.blue} {token_str} {color.reset}"
)


def _cleanup():
dist.barrier()
dist.destroy_process_group()


prompt = [
"What is Snow?",
"Who is Santa Claus?",
"Where does Santa live?",
# "Who is Abraham Lincoln?",
# "How are models trained?",
]


def main(args):
model_name = args.model_name
pp_degree = args.pp
Expand Down Expand Up @@ -293,7 +340,7 @@ def main(args):
# Batch size. Since we push batches dynamically through the pipeline rather
# than chunking them, this is effectively micro-batch size in pipeline
# sense. Thus it is interchangeable with micro-batch size below.
batch_size = 4
batch_size = len(prompt)
seqlen_prefill = 1024 # sequence length
dim = 4096 # embedding dimension

Expand Down Expand Up @@ -331,7 +378,9 @@ def main(args):

# Helper function to get example inputs and outputs for the stages.
def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device)
mb_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), device=device
)
activation = torch.rand(
batch_size, seqlen, dim, device=device, dtype=model_dtype
)
Expand Down Expand Up @@ -362,13 +411,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
# pipelining effect.
prefiller = ScheduleGPipe(prefill_stage, 1)

prompt = [
"What is a computer?",
"Where does Santa live?",
"Who is Abraham Lincoln?",
"How are models trained?",
]

start_pos = 0

# Need these global ids due to the API definition of dist.send and recv
Expand All @@ -384,10 +426,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
padded_sequence, prompt_lengths = _create_padded_prompts(
input_ids, tokenizer, seqlen_prefill, start_pos, device
)
# TODO: figure out how to set input_pos for each prompt in the batch then we
# can remove this limitation.
s = set(prompt_lengths)
assert len(s) == 1, f"prompt_lengths should be the same, got {s}"

# Need these global ids due to the API definition of dist.send and recv
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
Expand All @@ -396,6 +434,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
# New token generated each iteration
# need a row dimension for each prompt in the batch
new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64)
logger.info(f"{color.green}{new_token.shape=}, {new_token=}{color.reset}")
# Store the generated tokens
res = []

Expand All @@ -416,23 +455,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# Decode token id into string and print it
def decode_in_flight(token):
# Make a 2D tensor with ids on row dimension
unsqueezed = torch.unsqueeze(token, 1)
token_str = tokenizer.decode(unsqueezed.tolist())
if tp_rank == 0:
logger.info(
f"{color.green} responses ====>>>> "
f"{color.blue} {token_str} {color.reset}"
)

# Decode the output -- first generated token
if pp_rank == last_pp_rank:
new_token = _batch_decode_next_tokens(output, prompt_lengths[0] - 1)
logger.info(f"{color.green}Decoding...{prompt_lengths=}{color.reset}")
new_token = _batch_decode_next_tokens(output, prompt_lengths)
res.append(new_token)
if not args.disable_in_flight_decode:
decode_in_flight(new_token)
_decode_in_flight(new_token, tokenizer, tp_rank)

# seqlen = 1 now
seqlen_decode = 1
Expand Down Expand Up @@ -482,10 +511,11 @@ def decode_in_flight(token):

# Decode the output
if pp_rank == last_pp_rank:
new_token = _batch_decode_next_tokens(output, 0)
# logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}")
new_token = _batch_decode_next_tokens(output, prompt_lengths, step)
res.append(new_token)
if not args.disable_in_flight_decode:
decode_in_flight(new_token)
_decode_in_flight(new_token, tokenizer, tp_rank)

# Increment input position
input_pos += 1
Expand All @@ -499,12 +529,17 @@ def decode_in_flight(token):
# output formatted response via last pp group and tp rank 0
if pp_rank == last_pp_rank and tp_rank == 0:
# `res` is a list of tensors, each being a batch of generated token ids
res = torch.stack(res, dim=1)
res_list = res.tolist()
response = tokenizer.decode(res_list)
for i in range(len(response)):
logger.info(f"Prompt: {color.green}{prompt[i]} {color.reset}")
logger.info(f"Response: {color.red}{response[i]} {color.reset}")

res_stacked = torch.stack(res, dim=1)
res_list = res_stacked.tolist()

# Decode the output as comprehension instead of loop
responses = [tokenizer.decode(sequence) for sequence in res_list]

# Show prompts and responses
for prompt_text, response_text in zip(prompt, responses):
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")
logger.info(f"Response: {color.red}{response_text} {color.reset}")

# Cleanup
_cleanup()
Expand Down

0 comments on commit 77bac00

Please sign in to comment.