Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Tensor Parallel implementation with PyTorch TP #34184

Merged
merged 26 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e60fb87
Simplify Tensor Parallel implementation with PyTorch TP
kwen2501 Oct 15, 2024
fd7f7c7
Move tp_plan to config
kwen2501 Oct 23, 2024
9224cab
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Oct 30, 2024
79cc524
Lint
kwen2501 Oct 30, 2024
a2934b3
Format and warning
kwen2501 Oct 30, 2024
a8fc418
Disable copy-from check
kwen2501 Oct 30, 2024
e84a388
Conditionally get attr from config
kwen2501 Oct 31, 2024
396d158
make fix-copies
kwen2501 Oct 31, 2024
7b346b5
Move base_model_tp_plan to PretrainedConfig
kwen2501 Oct 31, 2024
d60679b
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Oct 31, 2024
dda058a
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Nov 1, 2024
12fbbe7
Move TP into from_pretrained
kwen2501 Nov 7, 2024
02c8c39
Add device context for load
kwen2501 Nov 7, 2024
073c521
Do not serialize
kwen2501 Nov 7, 2024
db6e5ee
Move _tp_plan setting to post_init
kwen2501 Nov 7, 2024
5bb294e
Add has_tp_plan
kwen2501 Nov 14, 2024
290a7f1
Add test_tp
kwen2501 Nov 15, 2024
bd2e89c
Add 'Multi-gpu inference' doc
kwen2501 Nov 15, 2024
4892cef
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Nov 15, 2024
9648f31
Add backward support for device type identification
kwen2501 Nov 15, 2024
93ba283
Auto-detect accelerator
kwen2501 Nov 16, 2024
73524c9
supports_tp_plan
kwen2501 Nov 16, 2024
f312e55
copyright year
kwen2501 Nov 16, 2024
ca93bdb
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Nov 17, 2024
dc2672f
Merge branch 'main' into tp_llama
kwen2501 Nov 18, 2024
1e27d6f
Fix copy
kwen2501 Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class PretrainedConfig(PushToHubMixin):
outputs of the model during inference.
- **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
naming of attributes.
- **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
parallel plan applied to the sub-module when `model.tensor_parallel` is called.

Common attributes (present in all subclasses):

Expand Down Expand Up @@ -192,6 +194,7 @@ class PretrainedConfig(PushToHubMixin):
model_type: str = ""
is_composition: bool = False
attribute_map: Dict[str, str] = {}
base_model_tp_plan: Optional[Dict[str, Any]] = None
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
_auto_class: Optional[str] = None

def __setattr__(self, key, value):
Expand Down Expand Up @@ -835,6 +838,9 @@ def to_diff_dict(self) -> Dict[str, Any]:

if "_attn_implementation_internal" in serializable_config_dict:
del serializable_config_dict["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_tp_plan"]

return serializable_config_dict

