Skip to content

Commit

Permalink
[distributed] add batch decoding (pytorch#1151)
Browse files Browse the repository at this point in the history
* enable batch decoding, optimize dst/src creation outside of decoding loop

* remove logging, update formatting for display

* ruff formatting

* use Ke's variable names for send/rcv

* add formatting exception for llama2 "".res

* fix prompt incrementing
add formatting exception for llama2 "".res

* revert prompt incrementing to pp=1 state
  • Loading branch information
lessw2020 authored Sep 15, 2024
1 parent 8b45633 commit 03c9819
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 20 deletions.
63 changes: 44 additions & 19 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
get_num_params,
GPUMemoryMonitor,
)
from distributed.verification_utils import find_cpu_tensors
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
from torchchat.model import ModelArgs, Transformer
Expand Down Expand Up @@ -219,10 +218,9 @@ def _update_padded_sequence(
new_token: torch.Tensor,
prompt_lengths: List[int],
) -> None:
# TODO: this is a hacky way to update the padded sequence: when there is
# more than one prompt, the for loop and the assignment is incompatible.
for i in range(len(prompt_lengths)):
padded_sequence[i, prompt_lengths[i]] = new_token
padded_sequence[i, prompt_lengths[i]] = new_token[i, 0]
# logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")


def _cleanup():
Expand All @@ -242,7 +240,7 @@ def main(args):
distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}")

config = ModelArgs.from_name(distribution).transformer_args['text']
config = ModelArgs.from_name(distribution).transformer_args["text"]
logger.info(f"Chat Model Config: {config}")

tokenizer = _build_chat_tokenizer(model_name)
Expand Down Expand Up @@ -295,7 +293,7 @@ def main(args):
logger.info(f"Model: {model}")

mbs = 1 # number of micro-batches
mb_size = 1 # micro-batch size
mb_size = 5 # micro-batch size
batch_size = mbs * mb_size # total batch size

seqlen = 4096 # sequence length
Expand Down Expand Up @@ -343,6 +341,10 @@ def main(args):

prompt = [
"What is snow?",
"Where does Santa Claus live?",
"What is PyTorch?",
"Write a poem about the beauty of the night sky.",
"What is the capital of France, Germany and Switzerland?",
]

"""
Expand All @@ -366,17 +368,23 @@ def main(args):

start_pos = 0

# pipeline comms setup
first_pp_rank = 0
last_pp_rank = pp_group_size - 1

# Need these global ids due to the API definition of dist.send and recv
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)

# encode the prompt
input_ids = _encode_strings(
prompt, tokenizer, bos=True, device=device, dtype=torch.int64
)
logger.info(f"{input_ids[0:8]=}")

# create a padded tensor for the input prompt
padded_sequence, prompt_lengths = _create_padded_prompts(
input_ids, tokenizer, seqlen, start_pos, device
)
logger.info(f"{prompt_lengths=}")

# create schedule
schedule = ScheduleGPipe(stage, mbs)
Expand All @@ -389,8 +397,10 @@ def main(args):
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)

# New token generated each iteration
new_token = torch.zeros(1, device=device, dtype=torch.int64)
res = []
total_prompts = len(prompt_lengths)
# need a new token dimension (row) for each prompt in the batch
new_token = torch.zeros(total_prompts, 1, device=device, dtype=torch.int64)
res = [[] for _ in range(total_prompts)]
num_tokens = 40

# Decoding
Expand All @@ -415,8 +425,11 @@ def main(args):
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
)
# decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
new_token = torch.tensor([decode_results[0][0]], device=device)
res.append(decode_results[0][1])
for i in range(len(decode_results)):
res[i].append(decode_results[i][1])
new_token[i, 0] = torch.tensor(
[decode_results[i][0]], device=device
) # decode_results[i][0]

# sendrecv between last and first ranks, only if:
# first_pp_rank != last_pp_rank.
Expand All @@ -435,20 +448,27 @@ def main(args):

# Update input sequence with new token
if pp_rank == first_pp_rank:
_update_padded_sequence(
padded_sequence, new_token, prompt_lengths
)
_update_padded_sequence(padded_sequence, new_token, prompt_lengths)

# increment prompt lengths for next token
for i in range(len(prompt_lengths)):
prompt_lengths[i] += 1

# Display the decoding results

# output formatted response via last pp group and tp rank 0
if pp_rank == last_pp_rank and tp_rank == 0:
logger.info(f"Prompt:{color.green} {prompt[0]} {color.reset}")
formatted_response = " ".join(res)
logger.info(f"$$$$$$ {color.blue}{formatted_response} {color.reset} $$$$$")
for i in range(len(prompt_lengths)):
logger.info(f"\nPrompt:{color.green} {prompt[i]} {color.reset}")

# TODO: resolve issue with llama2-7b-chat model and "".join
if model_name != "llama2-7b-chat":
formatted_response = "".join(res[i])
else:
formatted_response = " ".join(res[i])
logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$\n")

# Cleanup
logger.info(
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
)
Expand All @@ -457,7 +477,12 @@ def main(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model_name", type=str, help="Name of the model to load", choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys())
parser.add_argument(
"model_name",
type=str,
help="Name of the model to load",
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
)
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
args = parser.parse_args()

Expand Down
2 changes: 1 addition & 1 deletion run_dist.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ NGPU=${NGPU:-"4"}
LOG_RANK=${LOG_RANK:-0,1,2,3}
torchrun --nproc-per-node=$NGPU --master_port=$PORT \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
dist_run.py
dist_run.py --pp 2 llama3

0 comments on commit 03c9819

Please sign in to comment.