From 77bac0021586c70dd494ef5778baa608a261062b Mon Sep 17 00:00:00 2001 From: Less Wright Date: Tue, 1 Oct 2024 17:40:05 -0700 Subject: [PATCH] [Distributed] Implement universal batch_decode & decode_in_flight for llama2 & llama3, with deterministic or multinomial (topk) decoding (handle both sentencepiece (llama2) and tiktoken (llama3)) (#1234) * working multi-prompt same lengths * working multi-prompt multiple lengths * tighten up results decoding and display * improve batch_decode_next_tokens * update _decode_in_flight * move prompt outside of main, auto-update batch size based on prompt * faster batch_decode_next_tokens, add topk/temperature option * ruff format and check * simplify decode step, remove old comments * add explanatory comment on topk min check --- dist_run.py | 127 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 81 insertions(+), 46 deletions(-) diff --git a/dist_run.py b/dist_run.py index 09e0be725a..3666bca89e 100644 --- a/dist_run.py +++ b/dist_run.py @@ -16,6 +16,8 @@ import torch import torch.distributed as dist +from torch.distributed.pipelining import PipelineStage, ScheduleGPipe +from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs from torchchat.distributed.logging_utils import SingletonLogger @@ -33,8 +35,6 @@ get_num_params, GPUMemoryMonitor, ) -from torch.distributed.pipelining import PipelineStage, ScheduleGPipe -from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs from torchchat.model import ModelArgs, Transformer, TransformerArgs from torchchat.utils.build_utils import set_precision @@ -189,23 +189,49 @@ def _create_padded_prompts( def _batch_decode_next_tokens( output: torch.Tensor, - pos: int, + pos: List[int], + step: int = -1, + temperature: float = 1.0, + topk: int = 10, ) -> torch.Tensor: """ - Decode the next token for each prompt in the batch. + Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding. + Args: output (torch.Tensor): The output tensor to decode. - pos: the position of the `output` to decode in the sequence length dimension. + pos (List[int]): The positions of the `output` to decode in the sequence length dimension. + step (int): Step indicator. If -1, use positions from `pos`. Otherwise, use the first token. + temperature (float): Sampling temperature for non-deterministic decoding. Returns: - Decoded token ids. + torch.Tensor: Decoded token ids. """ - # Take the next token logits for each prompt - next_token_logits = output[:, pos, :] - # Argmax (deterministic) TODO: add temperature - next_token = torch.argmax(next_token_logits, dim=-1) - # Token ids in int tensor form - return next_token + batch_size, seq_len, vocab_size = output.shape + + if step != -1: + next_token_logits = output[:, 0, :] + else: + # get the logits for each prompt at the specified positions + next_token_logits = output[torch.arange(batch_size), torch.tensor(pos) - 1] + + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + + # Uses top-k sampling if temperature is not 1.0, otherwise use argmax + if temperature != 1.0: + top_k = min(topk, vocab_size) # Ensure top-k is not greater than vocab size + top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1) + probs = torch.softmax(top_k_logits, dim=-1) + next_token_indices = torch.multinomial(probs, num_samples=1).squeeze(-1) + next_tokens = top_k_indices.gather( + -1, next_token_indices.unsqueeze(-1) + ).squeeze(-1) + else: + # Argmax (deterministic) + next_tokens = torch.argmax(next_token_logits, dim=-1) + + logger.info(f"{color.yellow}Next tokens: {color.blue}{next_tokens}{color.reset}") + return next_tokens def _update_padded_sequence( @@ -218,11 +244,32 @@ def _update_padded_sequence( # logger.info(f"updated prompt {i} with new token {new_token[i, 0]}") +# Decode token id into string and print it +def _decode_in_flight(token, tokenizer, tp_rank): + """decode token ids for all prompts in the batch and log them""" + token_str = tokenizer.decode(token.tolist()) + # print the token string on tp rank 0 + if tp_rank == 0: + logger.info( + f"{color.green} responses ====>>>> " + f"{color.blue} {token_str} {color.reset}" + ) + + def _cleanup(): dist.barrier() dist.destroy_process_group() +prompt = [ + "What is Snow?", + "Who is Santa Claus?", + "Where does Santa live?", + # "Who is Abraham Lincoln?", + # "How are models trained?", +] + + def main(args): model_name = args.model_name pp_degree = args.pp @@ -293,7 +340,7 @@ def main(args): # Batch size. Since we push batches dynamically through the pipeline rather # than chunking them, this is effectively micro-batch size in pipeline # sense. Thus it is interchangeable with micro-batch size below. - batch_size = 4 + batch_size = len(prompt) seqlen_prefill = 1024 # sequence length dim = 4096 # embedding dimension @@ -331,7 +378,9 @@ def main(args): # Helper function to get example inputs and outputs for the stages. def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: - mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device) + mb_ids = torch.randint( + 0, config.vocab_size, (batch_size, seqlen), device=device + ) activation = torch.rand( batch_size, seqlen, dim, device=device, dtype=model_dtype ) @@ -362,13 +411,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # pipelining effect. prefiller = ScheduleGPipe(prefill_stage, 1) - prompt = [ - "What is a computer?", - "Where does Santa live?", - "Who is Abraham Lincoln?", - "How are models trained?", - ] - start_pos = 0 # Need these global ids due to the API definition of dist.send and recv @@ -384,10 +426,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: padded_sequence, prompt_lengths = _create_padded_prompts( input_ids, tokenizer, seqlen_prefill, start_pos, device ) - # TODO: figure out how to set input_pos for each prompt in the batch then we - # can remove this limitation. - s = set(prompt_lengths) - assert len(s) == 1, f"prompt_lengths should be the same, got {s}" # Need these global ids due to the API definition of dist.send and recv first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) @@ -396,6 +434,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # New token generated each iteration # need a row dimension for each prompt in the batch new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64) + logger.info(f"{color.green}{new_token.shape=}, {new_token=}{color.reset}") # Store the generated tokens res = [] @@ -416,23 +455,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" ) - # Decode token id into string and print it - def decode_in_flight(token): - # Make a 2D tensor with ids on row dimension - unsqueezed = torch.unsqueeze(token, 1) - token_str = tokenizer.decode(unsqueezed.tolist()) - if tp_rank == 0: - logger.info( - f"{color.green} responses ====>>>> " - f"{color.blue} {token_str} {color.reset}" - ) - # Decode the output -- first generated token if pp_rank == last_pp_rank: - new_token = _batch_decode_next_tokens(output, prompt_lengths[0] - 1) + logger.info(f"{color.green}Decoding...{prompt_lengths=}{color.reset}") + new_token = _batch_decode_next_tokens(output, prompt_lengths) res.append(new_token) if not args.disable_in_flight_decode: - decode_in_flight(new_token) + _decode_in_flight(new_token, tokenizer, tp_rank) # seqlen = 1 now seqlen_decode = 1 @@ -482,10 +511,11 @@ def decode_in_flight(token): # Decode the output if pp_rank == last_pp_rank: - new_token = _batch_decode_next_tokens(output, 0) + # logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}") + new_token = _batch_decode_next_tokens(output, prompt_lengths, step) res.append(new_token) if not args.disable_in_flight_decode: - decode_in_flight(new_token) + _decode_in_flight(new_token, tokenizer, tp_rank) # Increment input position input_pos += 1 @@ -499,12 +529,17 @@ def decode_in_flight(token): # output formatted response via last pp group and tp rank 0 if pp_rank == last_pp_rank and tp_rank == 0: # `res` is a list of tensors, each being a batch of generated token ids - res = torch.stack(res, dim=1) - res_list = res.tolist() - response = tokenizer.decode(res_list) - for i in range(len(response)): - logger.info(f"Prompt: {color.green}{prompt[i]} {color.reset}") - logger.info(f"Response: {color.red}{response[i]} {color.reset}") + + res_stacked = torch.stack(res, dim=1) + res_list = res_stacked.tolist() + + # Decode the output as comprehension instead of loop + responses = [tokenizer.decode(sequence) for sequence in res_list] + + # Show prompts and responses + for prompt_text, response_text in zip(prompt, responses): + logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}") + logger.info(f"Response: {color.red}{response_text} {color.reset}") # Cleanup _cleanup()