Expand All @@ -854,6 +860,9 @@ def to_dict(self) -> Dict[str, Any]:
del output["_commit_hash"]
if "_attn_implementation_internal" in output:
del output["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in output:
del output["base_model_tp_plan"]

# Transformers version when serializing the model
output["transformers_version"] = __version__
Expand Down
129 changes: 103 additions & 26 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
translate_to_torch_parallel_style,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
Expand Down Expand Up @@ -1398,6 +1399,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Has support for a `QuantoQuantizedCache` instance as `past_key_values`
_supports_quantized_cache = False

# A tensor parallel plan to be applied to the model when TP is enabled. For
# top-level models, this attribute is currently defined in respective model
# code. For base models, this attribute comes from
# `config.base_model_tp_plan` during `post_init`.
_tp_plan = None

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -1442,6 +1449,9 @@ def post_init(self):
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
# If current model is a base model, attach `base_model_tp_plan` from config
if self.base_model is self:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.base_model is self:
if self._tp_plan is None:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels simpler

Copy link
Contributor Author

@kwen2501 kwen2501 Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, the reason for self.base_model is self is that we are attaching the base_model_tp_plan to the base model only. If we attach it to LlamaForCausalLM, the FQNs won't match, because the base_model_tp_plan FQNs start with "layers", not "model.layers".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but self._tp_plan should be None only for the base model no?
We could maybe add something like if not self.base_model is self and self._tp_plan is None and self.supports_tp_plan raise an error in the futur, to force people to add the TP plan.

Copy link
Contributor Author

@kwen2501 kwen2501 Nov 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but self._tp_plan should be None only for the base model no?

That's not always the case. For example, LlamaForSequenceClassification and LlamaForQuestionAnswering have _tp_plan=None (at their top level), while LlamaForCausalLM has a _tp_plan = {"lm_head": ...}.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, indeed. I thiught we should enforce TP plan definition for all classes to avoid user errors but its fine like this!

self._tp_plan = self.config.base_model_tp_plan

def dequantize(self):
"""
Expand Down Expand Up @@ -3472,6 +3482,11 @@ def from_pretrained(
# Cache path to the GGUF file
gguf_path = None

tp_plan = kwargs.pop("tp_plan", None)
if tp_plan is not None and tp_plan != "auto":
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")

if is_fsdp_enabled():
low_cpu_mem_usage = True

Expand Down Expand Up @@ -4073,6 +4088,7 @@ def from_pretrained(

# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
tp_device = None

if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
Expand All @@ -4085,6 +4101,17 @@ def from_pretrained(
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
init_contexts.append(init_empty_weights())
elif tp_plan is not None:
if not torch.distributed.is_initialized():
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")

# Get device type (e.g. "cuda")
device_type = torch.distributed.distributed_c10d._device_capability()[0]
# Get torch device module (e.g. torch.cuda) based on device type
device_module = torch.get_device_module(device_type)
# Get device with index assuming equal number of devices per host
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
init_contexts.append(tp_device)

config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
if not getattr(config, "_attn_implementation_autoset", False):
Expand Down Expand Up @@ -4215,32 +4242,38 @@ def from_pretrained(
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
)
load_contexts = []
# Make sure we load onto targeted device
if tp_device is not None:
load_contexts.append(tp_device)

with ContextManagers(load_contexts):
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
)

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down Expand Up @@ -4324,6 +4357,14 @@ def from_pretrained(
}
return model, loading_info

if tp_plan is not None:
assert tp_device is not None, "tp_device not set!"
kwen2501 marked this conversation as resolved.
Show resolved Hide resolved
# Assuming sharding the model onto the world
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
# Apply Tensor Parallelism
model.tensor_parallel(device_mesh)

return model

@classmethod
Expand Down Expand Up @@ -5013,6 +5054,42 @@ def _is_quantized_training_enabled(self):

return self.hf_quantizer.is_trainable

def tensor_parallel(self, device_mesh):
kwen2501 marked this conversation as resolved.
Show resolved Hide resolved
"""
Tensor parallelize the model across the given device mesh.

Args:
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""

# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan is None:
return
logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}")
# In model configs, we use a neutral type (string) to specify
# parallel styles, here we translate them into torch TP types.
# Using tree_map because `tp_plan` is a dict.
tp_plan = torch.utils._pytree.tree_map(
translate_to_torch_parallel_style,
tp_plan,
)
# Apply TP to current module.
torch.distributed.tensor.parallel.parallelize_module(
mod,
device_mesh=device_mesh,
parallelize_plan=tp_plan,
)

# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
self.apply(tplize)

@property
@lru_cache
def loss_function(self):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,10 @@ def __init__(self, config: GemmaConfig):
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -982,6 +985,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config):
super().__init__(config)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,10 @@ def __init__(self, config: Gemma2Config):
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -961,6 +964,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config):
super().__init__(config)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,8 @@ def __init__(self, config: GlmConfig):
dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta
)
self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -967,6 +969,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config: GlmConfig):
super().__init__(config)
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ class LlamaConfig(PretrainedConfig):

model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `LlamaModel`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
Loading
Loading