Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Tensor Parallel implementation with PyTorch TP #34184

Merged
merged 26 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e60fb87
Simplify Tensor Parallel implementation with PyTorch TP
kwen2501 Oct 15, 2024
fd7f7c7
Move tp_plan to config
kwen2501 Oct 23, 2024
9224cab
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Oct 30, 2024
79cc524
Lint
kwen2501 Oct 30, 2024
a2934b3
Format and warning
kwen2501 Oct 30, 2024
a8fc418
Disable copy-from check
kwen2501 Oct 30, 2024
e84a388
Conditionally get attr from config
kwen2501 Oct 31, 2024
396d158
make fix-copies
kwen2501 Oct 31, 2024
7b346b5
Move base_model_tp_plan to PretrainedConfig
kwen2501 Oct 31, 2024
d60679b
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Oct 31, 2024
dda058a
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Nov 1, 2024
12fbbe7
Move TP into from_pretrained
kwen2501 Nov 7, 2024
02c8c39
Add device context for load
kwen2501 Nov 7, 2024
073c521
Do not serialize
kwen2501 Nov 7, 2024
db6e5ee
Move _tp_plan setting to post_init
kwen2501 Nov 7, 2024
5bb294e
Add has_tp_plan
kwen2501 Nov 14, 2024
290a7f1
Add test_tp
kwen2501 Nov 15, 2024
bd2e89c
Add 'Multi-gpu inference' doc
kwen2501 Nov 15, 2024
4892cef
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Nov 15, 2024
9648f31
Add backward support for device type identification
kwen2501 Nov 15, 2024
93ba283
Auto-detect accelerator
kwen2501 Nov 16, 2024
73524c9
supports_tp_plan
kwen2501 Nov 16, 2024
f312e55
copyright year
kwen2501 Nov 16, 2024
ca93bdb
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Nov 17, 2024
dc2672f
Merge branch 'main' into tp_llama
kwen2501 Nov 18, 2024
1e27d6f
Fix copy
kwen2501 Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@
title: CPU inference
- local: perf_infer_gpu_one
title: GPU inference
- local: perf_infer_gpu_multi
title: Multi-GPU inference
title: Optimizing inference
- local: big_models
title: Instantiate a big model
Expand Down
68 changes: 68 additions & 0 deletions docs/source/en/perf_infer_gpu_multi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Multi-GPU inference

Built-in Tensor Parallelism (TP) is now available with certain models using PyTorch. Tensor parallelism shards a model onto multiple GPUs, enabling larger model sizes, and parallelizes computations such as matrix multiplication.

To enable tensor parallel, pass the argument `tp_plan="auto"` to [`~AutoModelForCausalLM.from_pretrained`]:

```python
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

# Initialize distributed
rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)

# Retrieve tensor parallel model
model = AutoModelForCausalLM.from_pretrained(
model_id,
tp_plan="auto",
)

# Prepare input tokens
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# Distributed run
outputs = model(inputs)
```

You can use `torchrun` to launch the above script with multiple processes, each mapping to a GPU:

```
torchrun --nproc-per-node 4 demo.py
```

PyTorch tensor parallel is currently supported for the following models:
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)

You can request to add tensor parallel support for another model by opening a GitHub Issue or Pull Request.

### Expected speedups

You can benefit from considerable speedups for inference, especially for inputs with large batch size or long sequences.

For a single forward pass on [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) with a sequence length of 512 and various batch sizes, the expected speedup is as follows:

<div style="text-align: center">
<img src="huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Meta-Llama-3-8B-Instruct, seqlen = 512, python, w_ compile.png">
</div>
2 changes: 1 addition & 1 deletion docs/source/en/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ sections we go through the steps to run inference on CPU and single/multi-GPU se

* [Inference on a single CPU](perf_infer_cpu)
* [Inference on a single GPU](perf_infer_gpu_one)
* [Multi-GPU inference](perf_infer_gpu_one)
* [Multi-GPU inference](perf_infer_gpu_multi)
kwen2501 marked this conversation as resolved.
Show resolved Hide resolved
* [XLA Integration for TensorFlow Models](tf_xla)


Expand Down
9 changes: 9 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class PretrainedConfig(PushToHubMixin):
outputs of the model during inference.
- **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
naming of attributes.
- **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
parallel plan applied to the sub-module when `model.tensor_parallel` is called.

Common attributes (present in all subclasses):

Expand Down Expand Up @@ -194,6 +196,7 @@ class PretrainedConfig(PushToHubMixin):
sub_configs: Dict[str, "PretrainedConfig"] = {}
is_composition: bool = False
attribute_map: Dict[str, str] = {}
base_model_tp_plan: Optional[Dict[str, Any]] = None
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
_auto_class: Optional[str] = None

def __setattr__(self, key, value):
Expand Down Expand Up @@ -848,6 +851,9 @@ def to_diff_dict(self) -> Dict[str, Any]:

if "_attn_implementation_internal" in serializable_config_dict:
del serializable_config_dict["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_tp_plan"]

return serializable_config_dict

