From e90d41c16e9748a89896ccb32cca14c1521abd20 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 21 Dec 2023 15:39:39 -0800 Subject: [PATCH 1/6] 2D working without TP self attention --- examples/llama/2d_llama.py | 78 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 examples/llama/2d_llama.py diff --git a/examples/llama/2d_llama.py b/examples/llama/2d_llama.py new file mode 100644 index 000000000..db164cc84 --- /dev/null +++ b/examples/llama/2d_llama.py @@ -0,0 +1,78 @@ +# $ torchrun --nproc-per-node 8 2d_llama.py +import os +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage +from torch.distributed._tensor import init_device_mesh + +# Grab the model +llama = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True +) +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + +prompts = ( + "How do you", "I like to", "Can I help", "You need to", + "The weather is", "I found a", "What is your", "You are so", +) # bs = 8 +tokenizer.pad_token = tokenizer.eos_token + +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + +pp_group_size = 2 +tp_group_size = 4 +mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp")) +pp_group = mesh_2d["pp"].get_group() + +llama.to(device).eval() +inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device) + +# Cut model by equal number of layers per rank +layers_per_stage = llama.config.num_hidden_layers // pp_group_size +for i in range(1, pp_group_size): + annotate_split_points(llama, + {f"model.layers.{i * layers_per_stage}": PipeSplitWrapper.SplitPoint.BEGINNING}) + +# Create a pipeline representation from the model +llama_pipe = Pipe.from_tracing(llama, pp_group_size, example_args=(inputs["input_ids"],)) + +# Create pipeline stage for each rank +stage_idx = rank // tp_group_size +stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group) + +# Tensor parallel +from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel +starting_layer = stage_idx * layers_per_stage +plan = {} +for i in range(layers_per_stage): + # HACK: the right fix is to remove the ".mod" added by PipeSplitWrapper + extra = "_mod" if starting_layer > 0 and i == 0 else "" + layer_name = f"L__self___model_layers_{starting_layer + i}{extra}" + plan.update({ + # Parallel self attention not working yet due to the dimension mismatch + # after TP in view operation + #f"{layer_name}_self_attn_q_proj": ColwiseParallel(), + #f"{layer_name}_self_attn_k_proj": ColwiseParallel(), + #f"{layer_name}_self_attn_v_proj": ColwiseParallel(), + #f"{layer_name}_self_attn_o_proj": RowwiseParallel(), + f"{layer_name}_mlp_gate_proj": ColwiseParallel(), + f"{layer_name}_mlp_up_proj": ColwiseParallel(), + f"{layer_name}_mlp_down_proj": RowwiseParallel(), + }) +tp_mesh = mesh_2d["tp"] +parallelize_module(stage.submod, tp_mesh, plan) + +# Run +if stage_idx == 0: + args = inputs["input_ids"] +else: + args = None +output = stage(args) + +# Decode +if output is not None: + next_token_logits = output[0][:, -1, :] + next_token = torch.argmax(next_token_logits, dim=-1) + print(tokenizer.batch_decode(next_token)) From 78ccf593ac75085aa3acb4f1328b0a11f2ec0a57 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 28 Dec 2023 09:07:07 -0800 Subject: [PATCH 2/6] Modify view ops to make them compatible with TP --- examples/llama/2d_llama.py | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/examples/llama/2d_llama.py b/examples/llama/2d_llama.py index db164cc84..e09342134 100644 --- a/examples/llama/2d_llama.py +++ b/examples/llama/2d_llama.py @@ -5,6 +5,23 @@ from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage from torch.distributed._tensor import init_device_mesh + +def modify_view( + gm: torch.fx.GraphModule, + tp: int +): + """ + Adjust dimension size of view ops to make them compatible with tensor parallelism. + """ + for node in gm.graph.nodes: + if node.op == "call_method" and ( + node.target == "view" or node.target == "reshape" + ): + assert len(node.args) >= 4 + node.update_arg(3, node.args[3] // tp) + gm.recompile() + + # Grab the model llama = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True @@ -42,27 +59,34 @@ stage_idx = rank // tp_group_size stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group) +modify_view(stage.submod, tp_group_size) + # Tensor parallel from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel starting_layer = stage_idx * layers_per_stage -plan = {} +attn_plan = {} +mlp_plan = {} for i in range(layers_per_stage): # HACK: the right fix is to remove the ".mod" added by PipeSplitWrapper extra = "_mod" if starting_layer > 0 and i == 0 else "" layer_name = f"L__self___model_layers_{starting_layer + i}{extra}" - plan.update({ + attn_plan.update({ # Parallel self attention not working yet due to the dimension mismatch # after TP in view operation - #f"{layer_name}_self_attn_q_proj": ColwiseParallel(), - #f"{layer_name}_self_attn_k_proj": ColwiseParallel(), - #f"{layer_name}_self_attn_v_proj": ColwiseParallel(), - #f"{layer_name}_self_attn_o_proj": RowwiseParallel(), + f"{layer_name}_self_attn_q_proj": ColwiseParallel(), + f"{layer_name}_self_attn_k_proj": ColwiseParallel(), + f"{layer_name}_self_attn_v_proj": ColwiseParallel(), + f"{layer_name}_self_attn_o_proj": RowwiseParallel(), + }) + mlp_plan.update({ f"{layer_name}_mlp_gate_proj": ColwiseParallel(), f"{layer_name}_mlp_up_proj": ColwiseParallel(), f"{layer_name}_mlp_down_proj": RowwiseParallel(), }) tp_mesh = mesh_2d["tp"] -parallelize_module(stage.submod, tp_mesh, plan) +parallelize_module( + stage.submod, tp_mesh, {**attn_plan, **mlp_plan} +) # Run if stage_idx == 0: From ab40031deb723f7e233147829b5f5fe3b3ed7b9d Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 28 Dec 2023 09:36:48 -0800 Subject: [PATCH 3/6] Comments --- examples/llama/2d_llama.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/llama/2d_llama.py b/examples/llama/2d_llama.py index e09342134..f3764e68f 100644 --- a/examples/llama/2d_llama.py +++ b/examples/llama/2d_llama.py @@ -6,12 +6,15 @@ from torch.distributed._tensor import init_device_mesh +# Utility def modify_view( gm: torch.fx.GraphModule, tp: int ): """ - Adjust dimension size of view ops to make them compatible with tensor parallelism. + Adjust dimension size of view ops to make them compatible with tensor + parallelism. For example, when TP is 4, we need to adjust `num_heads` from + 32 to 8. This is needed for attention layers. """ for node in gm.graph.nodes: if node.op == "call_method" and ( From 0cc816d1669a5d4174183fd81907b75ab9221456 Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri Date: Tue, 2 Jan 2024 20:35:36 +0000 Subject: [PATCH 4/6] adding deferred init as Ke advised --- examples/llama/2d_llama.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/llama/2d_llama.py b/examples/llama/2d_llama.py index f3764e68f..a3e66e054 100644 --- a/examples/llama/2d_llama.py +++ b/examples/llama/2d_llama.py @@ -27,7 +27,8 @@ def modify_view( # Grab the model llama = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True + "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True, + torch_dtype=torch.float16 ) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") @@ -46,8 +47,8 @@ def modify_view( mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp")) pp_group = mesh_2d["pp"].get_group() -llama.to(device).eval() -inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device) +llama.eval() +inputs = tokenizer(prompts, return_tensors="pt", padding=True) # Cut model by equal number of layers per rank layers_per_stage = llama.config.num_hidden_layers // pp_group_size @@ -90,7 +91,7 @@ def modify_view( parallelize_module( stage.submod, tp_mesh, {**attn_plan, **mlp_plan} ) - +inputs = inputs.to(device) # Run if stage_idx == 0: args = inputs["input_ids"] From 37e110c0900914e767c730b9d83fc2b20182bf56 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 2 Jan 2024 15:02:18 -0800 Subject: [PATCH 5/6] Remove modify_view() --- examples/llama/2d_llama.py | 42 +++++++++++++++----------------------- pippy/PipelineStage.py | 4 +++- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/examples/llama/2d_llama.py b/examples/llama/2d_llama.py index a3e66e054..bc7e380ff 100644 --- a/examples/llama/2d_llama.py +++ b/examples/llama/2d_llama.py @@ -4,25 +4,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage from torch.distributed._tensor import init_device_mesh +from torch.distributed._tensor import DTensor -# Utility -def modify_view( - gm: torch.fx.GraphModule, - tp: int -): - """ - Adjust dimension size of view ops to make them compatible with tensor - parallelism. For example, when TP is 4, we need to adjust `num_heads` from - 32 to 8. This is needed for attention layers. - """ - for node in gm.graph.nodes: - if node.op == "call_method" and ( - node.target == "view" or node.target == "reshape" - ): - assert len(node.args) >= 4 - node.update_arg(3, node.args[3] // tp) - gm.recompile() +# We set this flag to true to allow operations on a mix of tensor and dtensor +# arguments. The mix is a result of `use_local_output=False` +DTensor._op_dispatcher._allow_implicit_replication = True # Grab the model @@ -63,8 +50,6 @@ def modify_view( stage_idx = rank // tp_group_size stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group) -modify_view(stage.submod, tp_group_size) - # Tensor parallel from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel starting_layer = stage_idx * layers_per_stage @@ -75,12 +60,13 @@ def modify_view( extra = "_mod" if starting_layer > 0 and i == 0 else "" layer_name = f"L__self___model_layers_{starting_layer + i}{extra}" attn_plan.update({ - # Parallel self attention not working yet due to the dimension mismatch - # after TP in view operation - f"{layer_name}_self_attn_q_proj": ColwiseParallel(), - f"{layer_name}_self_attn_k_proj": ColwiseParallel(), - f"{layer_name}_self_attn_v_proj": ColwiseParallel(), - f"{layer_name}_self_attn_o_proj": RowwiseParallel(), + # We set `use_local_output` to False to keep the output tensor in + # DTensor form, so that it works with the view/reshape operations + # without code change. + f"{layer_name}_self_attn_q_proj": ColwiseParallel(use_local_output=False), + f"{layer_name}_self_attn_k_proj": ColwiseParallel(use_local_output=False), + f"{layer_name}_self_attn_v_proj": ColwiseParallel(use_local_output=False), + f"{layer_name}_self_attn_o_proj": RowwiseParallel(use_local_output=False), }) mlp_plan.update({ f"{layer_name}_mlp_gate_proj": ColwiseParallel(), @@ -101,6 +87,10 @@ def modify_view( # Decode if output is not None: - next_token_logits = output[0][:, -1, :] + next_token_logits = output[0] + if isinstance(next_token_logits, DTensor): + # Convert DTensor back to regular tensor + next_token_logits = next_token_logits.to_local() + next_token_logits = next_token_logits[:, -1, :] next_token = torch.argmax(next_token_logits, dim=-1) print(tokenizer.batch_decode(next_token)) diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index fe39a15d6..767939113 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -476,7 +476,9 @@ def _send_activations( ) peer_rank = self.stage_index_to_group_rank[dst] work = dist.isend( - out, + # HACK: we convert DTensor to regular tensor here for it to + # work with send ops. DTensor may show up in PP + TP cases. + out.to_local() if isinstance(out, torch.distributed._tensor.DTensor) else out, peer_rank if self.group is None else dist.get_global_rank(self.group, peer_rank), # TODO From c0f6152d4cbd79946da974631600d0493f98c9d3 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 2 Jan 2024 15:22:27 -0800 Subject: [PATCH 6/6] Rearrange code --- examples/llama/2d_llama.py | 17 ++++++++--------- pippy/PipelineStage.py | 4 +++- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/llama/2d_llama.py b/examples/llama/2d_llama.py index bc7e380ff..f967e6d2b 100644 --- a/examples/llama/2d_llama.py +++ b/examples/llama/2d_llama.py @@ -3,8 +3,8 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage -from torch.distributed._tensor import init_device_mesh -from torch.distributed._tensor import DTensor +from torch.distributed._tensor import init_device_mesh, DTensor +from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel # We set this flag to true to allow operations on a mix of tensor and dtensor @@ -15,28 +15,27 @@ # Grab the model llama = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True, - torch_dtype=torch.float16 ) -tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") +llama.eval() +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") prompts = ( "How do you", "I like to", "Can I help", "You need to", "The weather is", "I found a", "What is your", "You are so", ) # bs = 8 tokenizer.pad_token = tokenizer.eos_token +inputs = tokenizer(prompts, return_tensors="pt", padding=True) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") +# Initialize 2D device mesh pp_group_size = 2 tp_group_size = 4 mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp")) pp_group = mesh_2d["pp"].get_group() -llama.eval() -inputs = tokenizer(prompts, return_tensors="pt", padding=True) - # Cut model by equal number of layers per rank layers_per_stage = llama.config.num_hidden_layers // pp_group_size for i in range(1, pp_group_size): @@ -51,7 +50,6 @@ stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group) # Tensor parallel -from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel starting_layer = stage_idx * layers_per_stage attn_plan = {} mlp_plan = {} @@ -77,8 +75,9 @@ parallelize_module( stage.submod, tp_mesh, {**attn_plan, **mlp_plan} ) -inputs = inputs.to(device) + # Run +inputs = inputs.to(device) if stage_idx == 0: args = inputs["input_ids"] else: diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index 767939113..3efca131d 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -478,7 +478,9 @@ def _send_activations( work = dist.isend( # HACK: we convert DTensor to regular tensor here for it to # work with send ops. DTensor may show up in PP + TP cases. - out.to_local() if isinstance(out, torch.distributed._tensor.DTensor) else out, + out.to_local() + if isinstance(out, torch.distributed._tensor.DTensor) + else out, peer_rank if self.group is None else dist.get_global_rank(self.group, peer_rank), # TODO