Expand All @@ -867,6 +873,9 @@ def to_dict(self) -> Dict[str, Any]:
del output["_commit_hash"]
if "_attn_implementation_internal" in output:
del output["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in output:
del output["base_model_tp_plan"]

# Transformers version when serializing the model
output["transformers_version"] = __version__
Expand Down
142 changes: 116 additions & 26 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
translate_to_torch_parallel_style,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
Expand Down Expand Up @@ -1326,6 +1327,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Has support for a `QuantoQuantizedCache` instance as `past_key_values`
_supports_quantized_cache = False

# A tensor parallel plan to be applied to the model when TP is enabled. For
# top-level models, this attribute is currently defined in respective model
# code. For base models, this attribute comes from
# `config.base_model_tp_plan` during `post_init`.
_tp_plan = None

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -1370,6 +1377,9 @@ def post_init(self):
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
# If current model is a base model, attach `base_model_tp_plan` from config
if self.base_model is self:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.base_model is self:
if self._tp_plan is None:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels simpler

Copy link
Contributor Author

@kwen2501 kwen2501 Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, the reason for self.base_model is self is that we are attaching the base_model_tp_plan to the base model only. If we attach it to LlamaForCausalLM, the FQNs won't match, because the base_model_tp_plan FQNs start with "layers", not "model.layers".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but self._tp_plan should be None only for the base model no?
We could maybe add something like if not self.base_model is self and self._tp_plan is None and self.supports_tp_plan raise an error in the futur, to force people to add the TP plan.

Copy link
Contributor Author

@kwen2501 kwen2501 Nov 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but self._tp_plan should be None only for the base model no?

That's not always the case. For example, LlamaForSequenceClassification and LlamaForQuestionAnswering have _tp_plan=None (at their top level), while LlamaForCausalLM has a _tp_plan = {"lm_head": ...}.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, indeed. I thiught we should enforce TP plan definition for all classes to avoid user errors but its fine like this!

self._tp_plan = self.config.base_model_tp_plan

def dequantize(self):
"""
Expand Down Expand Up @@ -3399,6 +3409,11 @@ def from_pretrained(
# Cache path to the GGUF file
gguf_path = None

tp_plan = kwargs.pop("tp_plan", None)
if tp_plan is not None and tp_plan != "auto":
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")

if is_fsdp_enabled():
low_cpu_mem_usage = True

Expand Down Expand Up @@ -4000,6 +4015,7 @@ def from_pretrained(

# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
tp_device = None

if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
Expand All @@ -4012,6 +4028,16 @@ def from_pretrained(
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
init_contexts.append(init_empty_weights())
elif tp_plan is not None:
if not torch.distributed.is_initialized():
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")

# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
# Get device with index assuming equal number of devices per host
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
init_contexts.append(tp_device)

if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())
Expand Down Expand Up @@ -4145,32 +4171,38 @@ def from_pretrained(
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
)
load_contexts = []
# Make sure we load onto targeted device
if tp_device is not None:
load_contexts.append(tp_device)

with ContextManagers(load_contexts):
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
)

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down Expand Up @@ -4254,6 +4286,16 @@ def from_pretrained(
}
return model, loading_info

if tp_plan is not None:
assert tp_device is not None, "tp_device not set!"
kwen2501 marked this conversation as resolved.
Show resolved Hide resolved
if not model.supports_tp_plan:
raise NotImplementedError("This model does not have a tensor parallel plan.")
# Assuming sharding the model onto the world
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
# Apply Tensor Parallelism
model.tensor_parallel(device_mesh)

return model

@classmethod
Expand Down Expand Up @@ -4943,6 +4985,54 @@ def _is_quantized_training_enabled(self):

return self.hf_quantizer.is_trainable

@property
def supports_tp_plan(self):
"""
Returns whether the model has a tensor parallelism plan.
"""
if self._tp_plan is not None:
return True
# Check if base model has a TP plan
if getattr(self.base_model, "_tp_plan", None) is not None:
return True
return False

def tensor_parallel(self, device_mesh):
kwen2501 marked this conversation as resolved.
Show resolved Hide resolved
"""
Tensor parallelize the model across the given device mesh.

Args:
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""

# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan is None:
return
logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}")
# In model configs, we use a neutral type (string) to specify
# parallel styles, here we translate them into torch TP types.
# Using tree_map because `tp_plan` is a dict.
tp_plan = torch.utils._pytree.tree_map(
translate_to_torch_parallel_style,
tp_plan,
)
# Apply TP to current module.
torch.distributed.tensor.parallel.parallelize_module(
mod,
device_mesh=device_mesh,
parallelize_plan=tp_plan,
)

# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
self.apply(tplize)

@property
def loss_function(self):
if getattr(self.config, "loss_type", None) is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,10 @@ def __init__(self, config: GemmaConfig):
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -982,6 +985,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config):
super().__init__(config)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,10 @@ def __init__(self, config: Gemma2Config):
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -961,6 +964,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config):
super().__init__(config)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,8 @@ def __init__(self, config: GlmConfig):
dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta
)
self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -967,6 +969,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config: GlmConfig):
super().__init__(config)
Expand Down
Loading